Add Task.get_requirements() returning the task’s requirements

This commit is contained in:
allegroai 2024-03-28 15:07:26 +02:00
parent 7ab2197ec0
commit 49e9e7370b
3 changed files with 29 additions and 2 deletions

View File

@ -91,7 +91,7 @@ from .utilities.config import verify_basic_value
from .binding.args import ( from .binding.args import (
argparser_parseargs_called, get_argparser_last_args, argparser_parseargs_called, get_argparser_last_args,
argparser_update_currenttask, ) argparser_update_currenttask, )
from .utilities.dicts import ReadOnlyDict, merge_dicts from .utilities.dicts import ReadOnlyDict, merge_dicts, RequirementsDict
from .utilities.proxy_object import ( from .utilities.proxy_object import (
ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary, ProxyDictPreWrite, ProxyDictPostWrite, flatten_dictionary,
nested_from_flat_dictionary, naive_nested_from_flat_dictionary, StubObject as _TaskStub) nested_from_flat_dictionary, naive_nested_from_flat_dictionary, StubObject as _TaskStub)
@ -1641,6 +1641,19 @@ class Task(_Task):
self.data.script.version_num = commit or "" self.data.script.version_num = commit or ""
self._edit(script=self.data.script) self._edit(script=self.data.script)
def get_requirements(self):
# type: () -> RequirementsDict
"""
Get the task's requirements
:return: A `RequirementsDict` object that holds the `pip`, `conda`, `orig_pip` requirements.
"""
if not running_remotely() and self.is_main_task():
self._wait_for_repo_detection(timeout=300.)
requirements_dict = RequirementsDict()
requirements_dict.update(self.data.script.requirements)
return requirements_dict
def connect_configuration(self, configuration, name=None, description=None, ignore_remote_overrides=False): def connect_configuration(self, configuration, name=None, description=None, ignore_remote_overrides=False):
# type: (Union[Mapping, list, Path, str], Optional[str], Optional[str], bool) -> Union[dict, Path, str] # type: (Union[Mapping, list, Path, str], Optional[str], Optional[str], bool) -> Union[dict, Path, str]
""" """

View File

@ -115,6 +115,20 @@ class NestedBlobsDict(BlobsDict):
return self._keys(self, '') return self._keys(self, '')
class RequirementsDict(dict):
@property
def pip(self):
return self.get("pip")
@property
def conda(self):
return self.get("conda")
@property
def orig_pip(self):
return self.get("orig_pip")
def merge_dicts(dict1, dict2): def merge_dicts(dict1, dict2):
""" Recursively merges dict2 into dict1 """ """ Recursively merges dict2 into dict1 """
if not isinstance(dict1, dict) or not isinstance(dict2, dict): if not isinstance(dict1, dict) or not isinstance(dict2, dict):