This commit is contained in:
revital 2023-10-25 10:34:38 +03:00
commit 52c47b5551
13 changed files with 271 additions and 160 deletions

View File

@ -101,6 +101,7 @@ class PipelineController(object):
explicit_docker_image = attrib(type=str, default=None) # The Docker image the node uses, specified at creation explicit_docker_image = attrib(type=str, default=None) # The Docker image the node uses, specified at creation
recursively_parse_parameters = attrib(type=bool, default=False) # if True, recursively parse parameters in recursively_parse_parameters = attrib(type=bool, default=False) # if True, recursively parse parameters in
# lists, dicts, or tuples # lists, dicts, or tuples
output_uri = attrib(type=Union[bool, str], default=None) # The default location for output models and other artifacts
def __attrs_post_init__(self): def __attrs_post_init__(self):
if self.parents is None: if self.parents is None:
@ -134,6 +135,26 @@ class PipelineController(object):
new_copy.task_factory_func = self.task_factory_func new_copy.task_factory_func = self.task_factory_func
return new_copy return new_copy
def set_job_ended(self):
if self.job_ended:
return
# noinspection PyBroadException
try:
self.job.task.reload()
self.job_ended = self.job_started + self.job.task.data.active_duration
except Exception as e:
pass
def set_job_started(self):
if self.job_started:
return
# noinspection PyBroadException
try:
self.job_started = self.job.task.data.started.timestamp()
except Exception:
pass
def __init__( def __init__(
self, self,
name, # type: str name, # type: str
@ -155,7 +176,8 @@ class PipelineController(object):
repo_commit=None, # type: Optional[str] repo_commit=None, # type: Optional[str]
always_create_from_code=True, # type: bool always_create_from_code=True, # type: bool
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]] artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
): ):
# type: (...) -> None # type: (...) -> None
""" """
@ -242,6 +264,9 @@ class PipelineController(object):
def deserialize(bytes_): def deserialize(bytes_):
import dill import dill
return dill.loads(bytes_) return dill.loads(bytes_)
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
""" """
if auto_version_bump is not None: if auto_version_bump is not None:
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning) warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
@ -316,6 +341,9 @@ class PipelineController(object):
project_id=self._task.project, system_tags=self._project_system_tags) project_id=self._task.project, system_tags=self._project_system_tags)
self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag]) self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag])
if output_uri is not None:
self._task.output_uri = output_uri
self._output_uri = output_uri
self._task.set_base_docker( self._task.set_base_docker(
docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script
) )
@ -387,7 +415,8 @@ class PipelineController(object):
base_task_factory=None, # type: Optional[Callable[[PipelineController.Node], Task]] base_task_factory=None, # type: Optional[Callable[[PipelineController.Node], Task]]
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
recursively_parse_parameters=False # type: bool recursively_parse_parameters=False, # type: bool
output_uri=None # type: Optional[Union[str, bool]]
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -529,7 +558,9 @@ class PipelineController(object):
previous_status # type: str previous_status # type: str
): ):
pass pass
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: True if successful :return: True if successful
""" """
@ -588,6 +619,7 @@ class PipelineController(object):
monitor_metrics=monitor_metrics or [], monitor_metrics=monitor_metrics or [],
monitor_artifacts=monitor_artifacts or [], monitor_artifacts=monitor_artifacts or [],
monitor_models=monitor_models or [], monitor_models=monitor_models or [],
output_uri=self._output_uri if output_uri is None else output_uri
) )
self._retries[name] = 0 self._retries[name] = 0
self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \ self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \
@ -632,7 +664,8 @@ class PipelineController(object):
cache_executed_step=False, # type: bool cache_executed_step=False, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
tags=None # type: Optional[Union[str, Sequence[str]]] tags=None, # type: Optional[Union[str, Sequence[str]]]
output_uri=None # type: Optional[Union[str, bool]]
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -799,6 +832,8 @@ class PipelineController(object):
:param tags: A list of tags for the specific pipeline step. :param tags: A list of tags for the specific pipeline step.
When executing a Pipeline remotely When executing a Pipeline remotely
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect. (i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: True if successful :return: True if successful
""" """
@ -838,7 +873,8 @@ class PipelineController(object):
cache_executed_step=cache_executed_step, cache_executed_step=cache_executed_step,
retry_on_failure=retry_on_failure, retry_on_failure=retry_on_failure,
status_change_callback=status_change_callback, status_change_callback=status_change_callback,
tags=tags tags=tags,
output_uri=output_uri
) )
def start( def start(
@ -1014,8 +1050,8 @@ class PipelineController(object):
return cls._get_pipeline_task().get_logger() return cls._get_pipeline_task().get_logger()
@classmethod @classmethod
def upload_model(cls, model_name, model_local_path): def upload_model(cls, model_name, model_local_path, upload_uri=None):
# type: (str, str) -> OutputModel # type: (str, str, Optional[str]) -> OutputModel
""" """
Upload (add) a model to the main Pipeline Task object. Upload (add) a model to the main Pipeline Task object.
This function can be called from any pipeline component to directly add models into the main pipeline Task This function can be called from any pipeline component to directly add models into the main pipeline Task
@ -1028,12 +1064,16 @@ class PipelineController(object):
:param model_local_path: Path to the local model file or directory to be uploaded. :param model_local_path: Path to the local model file or directory to be uploaded.
If a local directory is provided the content of the folder (recursively) will be If a local directory is provided the content of the folder (recursively) will be
packaged into a zip file and uploaded packaged into a zip file and uploaded
:param upload_uri: The URI of the storage destination for model weights upload. The default value
is the previously used URI.
:return: The uploaded OutputModel
""" """
task = cls._get_pipeline_task() task = cls._get_pipeline_task()
model_name = str(model_name) model_name = str(model_name)
model_local_path = Path(model_local_path) model_local_path = Path(model_local_path)
out_model = OutputModel(task=task, name=model_name) out_model = OutputModel(task=task, name=model_name)
out_model.update_weights(weights_filename=model_local_path.as_posix()) out_model.update_weights(weights_filename=model_local_path.as_posix(), upload_uri=upload_uri)
return out_model return out_model
@classmethod @classmethod
@ -1457,7 +1497,7 @@ class PipelineController(object):
self, docker, docker_args, docker_bash_setup_script, self, docker, docker_args, docker_bash_setup_script,
function, function_input_artifacts, function_kwargs, function_return, function, function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser, auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name, task_type, repo, branch, commit, helper_functions packages, project_name, task_name, task_type, repo, branch, commit, helper_functions, output_uri=None
): ):
task_definition = CreateFromFunction.create_task_from_function( task_definition = CreateFromFunction.create_task_from_function(
a_function=function, a_function=function,
@ -1476,7 +1516,7 @@ class PipelineController(object):
docker=docker, docker=docker,
docker_args=docker_args, docker_args=docker_args,
docker_bash_setup_script=docker_bash_setup_script, docker_bash_setup_script=docker_bash_setup_script,
output_uri=None, output_uri=output_uri,
helper_functions=helper_functions, helper_functions=helper_functions,
dry_run=True, dry_run=True,
task_template_header=self._task_template_header, task_template_header=self._task_template_header,
@ -1591,7 +1631,7 @@ class PipelineController(object):
'target_project': self._target_project, 'target_project': self._target_project,
} }
pipeline_dag = self._serialize() pipeline_dag = self._serialize()
# serialize pipeline state # serialize pipeline state
if self._task and self._auto_connect_task: if self._task and self._auto_connect_task:
# check if we are either running locally or that we are running remotely, # check if we are either running locally or that we are running remotely,
@ -1631,6 +1671,7 @@ class PipelineController(object):
self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version), self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version),
"version": self._version "version": self._version
}) })
self._task.set_user_properties(version=self._version)
else: else:
self._task.connect_configuration(pipeline_dag, name=self._config_section) self._task.connect_configuration(pipeline_dag, name=self._config_section)
connected_args = set() connected_args = set()
@ -1927,7 +1968,8 @@ class PipelineController(object):
cache_executed_step=False, # type: bool cache_executed_step=False, # type: bool
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
tags=None # type: Optional[Union[str, Sequence[str]]] tags=None, # type: Optional[Union[str, Sequence[str]]]
output_uri=None # type: Optional[Union[str, bool]]
): ):
# type: (...) -> bool # type: (...) -> bool
""" """
@ -2094,6 +2136,8 @@ class PipelineController(object):
:param tags: A list of tags for the specific pipeline step. :param tags: A list of tags for the specific pipeline step.
When executing a Pipeline remotely When executing a Pipeline remotely
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect. (i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: True if successful :return: True if successful
""" """
@ -2107,6 +2151,9 @@ class PipelineController(object):
self._verify_node_name(name) self._verify_node_name(name)
if output_uri is None:
output_uri = self._output_uri
function_input_artifacts = {} function_input_artifacts = {}
# go over function_kwargs, split it into string and input artifacts # go over function_kwargs, split it into string and input artifacts
for k, v in function_kwargs.items(): for k, v in function_kwargs.items():
@ -2145,7 +2192,7 @@ class PipelineController(object):
function_input_artifacts, function_kwargs, function_return, function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser, auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name, packages, project_name, task_name,
task_type, repo, repo_branch, repo_commit, helper_functions) task_type, repo, repo_branch, repo_commit, helper_functions, output_uri=output_uri)
elif self._task.running_locally() or self._task.get_configuration_object(name=name) is None: elif self._task.running_locally() or self._task.get_configuration_object(name=name) is None:
project_name = project_name or self._get_target_project() or self._task.get_project_name() project_name = project_name or self._get_target_project() or self._task.get_project_name()
@ -2155,7 +2202,7 @@ class PipelineController(object):
function_input_artifacts, function_kwargs, function_return, function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser, auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name, packages, project_name, task_name,
task_type, repo, repo_branch, repo_commit, helper_functions) task_type, repo, repo_branch, repo_commit, helper_functions, output_uri=output_uri)
# update configuration with the task definitions # update configuration with the task definitions
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._task._set_configuration( self._task._set_configuration(
@ -2180,6 +2227,9 @@ class PipelineController(object):
if tags: if tags:
a_task.add_tags(tags) a_task.add_tags(tags)
if output_uri is not None:
a_task.output_uri = output_uri
return a_task return a_task
self._nodes[name] = self.Node( self._nodes[name] = self.Node(
@ -2195,7 +2245,8 @@ class PipelineController(object):
monitor_metrics=monitor_metrics, monitor_metrics=monitor_metrics,
monitor_models=monitor_models, monitor_models=monitor_models,
job_code_section=job_code_section, job_code_section=job_code_section,
explicit_docker_image=docker explicit_docker_image=docker,
output_uri=output_uri
) )
self._retries[name] = 0 self._retries[name] = 0
self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \ self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \
@ -2284,13 +2335,14 @@ class PipelineController(object):
disable_clone_task=disable_clone_task, disable_clone_task=disable_clone_task,
task_overrides=task_overrides, task_overrides=task_overrides,
allow_caching=node.cache_executed_step, allow_caching=node.cache_executed_step,
output_uri=node.output_uri,
**extra_args **extra_args
) )
except Exception: except Exception:
self._pipeline_task_status_failed = True self._pipeline_task_status_failed = True
raise raise
node.job_started = time() node.job_started = None
node.job_ended = None node.job_ended = None
node.job_type = str(node.job.task.task_type) node.job_type = str(node.job.task.task_type)
@ -2546,6 +2598,8 @@ class PipelineController(object):
""" """
previous_status = node.status previous_status = node.status
if node.job and node.job.is_running():
node.set_job_started()
update_job_ended = node.job_started and not node.job_ended update_job_ended = node.job_started and not node.job_ended
if node.executed is not None: if node.executed is not None:
@ -2582,7 +2636,7 @@ class PipelineController(object):
node.status = "pending" node.status = "pending"
if update_job_ended and node.status in ("aborted", "failed", "completed"): if update_job_ended and node.status in ("aborted", "failed", "completed"):
node.job_ended = time() node.set_job_ended()
if ( if (
previous_status is not None previous_status is not None
@ -2679,7 +2733,7 @@ class PipelineController(object):
if node_failed and self._abort_running_steps_on_failure and not node.continue_on_fail: if node_failed and self._abort_running_steps_on_failure and not node.continue_on_fail:
nodes_failed_stop_pipeline.append(node.name) nodes_failed_stop_pipeline.append(node.name)
elif node.timeout: elif node.timeout:
started = node.job.task.data.started node.set_job_started()
if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout: if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout:
node.job.abort() node.job.abort()
completed_jobs.append(j) completed_jobs.append(j)
@ -3261,7 +3315,8 @@ class PipelineDecorator(PipelineController):
repo_branch=None, # type: Optional[str] repo_branch=None, # type: Optional[str]
repo_commit=None, # type: Optional[str] repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]] artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
): ):
# type: (...) -> () # type: (...) -> ()
""" """
@ -3341,6 +3396,9 @@ class PipelineDecorator(PipelineController):
def deserialize(bytes_): def deserialize(bytes_):
import dill import dill
return dill.loads(bytes_) return dill.loads(bytes_)
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
""" """
super(PipelineDecorator, self).__init__( super(PipelineDecorator, self).__init__(
name=name, name=name,
@ -3361,7 +3419,8 @@ class PipelineDecorator(PipelineController):
repo_commit=repo_commit, repo_commit=repo_commit,
always_create_from_code=False, always_create_from_code=False,
artifact_serialization_function=artifact_serialization_function, artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
) )
# if we are in eager execution, make sure parent class knows it # if we are in eager execution, make sure parent class knows it
@ -3583,7 +3642,7 @@ class PipelineDecorator(PipelineController):
function, function_input_artifacts, function_kwargs, function_return, function, function_input_artifacts, function_kwargs, function_return,
auto_connect_frameworks, auto_connect_arg_parser, auto_connect_frameworks, auto_connect_arg_parser,
packages, project_name, task_name, task_type, repo, branch, commit, packages, project_name, task_name, task_type, repo, branch, commit,
helper_functions helper_functions, output_uri=None
): ):
def sanitize(function_source): def sanitize(function_source):
matched = re.match(r"[\s]*@[\w]*.component[\s\\]*\(", function_source) matched = re.match(r"[\s]*@[\w]*.component[\s\\]*\(", function_source)
@ -3621,7 +3680,7 @@ class PipelineDecorator(PipelineController):
docker=docker, docker=docker,
docker_args=docker_args, docker_args=docker_args,
docker_bash_setup_script=docker_bash_setup_script, docker_bash_setup_script=docker_bash_setup_script,
output_uri=None, output_uri=output_uri,
helper_functions=helper_functions, helper_functions=helper_functions,
dry_run=True, dry_run=True,
task_template_header=self._task_template_header, task_template_header=self._task_template_header,
@ -3703,7 +3762,8 @@ class PipelineDecorator(PipelineController):
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
tags=None # type: Optional[Union[str, Sequence[str]]] tags=None, # type: Optional[Union[str, Sequence[str]]]
output_uri=None # type: Optional[Union[str, bool]]
): ):
# type: (...) -> Callable # type: (...) -> Callable
""" """
@ -3841,6 +3901,8 @@ class PipelineDecorator(PipelineController):
:param tags: A list of tags for the specific pipeline step. :param tags: A list of tags for the specific pipeline step.
When executing a Pipeline remotely When executing a Pipeline remotely
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect. (i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
:param output_uri: The storage / output url for this step. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:return: function wrapper :return: function wrapper
""" """
@ -3883,7 +3945,8 @@ class PipelineDecorator(PipelineController):
pre_execute_callback=pre_execute_callback, pre_execute_callback=pre_execute_callback,
post_execute_callback=post_execute_callback, post_execute_callback=post_execute_callback,
status_change_callback=status_change_callback, status_change_callback=status_change_callback,
tags=tags tags=tags,
output_uri=output_uri
) )
if cls._singleton: if cls._singleton:
@ -4109,7 +4172,8 @@ class PipelineDecorator(PipelineController):
repo_branch=None, # type: Optional[str] repo_branch=None, # type: Optional[str]
repo_commit=None, # type: Optional[str] repo_commit=None, # type: Optional[str]
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]] artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]] artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
output_uri=None # type: Optional[Union[str, bool]]
): ):
# type: (...) -> Callable # type: (...) -> Callable
""" """
@ -4220,6 +4284,9 @@ class PipelineDecorator(PipelineController):
def deserialize(bytes_): def deserialize(bytes_):
import dill import dill
return dill.loads(bytes_) return dill.loads(bytes_)
:param output_uri: The storage / output url for this pipeline. This is the default location for output
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
The `output_uri` of this pipeline's steps will default to this value.
""" """
def decorator_wrap(func): def decorator_wrap(func):
@ -4265,7 +4332,8 @@ class PipelineDecorator(PipelineController):
repo_branch=repo_branch, repo_branch=repo_branch,
repo_commit=repo_commit, repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function, artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
) )
ret_val = func(**pipeline_kwargs) ret_val = func(**pipeline_kwargs)
LazyEvalWrapper.trigger_all_remote_references() LazyEvalWrapper.trigger_all_remote_references()
@ -4316,7 +4384,8 @@ class PipelineDecorator(PipelineController):
repo_branch=repo_branch, repo_branch=repo_branch,
repo_commit=repo_commit, repo_commit=repo_commit,
artifact_serialization_function=artifact_serialization_function, artifact_serialization_function=artifact_serialization_function,
artifact_deserialization_function=artifact_deserialization_function artifact_deserialization_function=artifact_deserialization_function,
output_uri=output_uri
) )
a_pipeline._args_map = args_map or {} a_pipeline._args_map = args_map or {}

View File

@ -522,6 +522,7 @@ class ClearmlJob(BaseJob):
disable_clone_task=False, # type: bool disable_clone_task=False, # type: bool
allow_caching=False, # type: bool allow_caching=False, # type: bool
target_project=None, # type: Optional[str] target_project=None, # type: Optional[str]
output_uri=None, # type: Optional[Union[str, bool]]
**kwargs # type: Any **kwargs # type: Any
): ):
# type: (...) -> () # type: (...) -> ()
@ -545,6 +546,8 @@ class ClearmlJob(BaseJob):
If True, use the base_task_id directly (base-task must be in draft-mode / created), If True, use the base_task_id directly (base-task must be in draft-mode / created),
:param bool allow_caching: If True, check if we have a previously executed Task with the same specification. :param bool allow_caching: If True, check if we have a previously executed Task with the same specification.
If we do, use it and set internal is_cached flag. Default False (always create new Task). If we do, use it and set internal is_cached flag. Default False (always create new Task).
:param Union[str, bool] output_uri: The storage / output url for this job. This is the default location for
output models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
:param str target_project: Optional, Set the target project name to create the cloned Task in. :param str target_project: Optional, Set the target project name to create the cloned Task in.
""" """
super(ClearmlJob, self).__init__() super(ClearmlJob, self).__init__()
@ -660,6 +663,8 @@ class ClearmlJob(BaseJob):
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.task._edit(**sections) self.task._edit(**sections)
if output_uri is not None:
self.task.output_uri = output_uri
self._set_task_cache_hash(self.task, task_hash) self._set_task_cache_hash(self.task, task_hash)
self.task_started = False self.task_started = False
self._worker = None self._worker = None

View File

@ -78,13 +78,15 @@ class InterfaceBase(SessionInterface):
except MaxRequestSizeError as e: except MaxRequestSizeError as e:
res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e))) res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e)))
error_msg = 'Failed sending: %s' % str(e) error_msg = 'Failed sending: %s' % str(e)
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError as e:
# We couldn't send the request for more than the retries times configure in the api configuration file, # We couldn't send the request for more than the retries times configure in the api configuration file,
# so we will end the loop and raise the exception to the upper level. # so we will end the loop and raise the exception to the upper level.
# Notice: this is a connectivity error and not a backend error. # Notice: this is a connectivity error and not a backend error.
if raise_on_errors: # if raise_on_errors:
raise # raise
res = None res = None
if log and num_retries >= cls._num_retry_warning_display:
log.warning('Retrying, previous request failed %s: %s' % (str(type(req)), str(e)))
except cls._JSON_EXCEPTION as e: except cls._JSON_EXCEPTION as e:
if log: if log:
log.error( log.error(

View File

@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
return False return False
return bool(self.data.ready) return bool(self.data.ready)
def download_model_weights(self, raise_on_error=False, force_download=False): def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
""" """
Download the model weights into a local file in our cache Download the model weights into a local file in our cache
@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
:param bool force_download: If True, the base artifact will be downloaded, :param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached. even if the artifact is already cached.
:param bool extract_archive: If True, unzip the downloaded file if possible
:return: a local path to a downloaded copy of the model :return: a local path to a downloaded copy of the model
""" """
uri = self.data.uri uri = self.data.uri
@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
Model._local_model_to_id_uri.pop(dl_file, None) Model._local_model_to_id_uri.pop(dl_file, None)
local_download = StorageManager.get_local_copy( local_download = StorageManager.get_local_copy(
uri, extract_archive=False, force_download=force_download uri, extract_archive=extract_archive, force_download=force_download
) )
# save local model, so we can later query what was the original one # save local model, so we can later query what was the original one

View File

@ -178,7 +178,7 @@ class CreateAndPopulate(object):
project=Task.get_project_id(self.project_name), project=Task.get_project_id(self.project_name),
type=str(self.task_type or Task.TaskTypes.training), type=str(self.task_type or Task.TaskTypes.training),
) # type: dict ) # type: dict
if self.output_uri: if self.output_uri is not None:
task_state['output'] = dict(destination=self.output_uri) task_state['output'] = dict(destination=self.output_uri)
else: else:
task_state = dict(script={}) task_state = dict(script={})
@ -391,7 +391,7 @@ class CreateAndPopulate(object):
return task return task
def _set_output_uri(self, task): def _set_output_uri(self, task):
if self.output_uri: if self.output_uri is not None:
try: try:
task.output_uri = self.output_uri task.output_uri = self.output_uri
except ValueError: except ValueError:

View File

@ -1344,12 +1344,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
params = self.get_parameters(cast=cast) params = self.get_parameters(cast=cast)
return params.get(name, default) return params.get(name, default)
def delete_parameter(self, name): def delete_parameter(self, name, force=False):
# type: (str) -> bool # type: (str) -> bool
""" """
Delete a parameter by its full name Section/name. Delete a parameter by its full name Section/name.
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size' :param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
:param force: If set to True then both new and running task hyper params can be deleted.
Otherwise only the new task ones. Default is False
:return: True if the parameter was deleted successfully :return: True if the parameter was deleted successfully
""" """
if not Session.check_min_api_version('2.9'): if not Session.check_min_api_version('2.9'):
@ -1360,7 +1362,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
with self._edit_lock: with self._edit_lock:
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1]) paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
res = self.send(tasks.DeleteHyperParamsRequest( res = self.send(tasks.DeleteHyperParamsRequest(
task=self.id, hyperparams=[paramkey]), raise_on_errors=False) task=self.id, hyperparams=[paramkey], force=force), raise_on_errors=False)
self.reload() self.reload()
return res.ok() return res.ok()

View File

@ -15,6 +15,8 @@ class PatchHydra(object):
_config_section = 'OmegaConf' _config_section = 'OmegaConf'
_parameter_section = 'Hydra' _parameter_section = 'Hydra'
_parameter_allow_full_edit = '_allow_omegaconf_edit_' _parameter_allow_full_edit = '_allow_omegaconf_edit_'
_should_delete_overrides = False
_overrides_section = "Args/overrides"
@classmethod @classmethod
def patch_hydra(cls): def patch_hydra(cls):
@ -42,6 +44,12 @@ class PatchHydra(object):
except Exception: except Exception:
return False return False
@classmethod
def delete_overrides(cls):
if not cls._should_delete_overrides or not cls._current_task:
return
cls._current_task.delete_parameter(cls._overrides_section, force=True)
@staticmethod @staticmethod
def update_current_task(task): def update_current_task(task):
# set current Task before patching # set current Task before patching
@ -50,11 +58,24 @@ class PatchHydra(object):
return return
if PatchHydra.patch_hydra(): if PatchHydra.patch_hydra():
# check if we have an untracked state, store it. # check if we have an untracked state, store it.
if PatchHydra._last_untracked_state.get('connect'): if PatchHydra._last_untracked_state.get("connect"):
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect']) if PatchHydra._parameter_allow_full_edit in PatchHydra._last_untracked_state["connect"].get("mutable", {}):
if PatchHydra._last_untracked_state.get('_set_configuration'): allow_omegaconf_edit_section = PatchHydra._parameter_section + "/" + PatchHydra._parameter_allow_full_edit
allow_omegaconf_edit_section_val = PatchHydra._last_untracked_state["connect"]["mutable"].pop(
PatchHydra._parameter_allow_full_edit
)
PatchHydra._current_task.set_parameter(
allow_omegaconf_edit_section,
allow_omegaconf_edit_section_val,
description="If True, the `{}` parameter section will be completely ignored. The OmegaConf will instead be pulled from the `{}` section".format(
PatchHydra._parameter_section,
PatchHydra._config_section
)
)
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state["connect"])
if PatchHydra._last_untracked_state.get("_set_configuration"):
# noinspection PyProtectedMember # noinspection PyProtectedMember
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration']) PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state["_set_configuration"])
PatchHydra._last_untracked_state = {} PatchHydra._last_untracked_state = {}
else: else:
# if patching failed set it to None # if patching failed set it to None
@ -63,36 +84,34 @@ class PatchHydra(object):
@staticmethod @staticmethod
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs): def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
PatchHydra._allow_omegaconf_edit = False PatchHydra._allow_omegaconf_edit = False
if not running_remotely():
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
# store the config # get the parameters from the backend
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if running_remotely(): if not PatchHydra._current_task:
if not PatchHydra._current_task: from ..task import Task
from ..task import Task PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id()) # get the _parameter_allow_full_edit casted back to boolean
# get the _parameter_allow_full_edit casted back to boolean connected_config = {}
connected_config = dict() connected_config[PatchHydra._parameter_allow_full_edit] = False
connected_config[PatchHydra._parameter_allow_full_edit] = False PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section) PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None) # get all the overrides
# get all the overrides full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False) stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items() if k.startswith(PatchHydra._parameter_section+'/')}
if k.startswith(PatchHydra._parameter_section+'/')} stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
stored_config.pop(PatchHydra._parameter_allow_full_edit, None) for override_k, override_v in stored_config.items():
# noinspection PyBroadException if override_k.startswith("~"):
try: new_override = override_k
overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or [] else:
except Exception: new_override = "++" + override_k.lstrip("+")
overrides = [] if override_v is not None and override_v != "":
if overrides and not isinstance(overrides, (list, tuple)): new_override += "=" + override_v
overrides = [overrides] overrides.append(new_override)
overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()] PatchHydra._should_delete_overrides = True
overrides = [("+" + o) if (o.startswith("+") and not o.startswith("++")) else o for o in overrides]
else:
# We take care of it inside the _patched_run_job
pass
except Exception: except Exception:
pass pass
@ -114,12 +133,18 @@ class PatchHydra(object):
# store the config # store the config
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if running_remotely(): if not running_remotely():
# we take care of it in the hydra run (where we have access to the overrides) # note that we fetch the overrides from the backend in hydra run when running remotely,
pass # here we just get them from hydra to be stored as configuration/parameters
else:
overrides = config.hydra.overrides.task overrides = config.hydra.overrides.task
stored_config = dict(arg.split('=', 1) for arg in overrides) stored_config = {}
for arg in overrides:
arg = arg.lstrip("+")
if "=" in arg:
k, v = arg.split("=", 1)
stored_config[k] = v
else:
stored_config[arg] = None
stored_config[PatchHydra._parameter_allow_full_edit] = False stored_config[PatchHydra._parameter_allow_full_edit] = False
if PatchHydra._current_task: if PatchHydra._current_task:
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section) PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
@ -127,9 +152,7 @@ class PatchHydra(object):
else: else:
PatchHydra._last_untracked_state['connect'] = dict( PatchHydra._last_untracked_state['connect'] = dict(
mutable=stored_config, name=PatchHydra._parameter_section) mutable=stored_config, name=PatchHydra._parameter_section)
# Maybe ?! remove the overrides section from the Args (we have it here) PatchHydra._should_delete_overrides = True
# But when used with a Pipeline this is the only section we get... so we leave it here anyhow
# PatchHydra._current_task.delete_parameter('Args/overrides')
except Exception: except Exception:
pass pass
@ -176,8 +199,7 @@ class PatchHydra(object):
else: else:
# noinspection PyProtectedMember # noinspection PyProtectedMember
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section) omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
loaded_config = OmegaConf.load(io.StringIO(omega_yaml)) a_config = OmegaConf.load(io.StringIO(omega_yaml))
a_config = OmegaConf.merge(a_config, loaded_config)
PatchHydra._register_omegaconf(a_config, is_read_only=False) PatchHydra._register_omegaconf(a_config, is_read_only=False)
return task_function(a_config, *a_args, **a_kwargs) return task_function(a_config, *a_args, **a_kwargs)
@ -194,10 +216,6 @@ class PatchHydra(object):
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format( description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit) PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
# we should not have the hydra section in the config, but this seems never to be the case anymore.
# config = config.copy()
# config.pop('hydra', None)
configuration = dict( configuration = dict(
name=PatchHydra._config_section, name=PatchHydra._config_section,
description=description, description=description,

View File

@ -28,7 +28,9 @@ class PatchJsonArgParse(object):
_commands_sep = "." _commands_sep = "."
_command_type = "jsonargparse.Command" _command_type = "jsonargparse.Command"
_command_name = "subcommand" _command_name = "subcommand"
_special_fields = ["config", "subcommand"]
_section_name = "Args" _section_name = "Args"
_allow_jsonargparse_overrides = "_allow_config_file_override_from_ui_"
__remote_task_params = {} __remote_task_params = {}
__remote_task_params_dict = {} __remote_task_params_dict = {}
__patched = False __patched = False
@ -60,6 +62,7 @@ class PatchJsonArgParse(object):
return return
args = {} args = {}
args_type = {} args_type = {}
have_config_file = False
for k, v in cls._args.items(): for k, v in cls._args.items():
key_with_section = cls._section_name + cls._args_sep + k key_with_section = cls._section_name + cls._args_sep + k
args[key_with_section] = v args[key_with_section] = v
@ -75,27 +78,18 @@ class PatchJsonArgParse(object):
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)): elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v)) args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v))
args_type[key_with_section] = PatchJsonArgParse.path_type args_type[key_with_section] = PatchJsonArgParse.path_type
have_config_file = True
else: else:
args[key_with_section] = str(v) args[key_with_section] = str(v)
except Exception: except Exception:
pass pass
args, args_type = cls.__delete_config_args(parser, args, args_type, subcommand=subcommand)
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type) cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
if have_config_file:
@classmethod cls._current_task.set_parameter(
def __delete_config_args(cls, parser, args, args_type, subcommand=None): cls._section_name + cls._args_sep + cls._allow_jsonargparse_overrides,
if not parser: False,
return args, args_type description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
paths = PatchJsonArgParse.__get_paths_from_dict(cls._args) )
for path in paths:
args_to_delete = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
for arg_to_delete_key, arg_to_delete_value in args_to_delete.items():
key_with_section = cls._section_name + cls._args_sep + arg_to_delete_key
if key_with_section in args and args[key_with_section] == arg_to_delete_value:
del args[key_with_section]
if key_with_section in args_type:
del args_type[key_with_section]
return args, args_type
@staticmethod @staticmethod
def _adapt_typehints(original_fn, val, *args, **kwargs): def _adapt_typehints(original_fn, val, *args, **kwargs):
@ -103,6 +97,17 @@ class PatchJsonArgParse(object):
return original_fn(val, *args, **kwargs) return original_fn(val, *args, **kwargs)
return original_fn(val, *args, **kwargs) return original_fn(val, *args, **kwargs)
@staticmethod
def __restore_args(parser, args, subcommand=None):
paths = PatchJsonArgParse.__get_paths_from_dict(args)
for path in paths:
args_to_restore = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
for arg_to_restore_key, arg_to_restore_value in args_to_restore.items():
if arg_to_restore_key in PatchJsonArgParse._special_fields:
continue
args[arg_to_restore_key] = arg_to_restore_value
return args
@staticmethod @staticmethod
def _parse_args(original_fn, obj, *args, **kwargs): def _parse_args(original_fn, obj, *args, **kwargs):
if not PatchJsonArgParse._current_task: if not PatchJsonArgParse._current_task:
@ -119,6 +124,13 @@ class PatchJsonArgParse(object):
params_namespace = Namespace() params_namespace = Namespace()
for k, v in params.items(): for k, v in params.items():
params_namespace[k] = v params_namespace[k] = v
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides, True)
if not allow_jsonargparse_overrides_value:
params_namespace = PatchJsonArgParse.__restore_args(
obj,
params_namespace,
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
)
return params_namespace return params_namespace
except Exception as e: except Exception as e:
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e)) logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
@ -210,7 +222,7 @@ class PatchJsonArgParse(object):
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False) parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
if subcommand: if subcommand:
parsed_cfg = { parsed_cfg = {
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v ((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
for k, v in parsed_cfg.items() for k, v in parsed_cfg.items()
} }
return parsed_cfg return parsed_cfg

View File

@ -3215,17 +3215,17 @@ class Dataset(object):
pool.close() pool.close()
def _verify_dataset_folder(self, target_base_folder, part, chunk_selection, max_workers): def _verify_dataset_folder(self, target_base_folder, part, chunk_selection, max_workers):
# type: (Path, Optional[int], Optional[dict], Optional[int]) -> bool # type: (Path, int, dict, int) -> bool
def verify_file_or_link(base_folder, ds_part, ds_chunk_selection, file_entry): def __verify_file_or_link(target_base_folder, file_entry, part=None, chunk_selection=None):
# type: (Path, Optional[int], Optional[dict], FileEntry) -> Optional[bool] # type: (Path, Union[FileEntry, LinkEntry], Optional[int], Optional[dict]) -> bool
# check if we need the file for the requested dataset part # check if we need the file for the requested dataset part
if ds_part is not None: if ds_part is not None:
f_parts = ds_chunk_selection.get(file_entry.parent_dataset_id, []) f_parts = ds_chunk_selection.get(file_entry.parent_dataset_id, [])
# file is not in requested dataset part, no need to check it. # file is not in requested dataset part, no need to check it.
if self._get_chunk_idx_from_artifact_name(file_entry.artifact_name) not in f_parts: if self._get_chunk_idx_from_artifact_name(file_entry.artifact_name) not in f_parts:
return None return True
# check if the local size and the stored size match (faster than comparing hash) # check if the local size and the stored size match (faster than comparing hash)
if (base_folder / file_entry.relative_path).stat().st_size != file_entry.size: if (base_folder / file_entry.relative_path).stat().st_size != file_entry.size:
@ -3237,21 +3237,26 @@ class Dataset(object):
# check dataset file size, if we have a full match no need for parent dataset download / merge # check dataset file size, if we have a full match no need for parent dataset download / merge
verified = True verified = True
# noinspection PyBroadException # noinspection PyBroadException
tp = None
try: try:
futures_ = [] futures_ = []
with ThreadPoolExecutor(max_workers=max_workers) as tp: tp = ThreadPoolExecutor(max_workers=max_workers)
for f in self._dataset_file_entries.values(): for f in self._dataset_file_entries.values():
future = tp.submit(verify_file_or_link, target_base_folder, part, chunk_selection, f) future = tp.submit(__verify_file_or_link, target_base_folder, f, part, chunk_selection)
futures_.append(future) futures_.append(future)
for f in self._dataset_link_entries.values(): for f in self._dataset_link_entries.values():
# don't check whether link is in dataset part, hence None for part and chunk_selection # don't check whether link is in dataset part, hence None for part and chunk_selection
future = tp.submit(verify_file_or_link, target_base_folder, None, None, f) future = tp.submit(__verify_file_or_link, target_base_folder, f, None, None)
futures_.append(future) futures_.append(future)
verified = all(f.result() is not False for f in futures_) verified = all(f.result() for f in futures_)
except Exception: except Exception:
verified = False verified = False
finally:
if tp is not None:
# we already have our result, close all pending checks (improves performance when verified==False)
tp.shutdown(cancel_futures=True)
return verified return verified

View File

@ -362,8 +362,8 @@ class BaseModel(object):
self._task_connect_name = None self._task_connect_name = None
self._set_task(task) self._set_task(task)
def get_weights(self, raise_on_error=False, force_download=False): def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
# type: (bool, bool) -> str # type: (bool, bool, bool) -> str
""" """
Download the base model and return the locally stored filename. Download the base model and return the locally stored filename.
@ -373,17 +373,19 @@ class BaseModel(object):
:param bool force_download: If True, the base model will be downloaded, :param bool force_download: If True, the base model will be downloaded,
even if the base model is already cached. even if the base model is already cached.
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
:return: The locally stored file. :return: The locally stored file.
""" """
# download model (synchronously) and return local file # download model (synchronously) and return local file
return self._get_base_model().download_model_weights( return self._get_base_model().download_model_weights(
raise_on_error=raise_on_error, force_download=force_download raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
) )
def get_weights_package( def get_weights_package(
self, return_path=False, raise_on_error=False, force_download=False self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
): ):
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]] # type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
""" """
Download the base model package into a temporary directory (extract the files), or return a list of the Download the base model package into a temporary directory (extract the files), or return a list of the
locally stored filenames. locally stored filenames.
@ -399,6 +401,8 @@ class BaseModel(object):
:param bool force_download: If True, the base artifact will be downloaded, :param bool force_download: If True, the base artifact will be downloaded,
even if the artifact is already cached. even if the artifact is already cached.
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
:return: The model weights, or a list of the locally stored filenames. :return: The model weights, or a list of the locally stored filenames.
if raise_on_error=False, returns None on error. if raise_on_error=False, returns None on error.
""" """
@ -407,40 +411,21 @@ class BaseModel(object):
raise ValueError("Model is not packaged") raise ValueError("Model is not packaged")
# download packaged model # download packaged model
packed_file = self.get_weights( model_path = self.get_weights(
raise_on_error=raise_on_error, force_download=force_download raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
) )
if not packed_file: if not model_path:
if raise_on_error: if raise_on_error:
raise ValueError( raise ValueError(
"Model package '{}' could not be downloaded".format(self.url) "Model package '{}' could not be downloaded".format(self.url)
) )
return None return None
# unpack
target_folder = mkdtemp(prefix="model_package_")
if not target_folder:
raise ValueError(
"cannot create temporary directory for packed weight files"
)
for func in (zipfile.ZipFile, tarfile.open):
try:
obj = func(packed_file)
obj.extractall(path=target_folder)
break
except (zipfile.BadZipfile, tarfile.ReadError):
pass
else:
raise ValueError(
"cannot extract files from packaged model at %s", packed_file
)
if return_path: if return_path:
return target_folder return model_path
target_files = list(Path(target_folder).glob("*")) target_files = list(Path(model_path).glob("*"))
return target_files return target_files
def report_scalar(self, title, series, value, iteration): def report_scalar(self, title, series, value, iteration):
@ -1374,17 +1359,18 @@ class Model(BaseModel):
self._base_model = None self._base_model = None
def get_local_copy( def get_local_copy(
self, extract_archive=True, raise_on_error=False, force_download=False self, extract_archive=None, raise_on_error=False, force_download=False
): ):
# type: (bool, bool, bool) -> str # type: (Optional[bool], bool, bool) -> str
""" """
Retrieve a valid link to the model file(s). Retrieve a valid link to the model file(s).
If the model URL is a file system link, it will be returned directly. If the model URL is a file system link, it will be returned directly.
If the model URL points to a remote location (http/s3/gs etc.), If the model URL points to a remote location (http/s3/gs etc.),
it will download the file(s) and return the temporary location of the downloaded model. it will download the file(s) and return the temporary location of the downloaded model.
:param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder) :param bool extract_archive: If True, the local copy will be extracted if possible. If False,
The returned path will be a temporary folder containing the archive content the local copy will not be extracted. If None (default), the downloaded file will be extracted
if the model is a package.
:param bool raise_on_error: If True, and the artifact could not be downloaded, :param bool raise_on_error: If True, and the artifact could not be downloaded,
raise ValueError, otherwise return None on failure and output log warning. raise ValueError, otherwise return None on failure and output log warning.
:param bool force_download: If True, the artifact will be downloaded, :param bool force_download: If True, the artifact will be downloaded,
@ -1392,14 +1378,17 @@ class Model(BaseModel):
:return: A local path to the model (or a downloaded copy of it). :return: A local path to the model (or a downloaded copy of it).
""" """
if extract_archive and self._is_package(): if self._is_package():
return self.get_weights_package( return self.get_weights_package(
return_path=True, return_path=True,
raise_on_error=raise_on_error, raise_on_error=raise_on_error,
force_download=force_download, force_download=force_download,
extract_archive=True if extract_archive is None else extract_archive
) )
return self.get_weights( return self.get_weights(
raise_on_error=raise_on_error, force_download=force_download raise_on_error=raise_on_error,
force_download=force_download,
extract_archive=False if extract_archive is None else extract_archive
) )
def _get_base_model(self): def _get_base_model(self):

View File

@ -739,6 +739,8 @@ class Task(_Task):
if argparser_parseargs_called(): if argparser_parseargs_called():
for parser, parsed_args in get_argparser_last_args(): for parser, parsed_args in get_argparser_last_args():
task._connect_argparse(parser=parser, parsed_args=parsed_args) task._connect_argparse(parser=parser, parsed_args=parsed_args)
PatchHydra.delete_overrides()
elif argparser_parseargs_called(): elif argparser_parseargs_called():
# actually we have nothing to do, in remote running, the argparser will ignore # actually we have nothing to do, in remote running, the argparser will ignore
# all non argparser parameters, only caveat if parameter connected with the same name # all non argparser parameters, only caveat if parameter connected with the same name
@ -2604,15 +2606,15 @@ class Task(_Task):
self.reload() self.reload()
script = self.data.script script = self.data.script
if repository is not None: if repository is not None:
script.repository = str(repository) or None script.repository = str(repository)
if branch is not None: if branch is not None:
script.branch = str(branch) or None script.branch = str(branch)
if script.tag: if script.tag:
script.tag = None script.tag = None
if commit is not None: if commit is not None:
script.version_num = str(commit) or None script.version_num = str(commit)
if diff is not None: if diff is not None:
script.diff = str(diff) or None script.diff = str(diff)
if working_dir is not None: if working_dir is not None:
script.working_dir = str(working_dir) script.working_dir = str(working_dir)
if entry_point is not None: if entry_point is not None:

View File

@ -1,6 +1,7 @@
import itertools import itertools
import json import json
from copy import copy from copy import copy
from logging import getLogger
import six import six
import yaml import yaml
@ -132,16 +133,21 @@ def cast_basic_type(value, type_str):
parts = type_str.split('/') parts = type_str.split('/')
# nested = len(parts) > 1 # nested = len(parts) > 1
if parts[0] in ('list', 'tuple'): if parts[0] in ("list", "tuple", "dict"):
v = '[' + value.lstrip('[(').rstrip('])') + ']' # noinspection PyBroadException
v = yaml.load(v, Loader=yaml.SafeLoader)
return basic_types.get(parts[0])(v)
elif parts[0] in ('dict', ):
try: try:
return json.loads(value) # lists/tuple/dicts should be json loadable
return basic_types.get(parts[0])(json.loads(value))
except Exception: except Exception:
pass # noinspection PyBroadException
return value try:
# fallback to legacy basic type loading
v = '[' + value.lstrip('[(').rstrip('])') + ']'
v = yaml.load(v, Loader=yaml.SafeLoader)
return basic_types.get(parts[0])(v)
except Exception:
getLogger().warning("Could not cast `{}` to basic type. Returning it as `str`".format(value))
return value
t = basic_types.get(str(type_str).lower().strip(), False) t = basic_types.get(str(type_str).lower().strip(), False)
if t is not False: if t is not False:

View File

@ -69,8 +69,7 @@
"Input the following parameters:\n", "Input the following parameters:\n",
"* `name` - Name of the PipelineController task which will created\n", "* `name` - Name of the PipelineController task which will created\n",
"* `project` - Project which the controller will be associated with\n", "* `project` - Project which the controller will be associated with\n",
"* `version` - Pipeline's version number. This version allows to uniquely identify the pipeline template execution.\n", "* `version` - Pipeline's version number. This version allows to uniquely identify the pipeline template execution. If not set, find the pipeline's latest version and increment it. If no such version is found, defaults to `1.0.0`.\n",
"* `auto_version_bump` (default True) - if the same pipeline version already exists (with any difference from the current one), the current pipeline version will be bumped to a new version (e.g. 1.0.0 -> 1.0.1 , 1.2 -> 1.3, 10 -> 11)\n",
" " " "
] ]
}, },