This commit is contained in:
revital 2023-10-25 10:34:38 +03:00
commit 52c47b5551
13 changed files with 271 additions and 160 deletions

View File

@ -101,6 +101,7 @@ class PipelineController(object):
explicit_docker_image = attrib(type=str, default=None) # The Docker image the node uses, specified at creation
recursively_parse_parameters = attrib(type=bool, default=False) # if True, recursively parse parameters in
# lists, dicts, or tuples
output_uri = attrib(type=Union[bool, str], default=None) # The default location for output models and other artifacts
def __attrs_post_init__(self):
if self.parents is None:
@ -134,6 +135,26 @@ class PipelineController(object):
new_copy.task_factory_func = self.task_factory_func
return new_copy
def set_job_ended(self):
if self.job_ended:
return
# noinspection PyBroadException
try:
self.job.task.reload()
self.job_ended = self.job_started + self.job.task.data.active_duration
except Exception as e:
pass
def set_job_started(self):
if self.job_started:
return
# noinspection PyBroadException
try:
self.job_started = self.job.task.data.started.timestamp()
except Exception:
pass
def __init__(
self,
name, # type: str
@ -155,7 +176,8 @@ class PipelineController(object):
repo_commit=None, # type: Optional[str]
always_create_from_code=True, # type: bool
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
):
# type: (...) -> None
"""
@ -242,6 +264,9 @@ class PipelineController(object):
def deserialize(bytes_):
import dill
return dill.loads(bytes_)
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
"""
if auto_version_bump is not None:
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
@ -316,6 +341,9 @@ class PipelineController(object):
project_id=self._task.project, system_tags=self._project_system_tags)
self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag])
if output_uri is not None:
self._task.output_uri = output_uri
self._output_uri = output_uri
self._task.set_base_docker(
docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script
)
@ -387,7 +415,8 @@ class PipelineController(object):
base_task_factory=None, # type: Optional[Callable[[PipelineController.Node], Task]]
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
recursively_parse_parameters=False # type: bool
recursively_parse_parameters=False, # type: bool
output_uri=None # type: Optional[Union[str, bool]]
):
# type: (...) -> bool
"""
@ -530,6 +559,8 @@ class PipelineController(object):
):
pass
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: True if successful
"""
@ -588,6 +619,7 @@ class PipelineController(object):
monitor_metrics=monitor_metrics or [],
monitor_artifacts=monitor_artifacts or [],
monitor_models=monitor_models or [],
output_uri=self._output_uri if output_uri is None else output_uri
)
self._retries[name] = 0
self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \
@ -632,7 +664,8 @@ class PipelineController(object):
cache_executed_step=False, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
tags=None # type: Optional[Union[str, Sequence[str]]]
tags=None, # type: Optional[Union[str, Sequence[str]]]
output_uri=None # type: Optional[Union[str, bool]]
):
# type: (...) -> bool
"""
@ -799,6 +832,8 @@ class PipelineController(object):
:param tags: A list of tags for the specific pipeline step.
When executing a Pipeline remotely
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: True if successful
"""
@ -838,7 +873,8 @@ class PipelineController(object):
cache_executed_step=cache_executed_step,
retry_on_failure=retry_on_failure,
status_change_callback=status_change_callback,
tags=tags
tags=tags,
output_uri=output_uri
)
def start(
@ -1014,8 +1050,8 @@ class PipelineController(object):
return cls._get_pipeline_task().get_logger()
@classmethod
def upload_model(cls, model_name, model_local_path):
# type: (str, str) -> OutputModel
def upload_model(cls, model_name, model_local_path, upload_uri=None):
# type: (str, str, Optional[str]) -> OutputModel
"""
Upload (add) a model to the main Pipeline Task object.
This function can be called from any pipeline component to directly add models into the main pipeline Task
@ -1028,12 +1064,16 @@ class PipelineController(object):
:param model_local_path: Path to the local model file or directory to be uploaded.
If a local directory is provided the content of the folder (recursively) will be
packaged into a zip file and uploaded
:param upload_uri: The URI of the storage destination for model weights upload. The default value
is the previously used URI.
:return: The uploaded OutputModel
"""
task = cls._get_pipeline_task()
model_name = str(model_name)
model_local_path = Path(model_local_path)
out_model = OutputModel(task=task, name=model_name)
out_model.update_weights(weights_filename=model_local_path.as_posix())
out_model.update_weights(weights_filename=model_local_path.as_posix(), upload_uri=upload_uri)
return out_model
@classmethod
@ -1457,7 +1497,7 @@ class PipelineController(object):
self, docker, docker_args, docker_bash_setup_script,
function, function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name, task_type, repo, branch, commit, helper_functions
packages, project_name, task_name, task_type, repo, branch, commit, helper_functions, output_uri=None
):
task_definition = CreateFromFunction.create_task_from_function(
a_function=function,
@ -1476,7 +1516,7 @@ class PipelineController(object):
docker=docker,
docker_args=docker_args,
docker_bash_setup_script=docker_bash_setup_script,
output_uri=None,
output_uri=output_uri,
helper_functions=helper_functions,
dry_run=True,
task_template_header=self._task_template_header,
@ -1631,6 +1671,7 @@ class PipelineController(object):
self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version),
"version": self._version
})
self._task.set_user_properties(version=self._version)
else:
self._task.connect_configuration(pipeline_dag, name=self._config_section)
connected_args = set()
@ -1927,7 +1968,8 @@ class PipelineController(object):
cache_executed_step=False, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
tags=None # type: Optional[Union[str, Sequence[str]]]
tags=None, # type: Optional[Union[str, Sequence[str]]]
output_uri=None # type: Optional[Union[str, bool]]
):
# type: (...) -> bool
"""
@ -2094,6 +2136,8 @@ class PipelineController(object):
:param tags: A list of tags for the specific pipeline step.
When executing a Pipeline remotely
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: True if successful
"""
@ -2107,6 +2151,9 @@ class PipelineController(object):
self._verify_node_name(name)
if output_uri is None:
output_uri = self._output_uri
function_input_artifacts = {}
# go over function_kwargs, split it into string and input artifacts
for k, v in function_kwargs.items():
@ -2145,7 +2192,7 @@ class PipelineController(object):
function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name,
task_type, repo, repo_branch, repo_commit, helper_functions)
task_type, repo, repo_branch, repo_commit, helper_functions, output_uri=output_uri)
elif self._task.running_locally() or self._task.get_configuration_object(name=name) is None:
project_name = project_name or self._get_target_project() or self._task.get_project_name()
@ -2155,7 +2202,7 @@ class PipelineController(object):
function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name,
task_type, repo, repo_branch, repo_commit, helper_functions)
task_type, repo, repo_branch, repo_commit, helper_functions, output_uri=output_uri)
# update configuration with the task definitions
# noinspection PyProtectedMember
self._task._set_configuration(
@ -2180,6 +2227,9 @@ class PipelineController(object):
if tags:
a_task.add_tags(tags)
if output_uri is not None:
a_task.output_uri = output_uri
return a_task
self._nodes[name] = self.Node(
@ -2195,7 +2245,8 @@ class PipelineController(object):
monitor_metrics=monitor_metrics,
monitor_models=monitor_models,
job_code_section=job_code_section,
explicit_docker_image=docker
explicit_docker_image=docker,
output_uri=output_uri
)
self._retries[name] = 0
self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \
@ -2284,13 +2335,14 @@ class PipelineController(object):
disable_clone_task=disable_clone_task,
task_overrides=task_overrides,
allow_caching=node.cache_executed_step,
output_uri=node.output_uri,
**extra_args
)
except Exception:
self._pipeline_task_status_failed = True
raise
node.job_started = time()
node.job_started = None
node.job_ended = None
node.job_type = str(node.job.task.task_type)
@ -2546,6 +2598,8 @@ class PipelineController(object):
"""
previous_status = node.status
if node.job and node.job.is_running():
node.set_job_started()
update_job_ended = node.job_started and not node.job_ended
if node.executed is not None:
@ -2582,7 +2636,7 @@ class PipelineController(object):
node.status = "pending"
if update_job_ended and node.status in ("aborted", "failed", "completed"):
node.job_ended = time()
node.set_job_ended()
if (
previous_status is not None
@ -2679,7 +2733,7 @@ class PipelineController(object):
if node_failed and self._abort_running_steps_on_failure and not node.continue_on_fail:
nodes_failed_stop_pipeline.append(node.name)
elif node.timeout:
started = node.job.task.data.started
node.set_job_started()
if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout:
node.job.abort()
completed_jobs.append(j)
@ -3261,7 +3315,8 @@ class PipelineDecorator(PipelineController):
repo_branch=None, # type: Optional[str]
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
):
# type: (...) -> ()
"""
@ -3341,6 +3396,9 @@ class PipelineDecorator(PipelineController):
def deserialize(bytes_):
import dill
return dill.loads(bytes_)
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
"""
super(PipelineDecorator, self).__init__(
name=name,
@ -3361,7 +3419,8 @@ class PipelineDecorator(PipelineController):
repo_commit=repo_commit,
always_create_from_code=False,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
)
# if we are in eager execution, make sure parent class knows it
@ -3583,7 +3642,7 @@ class PipelineDecorator(PipelineController):
function, function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name, task_type, repo, branch, commit,
helper_functions
helper_functions, output_uri=None
):
def sanitize(function_source):
matched = re.match(r"[\s]*@[\w]*.component[\s\\]*\(", function_source)
@ -3621,7 +3680,7 @@ class PipelineDecorator(PipelineController):
docker=docker,
docker_args=docker_args,
docker_bash_setup_script=docker_bash_setup_script,
output_uri=None,
output_uri=output_uri,
helper_functions=helper_functions,
dry_run=True,
task_template_header=self._task_template_header,
@ -3703,7 +3762,8 @@ class PipelineDecorator(PipelineController):
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
tags=None # type: Optional[Union[str, Sequence[str]]]
tags=None, # type: Optional[Union[str, Sequence[str]]]
output_uri=None # type: Optional[Union[str, bool]]
):
# type: (...) -> Callable
"""
@ -3841,6 +3901,8 @@ class PipelineDecorator(PipelineController):
:param tags: A list of tags for the specific pipeline step.
When executing a Pipeline remotely
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: function wrapper
"""
@ -3883,7 +3945,8 @@ class PipelineDecorator(PipelineController):
pre_execute_callback=pre_execute_callback,
post_execute_callback=post_execute_callback,
status_change_callback=status_change_callback,
tags=tags
tags=tags,
output_uri=output_uri
)
if cls._singleton:
@ -4109,7 +4172,8 @@ class PipelineDecorator(PipelineController):
repo_branch=None, # type: Optional[str]
repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
):
# type: (...) -> Callable
"""
@ -4220,6 +4284,9 @@ class PipelineDecorator(PipelineController):
def deserialize(bytes_):
import dill
return dill.loads(bytes_)
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
"""
def decorator_wrap(func):
@ -4265,7 +4332,8 @@ class PipelineDecorator(PipelineController):
repo_branch=repo_branch,
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
)
ret_val = func(**pipeline_kwargs)
LazyEvalWrapper.trigger_all_remote_references()
@ -4316,7 +4384,8 @@ class PipelineDecorator(PipelineController):
repo_branch=repo_branch,
repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function
artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
)
a_pipeline._args_map = args_map or {}

View File

@ -522,6 +522,7 @@ class ClearmlJob(BaseJob):
disable_clone_task=False, # type: bool
allow_caching=False, # type: bool
target_project=None, # type: Optional[str]
output_uri=None, # type: Optional[Union[str, bool]]
**kwargs # type: Any
):
# type: (...) -> ()
@ -545,6 +546,8 @@ class ClearmlJob(BaseJob):
If True, use the base_task_id directly (base-task must be in draft-mode / created),
:param bool allow_caching: If True, check if we have a previously executed Task with the same specification.
If we do, use it and set internal is_cached flag. Default False (always create new Task).
:param Union[str, bool] output_uri: The storage / output url for this job. This is the default location for
output models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:param str target_project: Optional, Set the target project name to create the cloned Task in.
"""
super(ClearmlJob, self).__init__()
@ -660,6 +663,8 @@ class ClearmlJob(BaseJob):
# noinspection PyProtectedMember
self.task._edit(**sections)
if output_uri is not None:
self.task.output_uri = output_uri
self._set_task_cache_hash(self.task, task_hash)
self.task_started = False
self._worker = None

View File

@ -78,13 +78,15 @@ class InterfaceBase(SessionInterface):
except MaxRequestSizeError as e:
res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e)))
error_msg = 'Failed sending: %s' % str(e)
except requests.exceptions.ConnectionError:
except requests.exceptions.ConnectionError as e:
# We couldn't send the request for more than the retries times configure in the api configuration file,
# so we will end the loop and raise the exception to the upper level.
# Notice: this is a connectivity error and not a backend error.
if raise_on_errors:
raise
# if raise_on_errors:
# raise
res = None
if log and num_retries >= cls._num_retry_warning_display:
log.warning('Retrying, previous request failed %s: %s' % (str(type(req)), str(e)))
except cls._JSON_EXCEPTION as e:
if log:
log.error(

View File

@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return False
return bool(self.data.ready)
def download_model_weights(self, raise_on_error=False, force_download=False):
def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
"""
Download the model weights into a local file in our cache
@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
:param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached.
:param bool extract_archive: If True, unzip the downloaded file if possible
:return: a local path to a downloaded copy of the model
"""
uri = self.data.uri
@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
Model._local_model_to_id_uri.pop(dl_file, None)
local_download = StorageManager.get_local_copy(
uri, extract_archive=False, force_download=force_download
uri, extract_archive=extract_archive, force_download=force_download
)
# save local model, so we can later query what was the original one

View File

@ -178,7 +178,7 @@ class CreateAndPopulate(object):
project=Task.get_project_id(self.project_name),
type=str(self.task_type or Task.TaskTypes.training),
) # type: dict
if self.output_uri:
if self.output_uri is not None:
task_state['output'] = dict(destination=self.output_uri)
else:
task_state = dict(script={})
@ -391,7 +391,7 @@ class CreateAndPopulate(object):
return task
def _set_output_uri(self, task):
if self.output_uri:
if self.output_uri is not None:
try:
task.output_uri = self.output_uri
except ValueError:

View File

@ -1344,12 +1344,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
params = self.get_parameters(cast=cast)
return params.get(name, default)
def delete_parameter(self, name):
def delete_parameter(self, name, force=False):
# type: (str) -> bool
"""
Delete a parameter by its full name Section/name.
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
:param force: If set to True then both new and running task hyper params can be deleted.
Otherwise only the new task ones. Default is False
:return: True if the parameter was deleted successfully
"""
if not Session.check_min_api_version('2.9'):
@ -1360,7 +1362,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
with self._edit_lock:
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
res = self.send(tasks.DeleteHyperParamsRequest(
task=self.id, hyperparams=[paramkey]), raise_on_errors=False)
task=self.id, hyperparams=[paramkey], force=force), raise_on_errors=False)
self.reload()
return res.ok()

View File

@ -15,6 +15,8 @@ class PatchHydra(object):
_config_section = 'OmegaConf'
_parameter_section = 'Hydra'
_parameter_allow_full_edit = '_allow_omegaconf_edit_'
_should_delete_overrides = False
_overrides_section = "Args/overrides"
@classmethod
def patch_hydra(cls):
@ -42,6 +44,12 @@ class PatchHydra(object):
except Exception:
return False
@classmethod
def delete_overrides(cls):
if not cls._should_delete_overrides or not cls._current_task:
return
cls._current_task.delete_parameter(cls._overrides_section, force=True)
@staticmethod
def update_current_task(task):
# set current Task before patching
@ -50,11 +58,24 @@ class PatchHydra(object):
return
if PatchHydra.patch_hydra():
# check if we have an untracked state, store it.
if PatchHydra._last_untracked_state.get('connect'):
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect'])
if PatchHydra._last_untracked_state.get('_set_configuration'):
if PatchHydra._last_untracked_state.get("connect"):
if PatchHydra._parameter_allow_full_edit in PatchHydra._last_untracked_state["connect"].get("mutable", {}):
allow_omegaconf_edit_section = PatchHydra._parameter_section + "/" + PatchHydra._parameter_allow_full_edit
allow_omegaconf_edit_section_val = PatchHydra._last_untracked_state["connect"]["mutable"].pop(
PatchHydra._parameter_allow_full_edit
)
PatchHydra._current_task.set_parameter(
allow_omegaconf_edit_section,
allow_omegaconf_edit_section_val,
description="If True, the `{}` parameter section will be completely ignored. The OmegaConf will instead be pulled from the `{}` section".format(
PatchHydra._parameter_section,
PatchHydra._config_section
)
)
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state["connect"])
if PatchHydra._last_untracked_state.get("_set_configuration"):
# noinspection PyProtectedMember
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration'])
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state["_set_configuration"])
PatchHydra._last_untracked_state = {}
else:
# if patching failed set it to None
@ -63,16 +84,17 @@ class PatchHydra(object):
@staticmethod
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
PatchHydra._allow_omegaconf_edit = False
if not running_remotely():
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
# store the config
# get the parameters from the backend
# noinspection PyBroadException
try:
if running_remotely():
if not PatchHydra._current_task:
from ..task import Task
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
# get the _parameter_allow_full_edit casted back to boolean
connected_config = dict()
connected_config = {}
connected_config[PatchHydra._parameter_allow_full_edit] = False
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
@ -81,18 +103,15 @@ class PatchHydra(object):
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
if k.startswith(PatchHydra._parameter_section+'/')}
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
# noinspection PyBroadException
try:
overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or []
except Exception:
overrides = []
if overrides and not isinstance(overrides, (list, tuple)):
overrides = [overrides]
overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()]
overrides = [("+" + o) if (o.startswith("+") and not o.startswith("++")) else o for o in overrides]
for override_k, override_v in stored_config.items():
if override_k.startswith("~"):
new_override = override_k
else:
# We take care of it inside the _patched_run_job
pass
new_override = "++" + override_k.lstrip("+")
if override_v is not None and override_v != "":
new_override += "=" + override_v
overrides.append(new_override)
PatchHydra._should_delete_overrides = True
except Exception:
pass
@ -114,12 +133,18 @@ class PatchHydra(object):
# store the config
# noinspection PyBroadException
try:
if running_remotely():
# we take care of it in the hydra run (where we have access to the overrides)
pass
else:
if not running_remotely():
# note that we fetch the overrides from the backend in hydra run when running remotely,
# here we just get them from hydra to be stored as configuration/parameters
overrides = config.hydra.overrides.task
stored_config = dict(arg.split('=', 1) for arg in overrides)
stored_config = {}
for arg in overrides:
arg = arg.lstrip("+")
if "=" in arg:
k, v = arg.split("=", 1)
stored_config[k] = v
else:
stored_config[arg] = None
stored_config[PatchHydra._parameter_allow_full_edit] = False
if PatchHydra._current_task:
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
@ -127,9 +152,7 @@ class PatchHydra(object):
else:
PatchHydra._last_untracked_state['connect'] = dict(
mutable=stored_config, name=PatchHydra._parameter_section)
# Maybe ?! remove the overrides section from the Args (we have it here)
# But when used with a Pipeline this is the only section we get... so we leave it here anyhow
# PatchHydra._current_task.delete_parameter('Args/overrides')
PatchHydra._should_delete_overrides = True
except Exception:
pass
@ -176,8 +199,7 @@ class PatchHydra(object):
else:
# noinspection PyProtectedMember
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
loaded_config = OmegaConf.load(io.StringIO(omega_yaml))
a_config = OmegaConf.merge(a_config, loaded_config)
a_config = OmegaConf.load(io.StringIO(omega_yaml))
PatchHydra._register_omegaconf(a_config, is_read_only=False)
return task_function(a_config, *a_args, **a_kwargs)
@ -194,10 +216,6 @@ class PatchHydra(object):
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
# we should not have the hydra section in the config, but this seems never to be the case anymore.
# config = config.copy()
# config.pop('hydra', None)
configuration = dict(
name=PatchHydra._config_section,
description=description,

View File

@ -28,7 +28,9 @@ class PatchJsonArgParse(object):
_commands_sep = "."
_command_type = "jsonargparse.Command"
_command_name = "subcommand"
_special_fields = ["config", "subcommand"]
_section_name = "Args"
_allow_jsonargparse_overrides = "_allow_config_file_override_from_ui_"
__remote_task_params = {}
__remote_task_params_dict = {}
__patched = False
@ -60,6 +62,7 @@ class PatchJsonArgParse(object):
return
args = {}
args_type = {}
have_config_file = False
for k, v in cls._args.items():
key_with_section = cls._section_name + cls._args_sep + k
args[key_with_section] = v
@ -75,27 +78,18 @@ class PatchJsonArgParse(object):
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v))
args_type[key_with_section] = PatchJsonArgParse.path_type
have_config_file = True
else:
args[key_with_section] = str(v)
except Exception:
pass
args, args_type = cls.__delete_config_args(parser, args, args_type, subcommand=subcommand)
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
@classmethod
def __delete_config_args(cls, parser, args, args_type, subcommand=None):
if not parser:
return args, args_type
paths = PatchJsonArgParse.__get_paths_from_dict(cls._args)
for path in paths:
args_to_delete = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
for arg_to_delete_key, arg_to_delete_value in args_to_delete.items():
key_with_section = cls._section_name + cls._args_sep + arg_to_delete_key
if key_with_section in args and args[key_with_section] == arg_to_delete_value:
del args[key_with_section]
if key_with_section in args_type:
del args_type[key_with_section]
return args, args_type
if have_config_file:
cls._current_task.set_parameter(
cls._section_name + cls._args_sep + cls._allow_jsonargparse_overrides,
False,
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
)
@staticmethod
def _adapt_typehints(original_fn, val, *args, **kwargs):
@ -103,6 +97,17 @@ class PatchJsonArgParse(object):
return original_fn(val, *args, **kwargs)
return original_fn(val, *args, **kwargs)
@staticmethod
def __restore_args(parser, args, subcommand=None):
paths = PatchJsonArgParse.__get_paths_from_dict(args)
for path in paths:
args_to_restore = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
for arg_to_restore_key, arg_to_restore_value in args_to_restore.items():
if arg_to_restore_key in PatchJsonArgParse._special_fields:
continue
args[arg_to_restore_key] = arg_to_restore_value
return args
@staticmethod
def _parse_args(original_fn, obj, *args, **kwargs):
if not PatchJsonArgParse._current_task:
@ -119,6 +124,13 @@ class PatchJsonArgParse(object):
params_namespace = Namespace()
for k, v in params.items():
params_namespace[k] = v
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides, True)
if not allow_jsonargparse_overrides_value:
params_namespace = PatchJsonArgParse.__restore_args(
obj,
params_namespace,
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
)
return params_namespace
except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
@ -210,7 +222,7 @@ class PatchJsonArgParse(object):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand:
parsed_cfg = {
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
for k, v in parsed_cfg.items()
}
return parsed_cfg

View File

@ -3215,17 +3215,17 @@ class Dataset(object):
pool.close()
def _verify_dataset_folder(self, target_base_folder, part, chunk_selection, max_workers):
# type: (Path, Optional[int], Optional[dict], Optional[int]) -> bool
# type: (Path, int, dict, int) -> bool
def verify_file_or_link(base_folder, ds_part, ds_chunk_selection, file_entry):
# type: (Path, Optional[int], Optional[dict], FileEntry) -> Optional[bool]
def __verify_file_or_link(target_base_folder, file_entry, part=None, chunk_selection=None):
# type: (Path, Union[FileEntry, LinkEntry], Optional[int], Optional[dict]) -> bool
# check if we need the file for the requested dataset part
if ds_part is not None:
f_parts = ds_chunk_selection.get(file_entry.parent_dataset_id, [])
# file is not in requested dataset part, no need to check it.
if self._get_chunk_idx_from_artifact_name(file_entry.artifact_name) not in f_parts:
return None
return True
# check if the local size and the stored size match (faster than comparing hash)
if (base_folder / file_entry.relative_path).stat().st_size != file_entry.size:
@ -3237,21 +3237,26 @@ class Dataset(object):
# check dataset file size, if we have a full match no need for parent dataset download / merge
verified = True
# noinspection PyBroadException
tp = None
try:
futures_ = []
with ThreadPoolExecutor(max_workers=max_workers) as tp:
tp = ThreadPoolExecutor(max_workers=max_workers)
for f in self._dataset_file_entries.values():
future = tp.submit(verify_file_or_link, target_base_folder, part, chunk_selection, f)
future = tp.submit(__verify_file_or_link, target_base_folder, f, part, chunk_selection)
futures_.append(future)
for f in self._dataset_link_entries.values():
# don't check whether link is in dataset part, hence None for part and chunk_selection
future = tp.submit(verify_file_or_link, target_base_folder, None, None, f)
future = tp.submit(__verify_file_or_link, target_base_folder, f, None, None)
futures_.append(future)
verified = all(f.result() is not False for f in futures_)
verified = all(f.result() for f in futures_)
except Exception:
verified = False
finally:
if tp is not None:
# we already have our result, close all pending checks (improves performance when verified==False)
tp.shutdown(cancel_futures=True)
return verified

View File

@ -362,8 +362,8 @@ class BaseModel(object):
self._task_connect_name = None
self._set_task(task)
def get_weights(self, raise_on_error=False, force_download=False):
# type: (bool, bool) -> str
def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
# type: (bool, bool, bool) -> str
"""
Download the base model and return the locally stored filename.
@ -373,17 +373,19 @@ class BaseModel(object):
:param bool force_download: If True, the base model will be downloaded,
even if the base model is already cached.
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
:return: The locally stored file.
"""
# download model (synchronously) and return local file
return self._get_base_model().download_model_weights(
raise_on_error=raise_on_error, force_download=force_download
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
)
def get_weights_package(
self, return_path=False, raise_on_error=False, force_download=False
self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
):
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
# type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
"""
Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames.
@ -399,6 +401,8 @@ class BaseModel(object):
:param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached.
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
:return: The model weights, or a list of the locally stored filenames.
if raise_on_error=False, returns None on error.
"""
@ -407,40 +411,21 @@ class BaseModel(object):
raise ValueError("Model is not packaged")
# download packaged model
packed_file = self.get_weights(
raise_on_error=raise_on_error, force_download=force_download
model_path = self.get_weights(
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
)
if not packed_file:
if not model_path:
if raise_on_error:
raise ValueError(
"Model package '{}' could not be downloaded".format(self.url)
)
return None
# unpack
target_folder = mkdtemp(prefix="model_package_")
if not target_folder:
raise ValueError(
"cannot create temporary directory for packed weight files"
)
for func in (zipfile.ZipFile, tarfile.open):
try:
obj = func(packed_file)
obj.extractall(path=target_folder)
break
except (zipfile.BadZipfile, tarfile.ReadError):
pass
else:
raise ValueError(
"cannot extract files from packaged model at %s", packed_file
)
if return_path:
return target_folder
return model_path
target_files = list(Path(target_folder).glob("*"))
target_files = list(Path(model_path).glob("*"))
return target_files
def report_scalar(self, title, series, value, iteration):
@ -1374,17 +1359,18 @@ class Model(BaseModel):
self._base_model = None
def get_local_copy(
self, extract_archive=True, raise_on_error=False, force_download=False
self, extract_archive=None, raise_on_error=False, force_download=False
):
# type: (bool, bool, bool) -> str
# type: (Optional[bool], bool, bool) -> str
"""
Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly.
If the model URL points to a remote location (http/s3/gs etc.),
it will download the file(s) and return the temporary location of the downloaded model.
:param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
The returned path will be a temporary folder containing the archive content
:param bool extract_archive: If True, the local copy will be extracted if possible. If False,
the local copy will not be extracted. If None (default), the downloaded file will be extracted
if the model is a package.
:param bool raise_on_error: If True, and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the artifact will be downloaded,
@ -1392,14 +1378,17 @@ class Model(BaseModel):
:return: A local path to the model (or a downloaded copy of it).
"""
if extract_archive and self._is_package():
if self._is_package():
return self.get_weights_package(
return_path=True,
raise_on_error=raise_on_error,
force_download=force_download,
extract_archive=True if extract_archive is None else extract_archive
)
return self.get_weights(
raise_on_error=raise_on_error, force_download=force_download
raise_on_error=raise_on_error,
force_download=force_download,
extract_archive=False if extract_archive is None else extract_archive
)
def _get_base_model(self):

View File

@ -739,6 +739,8 @@ class Task(_Task):
if argparser_parseargs_called():
for parser, parsed_args in get_argparser_last_args():
task._connect_argparse(parser=parser, parsed_args=parsed_args)
PatchHydra.delete_overrides()
elif argparser_parseargs_called():
# actually we have nothing to do, in remote running, the argparser will ignore
# all non argparser parameters, only caveat if parameter connected with the same name
@ -2604,15 +2606,15 @@ class Task(_Task):
self.reload()
script = self.data.script
if repository is not None:
script.repository = str(repository) or None
script.repository = str(repository)
if branch is not None:
script.branch = str(branch) or None
script.branch = str(branch)
if script.tag:
script.tag = None
if commit is not None:
script.version_num = str(commit) or None
script.version_num = str(commit)
if diff is not None:
script.diff = str(diff) or None
script.diff = str(diff)
if working_dir is not None:
script.working_dir = str(working_dir)
if entry_point is not None:

View File

@ -1,6 +1,7 @@
import itertools
import json
from copy import copy
from logging import getLogger
import six
import yaml
@ -132,15 +133,20 @@ def cast_basic_type(value, type_str):
parts = type_str.split('/')
# nested = len(parts) > 1
if parts[0] in ('list', 'tuple'):
if parts[0] in ("list", "tuple", "dict"):
# noinspection PyBroadException
try:
# lists/tuple/dicts should be json loadable
return basic_types.get(parts[0])(json.loads(value))
except Exception:
# noinspection PyBroadException
try:
# fallback to legacy basic type loading
v = '[' + value.lstrip('[(').rstrip('])') + ']'
v = yaml.load(v, Loader=yaml.SafeLoader)
return basic_types.get(parts[0])(v)
elif parts[0] in ('dict', ):
try:
return json.loads(value)
except Exception:
pass
getLogger().warning("Could not cast `{}` to basic type. Returning it as `str`".format(value))
return value
t = basic_types.get(str(type_str).lower().strip(), False)

View File

@ -69,8 +69,7 @@
"Input the following parameters:\n",
"* `name` - Name of the PipelineController task which will created\n",
"* `project` - Project which the controller will be associated with\n",
"* `version` - Pipeline's version number. This version allows to uniquely identify the pipeline template execution.\n",
"* `auto_version_bump` (default True) - if the same pipeline version already exists (with any difference from the current one), the current pipeline version will be bumped to a new version (e.g. 1.0.0 -> 1.0.1 , 1.2 -> 1.3, 10 -> 11)\n",
"* `version` - Pipeline's version number. This version allows to uniquely identify the pipeline template execution. If not set, find the pipeline's latest version and increment it. If no such version is found, defaults to `1.0.0`.\n",
" "
]
},