mirror of
https://github.com/clearml/clearml
synced 2025-02-07 21:33:25 +00:00
Add wild card support in model auto-logging (https://clearml.slack.com/archives/CTK20V944/p1644931337863039)
This commit is contained in:
parent
f4e4423b3a
commit
5a6ec697e1
@ -55,6 +55,7 @@ class WeightsFileHandler(object):
|
||||
_model_store_lookup_lock = threading.Lock()
|
||||
_model_pre_callbacks = {}
|
||||
_model_post_callbacks = {}
|
||||
model_wildcards = {}
|
||||
|
||||
class CallbackType(Enum):
|
||||
def __str__(self):
|
||||
|
@ -83,6 +83,28 @@ class Framework(Options):
|
||||
'.cbm': (catboost, ),
|
||||
}
|
||||
|
||||
__parent_mapping = {
|
||||
"tensorflow": (
|
||||
tensorflow,
|
||||
tensorflowjs,
|
||||
tensorflowlite,
|
||||
keras,
|
||||
),
|
||||
"pytorch": (pytorch,),
|
||||
"xgboost": (xgboost,),
|
||||
"lightgbm": (lightgbm,),
|
||||
"catboost": (catboost,),
|
||||
"joblib": (scikitlearn, xgboost)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_framework_parents(cls, framework):
|
||||
parents = []
|
||||
for k, v in cls.__parent_mapping.items():
|
||||
if framework in v:
|
||||
parents.append(k)
|
||||
return parents
|
||||
|
||||
@classmethod
|
||||
def _get_file_ext(cls, framework, filename):
|
||||
mapping = cls.__file_extensions_mapping
|
||||
|
@ -52,6 +52,7 @@ from .binding.hydra_bind import PatchHydra
|
||||
from .binding.click_bind import PatchClick
|
||||
from .binding.fire_bind import PatchFire
|
||||
from .binding.jsonargs_bind import PatchJsonArgParse
|
||||
from .binding.frameworks import WeightsFileHandler
|
||||
from .config import (
|
||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
||||
deferred_config, TASK_SET_ITERATION_OFFSET, )
|
||||
@ -60,7 +61,7 @@ from .config.cache import SessionCache
|
||||
from .debugging.log import LoggerRoot
|
||||
from .errors import UsageError
|
||||
from .logger import Logger
|
||||
from .model import Model, InputModel, OutputModel
|
||||
from .model import Model, InputModel, OutputModel, Framework
|
||||
from .task_parameters import TaskParameters
|
||||
from .utilities.config import verify_basic_value
|
||||
from .binding.args import (
|
||||
@ -74,6 +75,7 @@ from .utilities.resource_monitor import ResourceMonitor
|
||||
from .utilities.seed import make_deterministic
|
||||
from .utilities.lowlevel.threads import get_current_thread_id
|
||||
from .utilities.process.mp import BackgroundMonitor, leave_process
|
||||
from .utilities.matching import matches_any_wildcard
|
||||
# noinspection PyProtectedMember
|
||||
from .backend_interface.task.args import _Arguments
|
||||
|
||||
@ -364,7 +366,9 @@ class Task(_Task):
|
||||
- ``True`` - Automatically connect (default)
|
||||
- ``False`` - Do not automatically connect
|
||||
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
|
||||
frameworks. The dictionary keys are frameworks and the values are booleans.
|
||||
frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings.
|
||||
In case of wildcard strings, the local path of models have to match at least one wildcard to be
|
||||
saved/loaded by ClearML.
|
||||
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
|
||||
|
||||
For example:
|
||||
@ -372,8 +376,8 @@ class Task(_Task):
|
||||
.. code-block:: py
|
||||
|
||||
auto_connect_frameworks={
|
||||
'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True,
|
||||
'xgboost': True, 'scikit': True, 'fastai': True,
|
||||
'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True,
|
||||
'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True,
|
||||
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
|
||||
'joblib': True, 'megengine': True, 'catboost': True
|
||||
}
|
||||
@ -599,6 +603,7 @@ class Task(_Task):
|
||||
task, report_mem_used_per_process=not config.get(
|
||||
'development.worker.report_global_mem_used', False))
|
||||
task._resource_monitor.start()
|
||||
cls.__add_model_wildcards(auto_connect_frameworks)
|
||||
|
||||
# make sure all random generators are initialized with new seed
|
||||
make_deterministic(task.get_random_seed())
|
||||
@ -3818,6 +3823,28 @@ class Task(_Task):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def __add_model_wildcards(cls, auto_connect_frameworks):
|
||||
if isinstance(auto_connect_frameworks, dict):
|
||||
for k, v in auto_connect_frameworks.items():
|
||||
if isinstance(v, str):
|
||||
v = [v]
|
||||
if isinstance(v, list):
|
||||
WeightsFileHandler.model_wildcards[k] = v
|
||||
|
||||
def callback(_, model_info):
|
||||
parents = Framework.get_framework_parents(model_info.framework)
|
||||
wildcards = []
|
||||
for parent in parents:
|
||||
wildcards.extend(WeightsFileHandler.model_wildcards[parent])
|
||||
if not wildcards:
|
||||
return model_info
|
||||
if not matches_any_wildcard(model_info.local_model_path, wildcards):
|
||||
return None
|
||||
return model_info
|
||||
|
||||
WeightsFileHandler.add_pre_callback(callback)
|
||||
|
||||
def __getstate__(self):
|
||||
# type: () -> dict
|
||||
return {'main': self.is_main_task(), 'id': self.id, 'offline': self.is_offline()}
|
||||
|
20
clearml/utilities/matching.py
Normal file
20
clearml/utilities/matching.py
Normal file
@ -0,0 +1,20 @@
|
||||
import fnmatch
|
||||
from typing import Union
|
||||
|
||||
|
||||
def matches_any_wildcard(pattern, wildcards):
|
||||
# type: (str, Union[str, list]) -> bool
|
||||
"""
|
||||
Checks if given pattern matches any supplied wildcard
|
||||
|
||||
:param pattern: pattern to check
|
||||
:param wildcards: wildcards to check against
|
||||
|
||||
:return: True if pattern matches any wildcard and False otherwise
|
||||
"""
|
||||
if isinstance(wildcards, str):
|
||||
wildcards = [wildcards]
|
||||
for wildcard in wildcards:
|
||||
if fnmatch.fnmatch(pattern, wildcard):
|
||||
return True
|
||||
return False
|
Loading…
Reference in New Issue
Block a user