Fix broken Task.init() auto_connect_frameworks wildcard filter support

This commit is contained in:
allegroai 2022-05-01 22:32:49 +03:00
parent 3d3a835435
commit dd4eca24c3
2 changed files with 12 additions and 8 deletions

View File

@ -101,6 +101,8 @@ class Framework(Options):
@classmethod @classmethod
def get_framework_parents(cls, framework): def get_framework_parents(cls, framework):
if not framework:
return []
parents = [] parents = []
for k, v in cls.__parent_mapping.items(): for k, v in cls.__parent_mapping.items():
if framework in v: if framework in v:
@ -542,7 +544,7 @@ class Model(BaseModel):
res = _Model._get_default_session().send( res = _Model._get_default_session().send(
models.GetAllRequest( models.GetAllRequest(
project=[project.id] if project else None, project=[project.id] if project else None,
name=model_name or None, name=exact_match_regex(model_name) if model_name is not None else None,
only_fields=only_fields, only_fields=only_fields,
tags=tags or None, tags=tags or None,
system_tags=["-" + cls._archived_tag] if not include_archived else None, system_tags=["-" + cls._archived_tag] if not include_archived else None,

View File

@ -211,7 +211,7 @@ class Task(_Task):
continue_last_task=False, # type: Union[bool, str, int] continue_last_task=False, # type: Union[bool, str, int]
output_uri=None, # type: Optional[Union[str, bool]] output_uri=None, # type: Optional[Union[str, bool]]
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]] auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]] auto_connect_frameworks=True, # type: Union[bool, Mapping[str, Union[bool, str, list]]]
auto_resource_monitoring=True, # type: bool auto_resource_monitoring=True, # type: bool
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]] auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
wait_for_task_init=True, # type: bool wait_for_task_init=True, # type: bool
@ -370,11 +370,12 @@ class Task(_Task):
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected - A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
frameworks. The dictionary keys are frameworks and the values are booleans, other dictionaries used for frameworks. The dictionary keys are frameworks and the values are booleans, other dictionaries used for
finer control or wildcard strings. finer control or wildcard strings.
In case of wildcard strings, the local path of models have to match at least one wildcard to be In case of wildcard strings, the local path of a model file has to match at least one wildcard to be
saved/loaded by ClearML. saved/loaded by ClearML. Example:
{'pytorch' : '*.pt', 'tensorflow': '*'}
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``. Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
Supported keys for finer control: Supported keys for finer control:
'tensorboard': {'report_hparams': bool} # whether or not to report TensorBoard hyperparameters {'tensorboard': {'report_hparams': bool}} # whether to report TensorBoard hyperparameters
For example: For example:
@ -3950,14 +3951,15 @@ class Task(_Task):
for k, v in auto_connect_frameworks.items(): for k, v in auto_connect_frameworks.items():
if isinstance(v, str): if isinstance(v, str):
v = [v] v = [v]
if isinstance(v, list): if isinstance(v, (list, tuple)):
WeightsFileHandler.model_wildcards[k] = v WeightsFileHandler.model_wildcards[k] = [str(i) for i in v]
def callback(_, model_info): def callback(_, model_info):
parents = Framework.get_framework_parents(model_info.framework) parents = Framework.get_framework_parents(model_info.framework)
wildcards = [] wildcards = []
for parent in parents: for parent in parents:
wildcards.extend(WeightsFileHandler.model_wildcards[parent]) if WeightsFileHandler.model_wildcards.get(parent):
wildcards.extend(WeightsFileHandler.model_wildcards[parent])
if not wildcards: if not wildcards:
return model_info return model_info
if not matches_any_wildcard(model_info.local_model_path, wildcards): if not matches_any_wildcard(model_info.local_model_path, wildcards):