mirror of
https://github.com/clearml/clearml
synced 2025-05-10 23:50:39 +00:00
Merge branch 'master' of https://github.com/allegroai/clearml
This commit is contained in:
commit
52c47b5551
@ -101,6 +101,7 @@ class PipelineController(object):
|
||||
explicit_docker_image = attrib(type=str, default=None) # The Docker image the node uses, specified at creation
|
||||
recursively_parse_parameters = attrib(type=bool, default=False) # if True, recursively parse parameters in
|
||||
# lists, dicts, or tuples
|
||||
output_uri = attrib(type=Union[bool, str], default=None) # The default location for output models and other artifacts
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if self.parents is None:
|
||||
@ -134,6 +135,26 @@ class PipelineController(object):
|
||||
new_copy.task_factory_func = self.task_factory_func
|
||||
return new_copy
|
||||
|
||||
def set_job_ended(self):
|
||||
if self.job_ended:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.job.task.reload()
|
||||
self.job_ended = self.job_started + self.job.task.data.active_duration
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def set_job_started(self):
|
||||
if self.job_started:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.job_started = self.job.task.data.started.timestamp()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name, # type: str
|
||||
@ -155,7 +176,8 @@ class PipelineController(object):
|
||||
repo_commit=None, # type: Optional[str]
|
||||
always_create_from_code=True, # type: bool
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
):
|
||||
# type: (...) -> None
|
||||
"""
|
||||
@ -242,6 +264,9 @@ class PipelineController(object):
|
||||
def deserialize(bytes_):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
"""
|
||||
if auto_version_bump is not None:
|
||||
warnings.warn("PipelineController.auto_version_bump is deprecated. It will be ignored", DeprecationWarning)
|
||||
@ -316,6 +341,9 @@ class PipelineController(object):
|
||||
project_id=self._task.project, system_tags=self._project_system_tags)
|
||||
|
||||
self._task.set_system_tags((self._task.get_system_tags() or []) + [self._tag])
|
||||
if output_uri is not None:
|
||||
self._task.output_uri = output_uri
|
||||
self._output_uri = output_uri
|
||||
self._task.set_base_docker(
|
||||
docker_image=docker, docker_arguments=docker_args, docker_setup_bash_script=docker_bash_setup_script
|
||||
)
|
||||
@ -387,7 +415,8 @@ class PipelineController(object):
|
||||
base_task_factory=None, # type: Optional[Callable[[PipelineController.Node], Task]]
|
||||
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
|
||||
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
|
||||
recursively_parse_parameters=False # type: bool
|
||||
recursively_parse_parameters=False, # type: bool
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -529,7 +558,9 @@ class PipelineController(object):
|
||||
previous_status # type: str
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
:param output_uri: The storage / output url for this step. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
|
||||
:return: True if successful
|
||||
"""
|
||||
@ -588,6 +619,7 @@ class PipelineController(object):
|
||||
monitor_metrics=monitor_metrics or [],
|
||||
monitor_artifacts=monitor_artifacts or [],
|
||||
monitor_models=monitor_models or [],
|
||||
output_uri=self._output_uri if output_uri is None else output_uri
|
||||
)
|
||||
self._retries[name] = 0
|
||||
self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \
|
||||
@ -632,7 +664,8 @@ class PipelineController(object):
|
||||
cache_executed_step=False, # type: bool
|
||||
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
|
||||
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
|
||||
tags=None # type: Optional[Union[str, Sequence[str]]]
|
||||
tags=None, # type: Optional[Union[str, Sequence[str]]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -799,6 +832,8 @@ class PipelineController(object):
|
||||
:param tags: A list of tags for the specific pipeline step.
|
||||
When executing a Pipeline remotely
|
||||
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
|
||||
:param output_uri: The storage / output url for this step. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
|
||||
:return: True if successful
|
||||
"""
|
||||
@ -838,7 +873,8 @@ class PipelineController(object):
|
||||
cache_executed_step=cache_executed_step,
|
||||
retry_on_failure=retry_on_failure,
|
||||
status_change_callback=status_change_callback,
|
||||
tags=tags
|
||||
tags=tags,
|
||||
output_uri=output_uri
|
||||
)
|
||||
|
||||
def start(
|
||||
@ -1014,8 +1050,8 @@ class PipelineController(object):
|
||||
return cls._get_pipeline_task().get_logger()
|
||||
|
||||
@classmethod
|
||||
def upload_model(cls, model_name, model_local_path):
|
||||
# type: (str, str) -> OutputModel
|
||||
def upload_model(cls, model_name, model_local_path, upload_uri=None):
|
||||
# type: (str, str, Optional[str]) -> OutputModel
|
||||
"""
|
||||
Upload (add) a model to the main Pipeline Task object.
|
||||
This function can be called from any pipeline component to directly add models into the main pipeline Task
|
||||
@ -1028,12 +1064,16 @@ class PipelineController(object):
|
||||
:param model_local_path: Path to the local model file or directory to be uploaded.
|
||||
If a local directory is provided the content of the folder (recursively) will be
|
||||
packaged into a zip file and uploaded
|
||||
:param upload_uri: The URI of the storage destination for model weights upload. The default value
|
||||
is the previously used URI.
|
||||
|
||||
:return: The uploaded OutputModel
|
||||
"""
|
||||
task = cls._get_pipeline_task()
|
||||
model_name = str(model_name)
|
||||
model_local_path = Path(model_local_path)
|
||||
out_model = OutputModel(task=task, name=model_name)
|
||||
out_model.update_weights(weights_filename=model_local_path.as_posix())
|
||||
out_model.update_weights(weights_filename=model_local_path.as_posix(), upload_uri=upload_uri)
|
||||
return out_model
|
||||
|
||||
@classmethod
|
||||
@ -1457,7 +1497,7 @@ class PipelineController(object):
|
||||
self, docker, docker_args, docker_bash_setup_script,
|
||||
function, function_input_artifacts, function_kwargs, function_return,
|
||||
auto_connect_frameworks, auto_connect_arg_parser,
|
||||
packages, project_name, task_name, task_type, repo, branch, commit, helper_functions
|
||||
packages, project_name, task_name, task_type, repo, branch, commit, helper_functions, output_uri=None
|
||||
):
|
||||
task_definition = CreateFromFunction.create_task_from_function(
|
||||
a_function=function,
|
||||
@ -1476,7 +1516,7 @@ class PipelineController(object):
|
||||
docker=docker,
|
||||
docker_args=docker_args,
|
||||
docker_bash_setup_script=docker_bash_setup_script,
|
||||
output_uri=None,
|
||||
output_uri=output_uri,
|
||||
helper_functions=helper_functions,
|
||||
dry_run=True,
|
||||
task_template_header=self._task_template_header,
|
||||
@ -1591,7 +1631,7 @@ class PipelineController(object):
|
||||
'target_project': self._target_project,
|
||||
}
|
||||
pipeline_dag = self._serialize()
|
||||
|
||||
|
||||
# serialize pipeline state
|
||||
if self._task and self._auto_connect_task:
|
||||
# check if we are either running locally or that we are running remotely,
|
||||
@ -1631,6 +1671,7 @@ class PipelineController(object):
|
||||
self._runtime_property_hash: "{}:{}".format(pipeline_hash, self._version),
|
||||
"version": self._version
|
||||
})
|
||||
self._task.set_user_properties(version=self._version)
|
||||
else:
|
||||
self._task.connect_configuration(pipeline_dag, name=self._config_section)
|
||||
connected_args = set()
|
||||
@ -1927,7 +1968,8 @@ class PipelineController(object):
|
||||
cache_executed_step=False, # type: bool
|
||||
retry_on_failure=None, # type: Optional[Union[int, Callable[[PipelineController, PipelineController.Node, int], bool]]] # noqa
|
||||
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
|
||||
tags=None # type: Optional[Union[str, Sequence[str]]]
|
||||
tags=None, # type: Optional[Union[str, Sequence[str]]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
):
|
||||
# type: (...) -> bool
|
||||
"""
|
||||
@ -2094,6 +2136,8 @@ class PipelineController(object):
|
||||
:param tags: A list of tags for the specific pipeline step.
|
||||
When executing a Pipeline remotely
|
||||
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
|
||||
:param output_uri: The storage / output url for this step. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
|
||||
:return: True if successful
|
||||
"""
|
||||
@ -2107,6 +2151,9 @@ class PipelineController(object):
|
||||
|
||||
self._verify_node_name(name)
|
||||
|
||||
if output_uri is None:
|
||||
output_uri = self._output_uri
|
||||
|
||||
function_input_artifacts = {}
|
||||
# go over function_kwargs, split it into string and input artifacts
|
||||
for k, v in function_kwargs.items():
|
||||
@ -2145,7 +2192,7 @@ class PipelineController(object):
|
||||
function_input_artifacts, function_kwargs, function_return,
|
||||
auto_connect_frameworks, auto_connect_arg_parser,
|
||||
packages, project_name, task_name,
|
||||
task_type, repo, repo_branch, repo_commit, helper_functions)
|
||||
task_type, repo, repo_branch, repo_commit, helper_functions, output_uri=output_uri)
|
||||
|
||||
elif self._task.running_locally() or self._task.get_configuration_object(name=name) is None:
|
||||
project_name = project_name or self._get_target_project() or self._task.get_project_name()
|
||||
@ -2155,7 +2202,7 @@ class PipelineController(object):
|
||||
function_input_artifacts, function_kwargs, function_return,
|
||||
auto_connect_frameworks, auto_connect_arg_parser,
|
||||
packages, project_name, task_name,
|
||||
task_type, repo, repo_branch, repo_commit, helper_functions)
|
||||
task_type, repo, repo_branch, repo_commit, helper_functions, output_uri=output_uri)
|
||||
# update configuration with the task definitions
|
||||
# noinspection PyProtectedMember
|
||||
self._task._set_configuration(
|
||||
@ -2180,6 +2227,9 @@ class PipelineController(object):
|
||||
if tags:
|
||||
a_task.add_tags(tags)
|
||||
|
||||
if output_uri is not None:
|
||||
a_task.output_uri = output_uri
|
||||
|
||||
return a_task
|
||||
|
||||
self._nodes[name] = self.Node(
|
||||
@ -2195,7 +2245,8 @@ class PipelineController(object):
|
||||
monitor_metrics=monitor_metrics,
|
||||
monitor_models=monitor_models,
|
||||
job_code_section=job_code_section,
|
||||
explicit_docker_image=docker
|
||||
explicit_docker_image=docker,
|
||||
output_uri=output_uri
|
||||
)
|
||||
self._retries[name] = 0
|
||||
self._retries_callbacks[name] = retry_on_failure if callable(retry_on_failure) else \
|
||||
@ -2284,13 +2335,14 @@ class PipelineController(object):
|
||||
disable_clone_task=disable_clone_task,
|
||||
task_overrides=task_overrides,
|
||||
allow_caching=node.cache_executed_step,
|
||||
output_uri=node.output_uri,
|
||||
**extra_args
|
||||
)
|
||||
except Exception:
|
||||
self._pipeline_task_status_failed = True
|
||||
raise
|
||||
|
||||
node.job_started = time()
|
||||
node.job_started = None
|
||||
node.job_ended = None
|
||||
node.job_type = str(node.job.task.task_type)
|
||||
|
||||
@ -2546,6 +2598,8 @@ class PipelineController(object):
|
||||
"""
|
||||
previous_status = node.status
|
||||
|
||||
if node.job and node.job.is_running():
|
||||
node.set_job_started()
|
||||
update_job_ended = node.job_started and not node.job_ended
|
||||
|
||||
if node.executed is not None:
|
||||
@ -2582,7 +2636,7 @@ class PipelineController(object):
|
||||
node.status = "pending"
|
||||
|
||||
if update_job_ended and node.status in ("aborted", "failed", "completed"):
|
||||
node.job_ended = time()
|
||||
node.set_job_ended()
|
||||
|
||||
if (
|
||||
previous_status is not None
|
||||
@ -2679,7 +2733,7 @@ class PipelineController(object):
|
||||
if node_failed and self._abort_running_steps_on_failure and not node.continue_on_fail:
|
||||
nodes_failed_stop_pipeline.append(node.name)
|
||||
elif node.timeout:
|
||||
started = node.job.task.data.started
|
||||
node.set_job_started()
|
||||
if (datetime.now().astimezone(started.tzinfo) - started).total_seconds() > node.timeout:
|
||||
node.job.abort()
|
||||
completed_jobs.append(j)
|
||||
@ -3261,7 +3315,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_branch=None, # type: Optional[str]
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
):
|
||||
# type: (...) -> ()
|
||||
"""
|
||||
@ -3341,6 +3396,9 @@ class PipelineDecorator(PipelineController):
|
||||
def deserialize(bytes_):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
"""
|
||||
super(PipelineDecorator, self).__init__(
|
||||
name=name,
|
||||
@ -3361,7 +3419,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_commit=repo_commit,
|
||||
always_create_from_code=False,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
)
|
||||
|
||||
# if we are in eager execution, make sure parent class knows it
|
||||
@ -3583,7 +3642,7 @@ class PipelineDecorator(PipelineController):
|
||||
function, function_input_artifacts, function_kwargs, function_return,
|
||||
auto_connect_frameworks, auto_connect_arg_parser,
|
||||
packages, project_name, task_name, task_type, repo, branch, commit,
|
||||
helper_functions
|
||||
helper_functions, output_uri=None
|
||||
):
|
||||
def sanitize(function_source):
|
||||
matched = re.match(r"[\s]*@[\w]*.component[\s\\]*\(", function_source)
|
||||
@ -3621,7 +3680,7 @@ class PipelineDecorator(PipelineController):
|
||||
docker=docker,
|
||||
docker_args=docker_args,
|
||||
docker_bash_setup_script=docker_bash_setup_script,
|
||||
output_uri=None,
|
||||
output_uri=output_uri,
|
||||
helper_functions=helper_functions,
|
||||
dry_run=True,
|
||||
task_template_header=self._task_template_header,
|
||||
@ -3703,7 +3762,8 @@ class PipelineDecorator(PipelineController):
|
||||
pre_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, dict], bool]] # noqa
|
||||
post_execute_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node], None]] # noqa
|
||||
status_change_callback=None, # type: Optional[Callable[[PipelineController, PipelineController.Node, str], None]] # noqa
|
||||
tags=None # type: Optional[Union[str, Sequence[str]]]
|
||||
tags=None, # type: Optional[Union[str, Sequence[str]]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
):
|
||||
# type: (...) -> Callable
|
||||
"""
|
||||
@ -3841,6 +3901,8 @@ class PipelineDecorator(PipelineController):
|
||||
:param tags: A list of tags for the specific pipeline step.
|
||||
When executing a Pipeline remotely
|
||||
(i.e. launching the pipeline from the UI/enqueuing it), this method has no effect.
|
||||
:param output_uri: The storage / output url for this step. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
|
||||
:return: function wrapper
|
||||
"""
|
||||
@ -3883,7 +3945,8 @@ class PipelineDecorator(PipelineController):
|
||||
pre_execute_callback=pre_execute_callback,
|
||||
post_execute_callback=post_execute_callback,
|
||||
status_change_callback=status_change_callback,
|
||||
tags=tags
|
||||
tags=tags,
|
||||
output_uri=output_uri
|
||||
)
|
||||
|
||||
if cls._singleton:
|
||||
@ -4109,7 +4172,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_branch=None, # type: Optional[str]
|
||||
repo_commit=None, # type: Optional[str]
|
||||
artifact_serialization_function=None, # type: Optional[Callable[[Any], Union[bytes, bytearray]]]
|
||||
artifact_deserialization_function=None # type: Optional[Callable[[bytes], Any]]
|
||||
artifact_deserialization_function=None, # type: Optional[Callable[[bytes], Any]]
|
||||
output_uri=None # type: Optional[Union[str, bool]]
|
||||
):
|
||||
# type: (...) -> Callable
|
||||
"""
|
||||
@ -4220,6 +4284,9 @@ class PipelineDecorator(PipelineController):
|
||||
def deserialize(bytes_):
|
||||
import dill
|
||||
return dill.loads(bytes_)
|
||||
:param output_uri: The storage / output url for this pipeline. This is the default location for output
|
||||
models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
The `output_uri` of this pipeline's steps will default to this value.
|
||||
"""
|
||||
def decorator_wrap(func):
|
||||
|
||||
@ -4265,7 +4332,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_branch=repo_branch,
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
)
|
||||
ret_val = func(**pipeline_kwargs)
|
||||
LazyEvalWrapper.trigger_all_remote_references()
|
||||
@ -4316,7 +4384,8 @@ class PipelineDecorator(PipelineController):
|
||||
repo_branch=repo_branch,
|
||||
repo_commit=repo_commit,
|
||||
artifact_serialization_function=artifact_serialization_function,
|
||||
artifact_deserialization_function=artifact_deserialization_function
|
||||
artifact_deserialization_function=artifact_deserialization_function,
|
||||
output_uri=output_uri
|
||||
)
|
||||
|
||||
a_pipeline._args_map = args_map or {}
|
||||
|
@ -522,6 +522,7 @@ class ClearmlJob(BaseJob):
|
||||
disable_clone_task=False, # type: bool
|
||||
allow_caching=False, # type: bool
|
||||
target_project=None, # type: Optional[str]
|
||||
output_uri=None, # type: Optional[Union[str, bool]]
|
||||
**kwargs # type: Any
|
||||
):
|
||||
# type: (...) -> ()
|
||||
@ -545,6 +546,8 @@ class ClearmlJob(BaseJob):
|
||||
If True, use the base_task_id directly (base-task must be in draft-mode / created),
|
||||
:param bool allow_caching: If True, check if we have a previously executed Task with the same specification.
|
||||
If we do, use it and set internal is_cached flag. Default False (always create new Task).
|
||||
:param Union[str, bool] output_uri: The storage / output url for this job. This is the default location for
|
||||
output models and other artifacts. Check Task.init reference docs for more info (output_uri is a parameter).
|
||||
:param str target_project: Optional, Set the target project name to create the cloned Task in.
|
||||
"""
|
||||
super(ClearmlJob, self).__init__()
|
||||
@ -660,6 +663,8 @@ class ClearmlJob(BaseJob):
|
||||
# noinspection PyProtectedMember
|
||||
self.task._edit(**sections)
|
||||
|
||||
if output_uri is not None:
|
||||
self.task.output_uri = output_uri
|
||||
self._set_task_cache_hash(self.task, task_hash)
|
||||
self.task_started = False
|
||||
self._worker = None
|
||||
|
@ -78,13 +78,15 @@ class InterfaceBase(SessionInterface):
|
||||
except MaxRequestSizeError as e:
|
||||
res = CallResult(meta=ResponseMeta.from_raw_data(status_code=400, text=str(e)))
|
||||
error_msg = 'Failed sending: %s' % str(e)
|
||||
except requests.exceptions.ConnectionError:
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
# We couldn't send the request for more than the retries times configure in the api configuration file,
|
||||
# so we will end the loop and raise the exception to the upper level.
|
||||
# Notice: this is a connectivity error and not a backend error.
|
||||
if raise_on_errors:
|
||||
raise
|
||||
# if raise_on_errors:
|
||||
# raise
|
||||
res = None
|
||||
if log and num_retries >= cls._num_retry_warning_display:
|
||||
log.warning('Retrying, previous request failed %s: %s' % (str(type(req)), str(e)))
|
||||
except cls._JSON_EXCEPTION as e:
|
||||
if log:
|
||||
log.error(
|
||||
|
@ -527,7 +527,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
return False
|
||||
return bool(self.data.ready)
|
||||
|
||||
def download_model_weights(self, raise_on_error=False, force_download=False):
|
||||
def download_model_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
|
||||
"""
|
||||
Download the model weights into a local file in our cache
|
||||
|
||||
@ -537,6 +537,8 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
:param bool force_download: If True, the base artifact will be downloaded,
|
||||
even if the artifact is already cached.
|
||||
|
||||
:param bool extract_archive: If True, unzip the downloaded file if possible
|
||||
|
||||
:return: a local path to a downloaded copy of the model
|
||||
"""
|
||||
uri = self.data.uri
|
||||
@ -556,7 +558,7 @@ class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
Model._local_model_to_id_uri.pop(dl_file, None)
|
||||
|
||||
local_download = StorageManager.get_local_copy(
|
||||
uri, extract_archive=False, force_download=force_download
|
||||
uri, extract_archive=extract_archive, force_download=force_download
|
||||
)
|
||||
|
||||
# save local model, so we can later query what was the original one
|
||||
|
@ -178,7 +178,7 @@ class CreateAndPopulate(object):
|
||||
project=Task.get_project_id(self.project_name),
|
||||
type=str(self.task_type or Task.TaskTypes.training),
|
||||
) # type: dict
|
||||
if self.output_uri:
|
||||
if self.output_uri is not None:
|
||||
task_state['output'] = dict(destination=self.output_uri)
|
||||
else:
|
||||
task_state = dict(script={})
|
||||
@ -391,7 +391,7 @@ class CreateAndPopulate(object):
|
||||
return task
|
||||
|
||||
def _set_output_uri(self, task):
|
||||
if self.output_uri:
|
||||
if self.output_uri is not None:
|
||||
try:
|
||||
task.output_uri = self.output_uri
|
||||
except ValueError:
|
||||
|
@ -1344,12 +1344,14 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
params = self.get_parameters(cast=cast)
|
||||
return params.get(name, default)
|
||||
|
||||
def delete_parameter(self, name):
|
||||
def delete_parameter(self, name, force=False):
|
||||
# type: (str) -> bool
|
||||
"""
|
||||
Delete a parameter by its full name Section/name.
|
||||
|
||||
:param name: Parameter name in full, i.e. Section/name. For example, 'Args/batch_size'
|
||||
:param force: If set to True then both new and running task hyper params can be deleted.
|
||||
Otherwise only the new task ones. Default is False
|
||||
:return: True if the parameter was deleted successfully
|
||||
"""
|
||||
if not Session.check_min_api_version('2.9'):
|
||||
@ -1360,7 +1362,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
with self._edit_lock:
|
||||
paramkey = tasks.ParamKey(section=name.split('/', 1)[0], name=name.split('/', 1)[1])
|
||||
res = self.send(tasks.DeleteHyperParamsRequest(
|
||||
task=self.id, hyperparams=[paramkey]), raise_on_errors=False)
|
||||
task=self.id, hyperparams=[paramkey], force=force), raise_on_errors=False)
|
||||
self.reload()
|
||||
|
||||
return res.ok()
|
||||
|
@ -15,6 +15,8 @@ class PatchHydra(object):
|
||||
_config_section = 'OmegaConf'
|
||||
_parameter_section = 'Hydra'
|
||||
_parameter_allow_full_edit = '_allow_omegaconf_edit_'
|
||||
_should_delete_overrides = False
|
||||
_overrides_section = "Args/overrides"
|
||||
|
||||
@classmethod
|
||||
def patch_hydra(cls):
|
||||
@ -42,6 +44,12 @@ class PatchHydra(object):
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def delete_overrides(cls):
|
||||
if not cls._should_delete_overrides or not cls._current_task:
|
||||
return
|
||||
cls._current_task.delete_parameter(cls._overrides_section, force=True)
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task):
|
||||
# set current Task before patching
|
||||
@ -50,11 +58,24 @@ class PatchHydra(object):
|
||||
return
|
||||
if PatchHydra.patch_hydra():
|
||||
# check if we have an untracked state, store it.
|
||||
if PatchHydra._last_untracked_state.get('connect'):
|
||||
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state['connect'])
|
||||
if PatchHydra._last_untracked_state.get('_set_configuration'):
|
||||
if PatchHydra._last_untracked_state.get("connect"):
|
||||
if PatchHydra._parameter_allow_full_edit in PatchHydra._last_untracked_state["connect"].get("mutable", {}):
|
||||
allow_omegaconf_edit_section = PatchHydra._parameter_section + "/" + PatchHydra._parameter_allow_full_edit
|
||||
allow_omegaconf_edit_section_val = PatchHydra._last_untracked_state["connect"]["mutable"].pop(
|
||||
PatchHydra._parameter_allow_full_edit
|
||||
)
|
||||
PatchHydra._current_task.set_parameter(
|
||||
allow_omegaconf_edit_section,
|
||||
allow_omegaconf_edit_section_val,
|
||||
description="If True, the `{}` parameter section will be completely ignored. The OmegaConf will instead be pulled from the `{}` section".format(
|
||||
PatchHydra._parameter_section,
|
||||
PatchHydra._config_section
|
||||
)
|
||||
)
|
||||
PatchHydra._current_task.connect(**PatchHydra._last_untracked_state["connect"])
|
||||
if PatchHydra._last_untracked_state.get("_set_configuration"):
|
||||
# noinspection PyProtectedMember
|
||||
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state['_set_configuration'])
|
||||
PatchHydra._current_task._set_configuration(**PatchHydra._last_untracked_state["_set_configuration"])
|
||||
PatchHydra._last_untracked_state = {}
|
||||
else:
|
||||
# if patching failed set it to None
|
||||
@ -63,36 +84,34 @@ class PatchHydra(object):
|
||||
@staticmethod
|
||||
def _patched_hydra_run(self, config_name, task_function, overrides, *args, **kwargs):
|
||||
PatchHydra._allow_omegaconf_edit = False
|
||||
if not running_remotely():
|
||||
return PatchHydra._original_hydra_run(self, config_name, task_function, overrides, *args, **kwargs)
|
||||
|
||||
# store the config
|
||||
# get the parameters from the backend
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if running_remotely():
|
||||
if not PatchHydra._current_task:
|
||||
from ..task import Task
|
||||
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
|
||||
# get the _parameter_allow_full_edit casted back to boolean
|
||||
connected_config = dict()
|
||||
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
||||
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
|
||||
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
# get all the overrides
|
||||
full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
|
||||
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
|
||||
if k.startswith(PatchHydra._parameter_section+'/')}
|
||||
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
overrides = yaml.safe_load(full_parameters.get("Args/overrides", "")) or []
|
||||
except Exception:
|
||||
overrides = []
|
||||
if overrides and not isinstance(overrides, (list, tuple)):
|
||||
overrides = [overrides]
|
||||
overrides += ['{}={}'.format(k, v) for k, v in stored_config.items()]
|
||||
overrides = [("+" + o) if (o.startswith("+") and not o.startswith("++")) else o for o in overrides]
|
||||
else:
|
||||
# We take care of it inside the _patched_run_job
|
||||
pass
|
||||
if not PatchHydra._current_task:
|
||||
from ..task import Task
|
||||
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
|
||||
# get the _parameter_allow_full_edit casted back to boolean
|
||||
connected_config = {}
|
||||
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
||||
PatchHydra._current_task.connect(connected_config, name=PatchHydra._parameter_section)
|
||||
PatchHydra._allow_omegaconf_edit = connected_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
# get all the overrides
|
||||
full_parameters = PatchHydra._current_task.get_parameters(backwards_compatibility=False)
|
||||
stored_config = {k[len(PatchHydra._parameter_section)+1:]: v for k, v in full_parameters.items()
|
||||
if k.startswith(PatchHydra._parameter_section+'/')}
|
||||
stored_config.pop(PatchHydra._parameter_allow_full_edit, None)
|
||||
for override_k, override_v in stored_config.items():
|
||||
if override_k.startswith("~"):
|
||||
new_override = override_k
|
||||
else:
|
||||
new_override = "++" + override_k.lstrip("+")
|
||||
if override_v is not None and override_v != "":
|
||||
new_override += "=" + override_v
|
||||
overrides.append(new_override)
|
||||
PatchHydra._should_delete_overrides = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -114,12 +133,18 @@ class PatchHydra(object):
|
||||
# store the config
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if running_remotely():
|
||||
# we take care of it in the hydra run (where we have access to the overrides)
|
||||
pass
|
||||
else:
|
||||
if not running_remotely():
|
||||
# note that we fetch the overrides from the backend in hydra run when running remotely,
|
||||
# here we just get them from hydra to be stored as configuration/parameters
|
||||
overrides = config.hydra.overrides.task
|
||||
stored_config = dict(arg.split('=', 1) for arg in overrides)
|
||||
stored_config = {}
|
||||
for arg in overrides:
|
||||
arg = arg.lstrip("+")
|
||||
if "=" in arg:
|
||||
k, v = arg.split("=", 1)
|
||||
stored_config[k] = v
|
||||
else:
|
||||
stored_config[arg] = None
|
||||
stored_config[PatchHydra._parameter_allow_full_edit] = False
|
||||
if PatchHydra._current_task:
|
||||
PatchHydra._current_task.connect(stored_config, name=PatchHydra._parameter_section)
|
||||
@ -127,9 +152,7 @@ class PatchHydra(object):
|
||||
else:
|
||||
PatchHydra._last_untracked_state['connect'] = dict(
|
||||
mutable=stored_config, name=PatchHydra._parameter_section)
|
||||
# Maybe ?! remove the overrides section from the Args (we have it here)
|
||||
# But when used with a Pipeline this is the only section we get... so we leave it here anyhow
|
||||
# PatchHydra._current_task.delete_parameter('Args/overrides')
|
||||
PatchHydra._should_delete_overrides = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -176,8 +199,7 @@ class PatchHydra(object):
|
||||
else:
|
||||
# noinspection PyProtectedMember
|
||||
omega_yaml = PatchHydra._current_task._get_configuration_text(PatchHydra._config_section)
|
||||
loaded_config = OmegaConf.load(io.StringIO(omega_yaml))
|
||||
a_config = OmegaConf.merge(a_config, loaded_config)
|
||||
a_config = OmegaConf.load(io.StringIO(omega_yaml))
|
||||
PatchHydra._register_omegaconf(a_config, is_read_only=False)
|
||||
return task_function(a_config, *a_args, **a_kwargs)
|
||||
|
||||
@ -194,10 +216,6 @@ class PatchHydra(object):
|
||||
description = 'Full OmegaConf YAML configuration overridden! ({}/{}=True)'.format(
|
||||
PatchHydra._parameter_section, PatchHydra._parameter_allow_full_edit)
|
||||
|
||||
# we should not have the hydra section in the config, but this seems never to be the case anymore.
|
||||
# config = config.copy()
|
||||
# config.pop('hydra', None)
|
||||
|
||||
configuration = dict(
|
||||
name=PatchHydra._config_section,
|
||||
description=description,
|
||||
|
@ -28,7 +28,9 @@ class PatchJsonArgParse(object):
|
||||
_commands_sep = "."
|
||||
_command_type = "jsonargparse.Command"
|
||||
_command_name = "subcommand"
|
||||
_special_fields = ["config", "subcommand"]
|
||||
_section_name = "Args"
|
||||
_allow_jsonargparse_overrides = "_allow_config_file_override_from_ui_"
|
||||
__remote_task_params = {}
|
||||
__remote_task_params_dict = {}
|
||||
__patched = False
|
||||
@ -60,6 +62,7 @@ class PatchJsonArgParse(object):
|
||||
return
|
||||
args = {}
|
||||
args_type = {}
|
||||
have_config_file = False
|
||||
for k, v in cls._args.items():
|
||||
key_with_section = cls._section_name + cls._args_sep + k
|
||||
args[key_with_section] = v
|
||||
@ -75,27 +78,18 @@ class PatchJsonArgParse(object):
|
||||
elif isinstance(v, Path) or (isinstance(v, list) and all(isinstance(sub_v, Path) for sub_v in v)):
|
||||
args[key_with_section] = json.dumps(PatchJsonArgParse._handle_path(v))
|
||||
args_type[key_with_section] = PatchJsonArgParse.path_type
|
||||
have_config_file = True
|
||||
else:
|
||||
args[key_with_section] = str(v)
|
||||
except Exception:
|
||||
pass
|
||||
args, args_type = cls.__delete_config_args(parser, args, args_type, subcommand=subcommand)
|
||||
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
||||
|
||||
@classmethod
|
||||
def __delete_config_args(cls, parser, args, args_type, subcommand=None):
|
||||
if not parser:
|
||||
return args, args_type
|
||||
paths = PatchJsonArgParse.__get_paths_from_dict(cls._args)
|
||||
for path in paths:
|
||||
args_to_delete = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
|
||||
for arg_to_delete_key, arg_to_delete_value in args_to_delete.items():
|
||||
key_with_section = cls._section_name + cls._args_sep + arg_to_delete_key
|
||||
if key_with_section in args and args[key_with_section] == arg_to_delete_value:
|
||||
del args[key_with_section]
|
||||
if key_with_section in args_type:
|
||||
del args_type[key_with_section]
|
||||
return args, args_type
|
||||
if have_config_file:
|
||||
cls._current_task.set_parameter(
|
||||
cls._section_name + cls._args_sep + cls._allow_jsonargparse_overrides,
|
||||
False,
|
||||
description="If True, values in the config file will be overriden by values found in the UI. Otherwise, the values in the config file have priority"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _adapt_typehints(original_fn, val, *args, **kwargs):
|
||||
@ -103,6 +97,17 @@ class PatchJsonArgParse(object):
|
||||
return original_fn(val, *args, **kwargs)
|
||||
return original_fn(val, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def __restore_args(parser, args, subcommand=None):
|
||||
paths = PatchJsonArgParse.__get_paths_from_dict(args)
|
||||
for path in paths:
|
||||
args_to_restore = PatchJsonArgParse.__get_args_from_path(parser, path, subcommand=subcommand)
|
||||
for arg_to_restore_key, arg_to_restore_value in args_to_restore.items():
|
||||
if arg_to_restore_key in PatchJsonArgParse._special_fields:
|
||||
continue
|
||||
args[arg_to_restore_key] = arg_to_restore_value
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(original_fn, obj, *args, **kwargs):
|
||||
if not PatchJsonArgParse._current_task:
|
||||
@ -119,6 +124,13 @@ class PatchJsonArgParse(object):
|
||||
params_namespace = Namespace()
|
||||
for k, v in params.items():
|
||||
params_namespace[k] = v
|
||||
allow_jsonargparse_overrides_value = params.pop(PatchJsonArgParse._allow_jsonargparse_overrides, True)
|
||||
if not allow_jsonargparse_overrides_value:
|
||||
params_namespace = PatchJsonArgParse.__restore_args(
|
||||
obj,
|
||||
params_namespace,
|
||||
subcommand=params_namespace.get(PatchJsonArgParse._command_name)
|
||||
)
|
||||
return params_namespace
|
||||
except Exception as e:
|
||||
logging.getLogger(__file__).warning("Failed parsing jsonargparse arguments: {}".format(e))
|
||||
@ -210,7 +222,7 @@ class PatchJsonArgParse(object):
|
||||
parsed_cfg = parser.parse_string(path.get_content(), _skip_check=True, _fail_no_subcommand=False)
|
||||
if subcommand:
|
||||
parsed_cfg = {
|
||||
((subcommand + PatchJsonArgParse._commands_sep) if k not in ["config", "subcommand"] else "") + k: v
|
||||
((subcommand + PatchJsonArgParse._commands_sep) if k not in PatchJsonArgParse._special_fields else "") + k: v
|
||||
for k, v in parsed_cfg.items()
|
||||
}
|
||||
return parsed_cfg
|
||||
|
@ -3215,17 +3215,17 @@ class Dataset(object):
|
||||
pool.close()
|
||||
|
||||
def _verify_dataset_folder(self, target_base_folder, part, chunk_selection, max_workers):
|
||||
# type: (Path, Optional[int], Optional[dict], Optional[int]) -> bool
|
||||
# type: (Path, int, dict, int) -> bool
|
||||
|
||||
def verify_file_or_link(base_folder, ds_part, ds_chunk_selection, file_entry):
|
||||
# type: (Path, Optional[int], Optional[dict], FileEntry) -> Optional[bool]
|
||||
def __verify_file_or_link(target_base_folder, file_entry, part=None, chunk_selection=None):
|
||||
# type: (Path, Union[FileEntry, LinkEntry], Optional[int], Optional[dict]) -> bool
|
||||
|
||||
# check if we need the file for the requested dataset part
|
||||
if ds_part is not None:
|
||||
f_parts = ds_chunk_selection.get(file_entry.parent_dataset_id, [])
|
||||
# file is not in requested dataset part, no need to check it.
|
||||
if self._get_chunk_idx_from_artifact_name(file_entry.artifact_name) not in f_parts:
|
||||
return None
|
||||
return True
|
||||
|
||||
# check if the local size and the stored size match (faster than comparing hash)
|
||||
if (base_folder / file_entry.relative_path).stat().st_size != file_entry.size:
|
||||
@ -3237,21 +3237,26 @@ class Dataset(object):
|
||||
# check dataset file size, if we have a full match no need for parent dataset download / merge
|
||||
verified = True
|
||||
# noinspection PyBroadException
|
||||
tp = None
|
||||
try:
|
||||
futures_ = []
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as tp:
|
||||
for f in self._dataset_file_entries.values():
|
||||
future = tp.submit(verify_file_or_link, target_base_folder, part, chunk_selection, f)
|
||||
futures_.append(future)
|
||||
tp = ThreadPoolExecutor(max_workers=max_workers)
|
||||
for f in self._dataset_file_entries.values():
|
||||
future = tp.submit(__verify_file_or_link, target_base_folder, f, part, chunk_selection)
|
||||
futures_.append(future)
|
||||
|
||||
for f in self._dataset_link_entries.values():
|
||||
# don't check whether link is in dataset part, hence None for part and chunk_selection
|
||||
future = tp.submit(verify_file_or_link, target_base_folder, None, None, f)
|
||||
futures_.append(future)
|
||||
for f in self._dataset_link_entries.values():
|
||||
# don't check whether link is in dataset part, hence None for part and chunk_selection
|
||||
future = tp.submit(__verify_file_or_link, target_base_folder, f, None, None)
|
||||
futures_.append(future)
|
||||
|
||||
verified = all(f.result() is not False for f in futures_)
|
||||
verified = all(f.result() for f in futures_)
|
||||
except Exception:
|
||||
verified = False
|
||||
finally:
|
||||
if tp is not None:
|
||||
# we already have our result, close all pending checks (improves performance when verified==False)
|
||||
tp.shutdown(cancel_futures=True)
|
||||
|
||||
return verified
|
||||
|
||||
|
@ -362,8 +362,8 @@ class BaseModel(object):
|
||||
self._task_connect_name = None
|
||||
self._set_task(task)
|
||||
|
||||
def get_weights(self, raise_on_error=False, force_download=False):
|
||||
# type: (bool, bool) -> str
|
||||
def get_weights(self, raise_on_error=False, force_download=False, extract_archive=False):
|
||||
# type: (bool, bool, bool) -> str
|
||||
"""
|
||||
Download the base model and return the locally stored filename.
|
||||
|
||||
@ -373,17 +373,19 @@ class BaseModel(object):
|
||||
:param bool force_download: If True, the base model will be downloaded,
|
||||
even if the base model is already cached.
|
||||
|
||||
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
||||
|
||||
:return: The locally stored file.
|
||||
"""
|
||||
# download model (synchronously) and return local file
|
||||
return self._get_base_model().download_model_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download
|
||||
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
||||
)
|
||||
|
||||
def get_weights_package(
|
||||
self, return_path=False, raise_on_error=False, force_download=False
|
||||
self, return_path=False, raise_on_error=False, force_download=False, extract_archive=True
|
||||
):
|
||||
# type: (bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
||||
# type: (bool, bool, bool, bool) -> Optional[Union[str, List[Path]]]
|
||||
"""
|
||||
Download the base model package into a temporary directory (extract the files), or return a list of the
|
||||
locally stored filenames.
|
||||
@ -399,6 +401,8 @@ class BaseModel(object):
|
||||
:param bool force_download: If True, the base artifact will be downloaded,
|
||||
even if the artifact is already cached.
|
||||
|
||||
:param bool extract_archive: If True, the downloaded weights file will be extracted if possible
|
||||
|
||||
:return: The model weights, or a list of the locally stored filenames.
|
||||
if raise_on_error=False, returns None on error.
|
||||
"""
|
||||
@ -407,40 +411,21 @@ class BaseModel(object):
|
||||
raise ValueError("Model is not packaged")
|
||||
|
||||
# download packaged model
|
||||
packed_file = self.get_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download
|
||||
model_path = self.get_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download, extract_archive=extract_archive
|
||||
)
|
||||
|
||||
if not packed_file:
|
||||
if not model_path:
|
||||
if raise_on_error:
|
||||
raise ValueError(
|
||||
"Model package '{}' could not be downloaded".format(self.url)
|
||||
)
|
||||
return None
|
||||
|
||||
# unpack
|
||||
target_folder = mkdtemp(prefix="model_package_")
|
||||
if not target_folder:
|
||||
raise ValueError(
|
||||
"cannot create temporary directory for packed weight files"
|
||||
)
|
||||
|
||||
for func in (zipfile.ZipFile, tarfile.open):
|
||||
try:
|
||||
obj = func(packed_file)
|
||||
obj.extractall(path=target_folder)
|
||||
break
|
||||
except (zipfile.BadZipfile, tarfile.ReadError):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"cannot extract files from packaged model at %s", packed_file
|
||||
)
|
||||
|
||||
if return_path:
|
||||
return target_folder
|
||||
return model_path
|
||||
|
||||
target_files = list(Path(target_folder).glob("*"))
|
||||
target_files = list(Path(model_path).glob("*"))
|
||||
return target_files
|
||||
|
||||
def report_scalar(self, title, series, value, iteration):
|
||||
@ -1374,17 +1359,18 @@ class Model(BaseModel):
|
||||
self._base_model = None
|
||||
|
||||
def get_local_copy(
|
||||
self, extract_archive=True, raise_on_error=False, force_download=False
|
||||
self, extract_archive=None, raise_on_error=False, force_download=False
|
||||
):
|
||||
# type: (bool, bool, bool) -> str
|
||||
# type: (Optional[bool], bool, bool) -> str
|
||||
"""
|
||||
Retrieve a valid link to the model file(s).
|
||||
If the model URL is a file system link, it will be returned directly.
|
||||
If the model URL points to a remote location (http/s3/gs etc.),
|
||||
it will download the file(s) and return the temporary location of the downloaded model.
|
||||
|
||||
:param bool extract_archive: If True, and the model is of type 'packaged' (e.g. TensorFlow compressed folder)
|
||||
The returned path will be a temporary folder containing the archive content
|
||||
:param bool extract_archive: If True, the local copy will be extracted if possible. If False,
|
||||
the local copy will not be extracted. If None (default), the downloaded file will be extracted
|
||||
if the model is a package.
|
||||
:param bool raise_on_error: If True, and the artifact could not be downloaded,
|
||||
raise ValueError, otherwise return None on failure and output log warning.
|
||||
:param bool force_download: If True, the artifact will be downloaded,
|
||||
@ -1392,14 +1378,17 @@ class Model(BaseModel):
|
||||
|
||||
:return: A local path to the model (or a downloaded copy of it).
|
||||
"""
|
||||
if extract_archive and self._is_package():
|
||||
if self._is_package():
|
||||
return self.get_weights_package(
|
||||
return_path=True,
|
||||
raise_on_error=raise_on_error,
|
||||
force_download=force_download,
|
||||
extract_archive=True if extract_archive is None else extract_archive
|
||||
)
|
||||
return self.get_weights(
|
||||
raise_on_error=raise_on_error, force_download=force_download
|
||||
raise_on_error=raise_on_error,
|
||||
force_download=force_download,
|
||||
extract_archive=False if extract_archive is None else extract_archive
|
||||
)
|
||||
|
||||
def _get_base_model(self):
|
||||
|
@ -739,6 +739,8 @@ class Task(_Task):
|
||||
if argparser_parseargs_called():
|
||||
for parser, parsed_args in get_argparser_last_args():
|
||||
task._connect_argparse(parser=parser, parsed_args=parsed_args)
|
||||
|
||||
PatchHydra.delete_overrides()
|
||||
elif argparser_parseargs_called():
|
||||
# actually we have nothing to do, in remote running, the argparser will ignore
|
||||
# all non argparser parameters, only caveat if parameter connected with the same name
|
||||
@ -2604,15 +2606,15 @@ class Task(_Task):
|
||||
self.reload()
|
||||
script = self.data.script
|
||||
if repository is not None:
|
||||
script.repository = str(repository) or None
|
||||
script.repository = str(repository)
|
||||
if branch is not None:
|
||||
script.branch = str(branch) or None
|
||||
script.branch = str(branch)
|
||||
if script.tag:
|
||||
script.tag = None
|
||||
if commit is not None:
|
||||
script.version_num = str(commit) or None
|
||||
script.version_num = str(commit)
|
||||
if diff is not None:
|
||||
script.diff = str(diff) or None
|
||||
script.diff = str(diff)
|
||||
if working_dir is not None:
|
||||
script.working_dir = str(working_dir)
|
||||
if entry_point is not None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import itertools
|
||||
import json
|
||||
from copy import copy
|
||||
from logging import getLogger
|
||||
|
||||
import six
|
||||
import yaml
|
||||
@ -132,16 +133,21 @@ def cast_basic_type(value, type_str):
|
||||
parts = type_str.split('/')
|
||||
# nested = len(parts) > 1
|
||||
|
||||
if parts[0] in ('list', 'tuple'):
|
||||
v = '[' + value.lstrip('[(').rstrip('])') + ']'
|
||||
v = yaml.load(v, Loader=yaml.SafeLoader)
|
||||
return basic_types.get(parts[0])(v)
|
||||
elif parts[0] in ('dict', ):
|
||||
if parts[0] in ("list", "tuple", "dict"):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return json.loads(value)
|
||||
# lists/tuple/dicts should be json loadable
|
||||
return basic_types.get(parts[0])(json.loads(value))
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# fallback to legacy basic type loading
|
||||
v = '[' + value.lstrip('[(').rstrip('])') + ']'
|
||||
v = yaml.load(v, Loader=yaml.SafeLoader)
|
||||
return basic_types.get(parts[0])(v)
|
||||
except Exception:
|
||||
getLogger().warning("Could not cast `{}` to basic type. Returning it as `str`".format(value))
|
||||
return value
|
||||
|
||||
t = basic_types.get(str(type_str).lower().strip(), False)
|
||||
if t is not False:
|
||||
|
@ -69,8 +69,7 @@
|
||||
"Input the following parameters:\n",
|
||||
"* `name` - Name of the PipelineController task which will created\n",
|
||||
"* `project` - Project which the controller will be associated with\n",
|
||||
"* `version` - Pipeline's version number. This version allows to uniquely identify the pipeline template execution.\n",
|
||||
"* `auto_version_bump` (default True) - if the same pipeline version already exists (with any difference from the current one), the current pipeline version will be bumped to a new version (e.g. 1.0.0 -> 1.0.1 , 1.2 -> 1.3, 10 -> 11)\n",
|
||||
"* `version` - Pipeline's version number. This version allows to uniquely identify the pipeline template execution. If not set, find the pipeline's latest version and increment it. If no such version is found, defaults to `1.0.0`.\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user