Fix broken Task.init() auto_connect_frameworks wildcard filter support

This commit is contained in:
allegroai 2022-05-01 22:32:49 +03:00
parent 3d3a835435
commit dd4eca24c3
2 changed files with 12 additions and 8 deletions

View File

@ -101,6 +101,8 @@ class Framework(Options):
@classmethod
def get_framework_parents(cls, framework):
if not framework:
return []
parents = []
for k, v in cls.__parent_mapping.items():
if framework in v:
@ -542,7 +544,7 @@ class Model(BaseModel):
res = _Model._get_default_session().send(
models.GetAllRequest(
project=[project.id] if project else None,
name=model_name or None,
name=exact_match_regex(model_name) if model_name is not None else None,
only_fields=only_fields,
tags=tags or None,
system_tags=["-" + cls._archived_tag] if not include_archived else None,

View File

@ -211,7 +211,7 @@ class Task(_Task):
continue_last_task=False, # type: Union[bool, str, int]
output_uri=None, # type: Optional[Union[str, bool]]
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, Union[bool, str, list]]]
auto_resource_monitoring=True, # type: bool
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
wait_for_task_init=True, # type: bool
@ -370,11 +370,12 @@ class Task(_Task):
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
frameworks. The dictionary keys are frameworks and the values are booleans, other dictionaries used for
finer control or wildcard strings.
In case of wildcard strings, the local path of models have to match at least one wildcard to be
saved/loaded by ClearML.
In case of wildcard strings, the local path of a model file has to match at least one wildcard to be
saved/loaded by ClearML. Example:
{'pytorch' : '*.pt', 'tensorflow': '*'}
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
Supported keys for finer control:
'tensorboard': {'report_hparams': bool} # whether or not to report TensorBoard hyperparameters
{'tensorboard': {'report_hparams': bool}} # whether to report TensorBoard hyperparameters
For example:
@ -3950,14 +3951,15 @@ class Task(_Task):
for k, v in auto_connect_frameworks.items():
if isinstance(v, str):
v = [v]
if isinstance(v, list):
WeightsFileHandler.model_wildcards[k] = v
if isinstance(v, (list, tuple)):
WeightsFileHandler.model_wildcards[k] = [str(i) for i in v]
def callback(_, model_info):
parents = Framework.get_framework_parents(model_info.framework)
wildcards = []
for parent in parents:
wildcards.extend(WeightsFileHandler.model_wildcards[parent])
if WeightsFileHandler.model_wildcards.get(parent):
wildcards.extend(WeightsFileHandler.model_wildcards[parent])
if not wildcards:
return model_info
if not matches_any_wildcard(model_info.local_model_path, wildcards):