diff --git a/clearml/binding/frameworks/__init__.py b/clearml/binding/frameworks/__init__.py index 90c1a140..8b59716d 100644 --- a/clearml/binding/frameworks/__init__.py +++ b/clearml/binding/frameworks/__init__.py @@ -55,6 +55,7 @@ class WeightsFileHandler(object): _model_store_lookup_lock = threading.Lock() _model_pre_callbacks = {} _model_post_callbacks = {} + model_wildcards = {} class CallbackType(Enum): def __str__(self): diff --git a/clearml/model.py b/clearml/model.py index fec0dcfe..cd601a43 100644 --- a/clearml/model.py +++ b/clearml/model.py @@ -83,6 +83,28 @@ class Framework(Options): '.cbm': (catboost, ), } + __parent_mapping = { + "tensorflow": ( + tensorflow, + tensorflowjs, + tensorflowlite, + keras, + ), + "pytorch": (pytorch,), + "xgboost": (xgboost,), + "lightgbm": (lightgbm,), + "catboost": (catboost,), + "joblib": (scikitlearn, xgboost) + } + + @classmethod + def get_framework_parents(cls, framework): + parents = [] + for k, v in cls.__parent_mapping.items(): + if framework in v: + parents.append(k) + return parents + @classmethod def _get_file_ext(cls, framework, filename): mapping = cls.__file_extensions_mapping diff --git a/clearml/task.py b/clearml/task.py index 4eac49a5..17e4f461 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -52,6 +52,7 @@ from .binding.hydra_bind import PatchHydra from .binding.click_bind import PatchClick from .binding.fire_bind import PatchFire from .binding.jsonargs_bind import PatchJsonArgParse +from .binding.frameworks import WeightsFileHandler from .config import ( config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI, deferred_config, TASK_SET_ITERATION_OFFSET, ) @@ -60,7 +61,7 @@ from .config.cache import SessionCache from .debugging.log import LoggerRoot from .errors import UsageError from .logger import Logger -from .model import Model, InputModel, OutputModel +from .model import Model, InputModel, OutputModel, Framework from .task_parameters import TaskParameters from .utilities.config import verify_basic_value from .binding.args import ( @@ -74,6 +75,7 @@ from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic from .utilities.lowlevel.threads import get_current_thread_id from .utilities.process.mp import BackgroundMonitor, leave_process +from .utilities.matching import matches_any_wildcard # noinspection PyProtectedMember from .backend_interface.task.args import _Arguments @@ -364,7 +366,9 @@ class Task(_Task): - ``True`` - Automatically connect (default) - ``False`` - Do not automatically connect - A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected - frameworks. The dictionary keys are frameworks and the values are booleans. + frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings. + In case of wildcard strings, the local path of models have to match at least one wildcard to be + saved/loaded by ClearML. Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``. For example: @@ -372,8 +376,8 @@ class Task(_Task): .. code-block:: py auto_connect_frameworks={ - 'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True, - 'xgboost': True, 'scikit': True, 'fastai': True, + 'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True, + 'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True, 'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True, 'joblib': True, 'megengine': True, 'catboost': True } @@ -599,6 +603,7 @@ class Task(_Task): task, report_mem_used_per_process=not config.get( 'development.worker.report_global_mem_used', False)) task._resource_monitor.start() + cls.__add_model_wildcards(auto_connect_frameworks) # make sure all random generators are initialized with new seed make_deterministic(task.get_random_seed()) @@ -3818,6 +3823,28 @@ class Task(_Task): return True return False + @classmethod + def __add_model_wildcards(cls, auto_connect_frameworks): + if isinstance(auto_connect_frameworks, dict): + for k, v in auto_connect_frameworks.items(): + if isinstance(v, str): + v = [v] + if isinstance(v, list): + WeightsFileHandler.model_wildcards[k] = v + + def callback(_, model_info): + parents = Framework.get_framework_parents(model_info.framework) + wildcards = [] + for parent in parents: + wildcards.extend(WeightsFileHandler.model_wildcards[parent]) + if not wildcards: + return model_info + if not matches_any_wildcard(model_info.local_model_path, wildcards): + return None + return model_info + + WeightsFileHandler.add_pre_callback(callback) + def __getstate__(self): # type: () -> dict return {'main': self.is_main_task(), 'id': self.id, 'offline': self.is_offline()} diff --git a/clearml/utilities/matching.py b/clearml/utilities/matching.py new file mode 100644 index 00000000..47824986 --- /dev/null +++ b/clearml/utilities/matching.py @@ -0,0 +1,20 @@ +import fnmatch +from typing import Union + + +def matches_any_wildcard(pattern, wildcards): + # type: (str, Union[str, list]) -> bool + """ + Checks if given pattern matches any supplied wildcard + + :param pattern: pattern to check + :param wildcards: wildcards to check against + + :return: True if pattern matches any wildcard and False otherwise + """ + if isinstance(wildcards, str): + wildcards = [wildcards] + for wildcard in wildcards: + if fnmatch.fnmatch(pattern, wildcard): + return True + return False