mirror of
https://github.com/clearml/clearml
synced 2025-06-25 17:46:28 +00:00
Add wild card support in model auto-logging (https://clearml.slack.com/archives/CTK20V944/p1644931337863039)
This commit is contained in:
parent
f4e4423b3a
commit
5a6ec697e1
@ -55,6 +55,7 @@ class WeightsFileHandler(object):
|
|||||||
_model_store_lookup_lock = threading.Lock()
|
_model_store_lookup_lock = threading.Lock()
|
||||||
_model_pre_callbacks = {}
|
_model_pre_callbacks = {}
|
||||||
_model_post_callbacks = {}
|
_model_post_callbacks = {}
|
||||||
|
model_wildcards = {}
|
||||||
|
|
||||||
class CallbackType(Enum):
|
class CallbackType(Enum):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
@ -83,6 +83,28 @@ class Framework(Options):
|
|||||||
'.cbm': (catboost, ),
|
'.cbm': (catboost, ),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__parent_mapping = {
|
||||||
|
"tensorflow": (
|
||||||
|
tensorflow,
|
||||||
|
tensorflowjs,
|
||||||
|
tensorflowlite,
|
||||||
|
keras,
|
||||||
|
),
|
||||||
|
"pytorch": (pytorch,),
|
||||||
|
"xgboost": (xgboost,),
|
||||||
|
"lightgbm": (lightgbm,),
|
||||||
|
"catboost": (catboost,),
|
||||||
|
"joblib": (scikitlearn, xgboost)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_framework_parents(cls, framework):
|
||||||
|
parents = []
|
||||||
|
for k, v in cls.__parent_mapping.items():
|
||||||
|
if framework in v:
|
||||||
|
parents.append(k)
|
||||||
|
return parents
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_file_ext(cls, framework, filename):
|
def _get_file_ext(cls, framework, filename):
|
||||||
mapping = cls.__file_extensions_mapping
|
mapping = cls.__file_extensions_mapping
|
||||||
|
@ -52,6 +52,7 @@ from .binding.hydra_bind import PatchHydra
|
|||||||
from .binding.click_bind import PatchClick
|
from .binding.click_bind import PatchClick
|
||||||
from .binding.fire_bind import PatchFire
|
from .binding.fire_bind import PatchFire
|
||||||
from .binding.jsonargs_bind import PatchJsonArgParse
|
from .binding.jsonargs_bind import PatchJsonArgParse
|
||||||
|
from .binding.frameworks import WeightsFileHandler
|
||||||
from .config import (
|
from .config import (
|
||||||
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
|
||||||
deferred_config, TASK_SET_ITERATION_OFFSET, )
|
deferred_config, TASK_SET_ITERATION_OFFSET, )
|
||||||
@ -60,7 +61,7 @@ from .config.cache import SessionCache
|
|||||||
from .debugging.log import LoggerRoot
|
from .debugging.log import LoggerRoot
|
||||||
from .errors import UsageError
|
from .errors import UsageError
|
||||||
from .logger import Logger
|
from .logger import Logger
|
||||||
from .model import Model, InputModel, OutputModel
|
from .model import Model, InputModel, OutputModel, Framework
|
||||||
from .task_parameters import TaskParameters
|
from .task_parameters import TaskParameters
|
||||||
from .utilities.config import verify_basic_value
|
from .utilities.config import verify_basic_value
|
||||||
from .binding.args import (
|
from .binding.args import (
|
||||||
@ -74,6 +75,7 @@ from .utilities.resource_monitor import ResourceMonitor
|
|||||||
from .utilities.seed import make_deterministic
|
from .utilities.seed import make_deterministic
|
||||||
from .utilities.lowlevel.threads import get_current_thread_id
|
from .utilities.lowlevel.threads import get_current_thread_id
|
||||||
from .utilities.process.mp import BackgroundMonitor, leave_process
|
from .utilities.process.mp import BackgroundMonitor, leave_process
|
||||||
|
from .utilities.matching import matches_any_wildcard
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
from .backend_interface.task.args import _Arguments
|
from .backend_interface.task.args import _Arguments
|
||||||
|
|
||||||
@ -364,7 +366,9 @@ class Task(_Task):
|
|||||||
- ``True`` - Automatically connect (default)
|
- ``True`` - Automatically connect (default)
|
||||||
- ``False`` - Do not automatically connect
|
- ``False`` - Do not automatically connect
|
||||||
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
|
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
|
||||||
frameworks. The dictionary keys are frameworks and the values are booleans.
|
frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings.
|
||||||
|
In case of wildcard strings, the local path of models have to match at least one wildcard to be
|
||||||
|
saved/loaded by ClearML.
|
||||||
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
|
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
@ -372,8 +376,8 @@ class Task(_Task):
|
|||||||
.. code-block:: py
|
.. code-block:: py
|
||||||
|
|
||||||
auto_connect_frameworks={
|
auto_connect_frameworks={
|
||||||
'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True,
|
'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True,
|
||||||
'xgboost': True, 'scikit': True, 'fastai': True,
|
'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True,
|
||||||
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
|
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
|
||||||
'joblib': True, 'megengine': True, 'catboost': True
|
'joblib': True, 'megengine': True, 'catboost': True
|
||||||
}
|
}
|
||||||
@ -599,6 +603,7 @@ class Task(_Task):
|
|||||||
task, report_mem_used_per_process=not config.get(
|
task, report_mem_used_per_process=not config.get(
|
||||||
'development.worker.report_global_mem_used', False))
|
'development.worker.report_global_mem_used', False))
|
||||||
task._resource_monitor.start()
|
task._resource_monitor.start()
|
||||||
|
cls.__add_model_wildcards(auto_connect_frameworks)
|
||||||
|
|
||||||
# make sure all random generators are initialized with new seed
|
# make sure all random generators are initialized with new seed
|
||||||
make_deterministic(task.get_random_seed())
|
make_deterministic(task.get_random_seed())
|
||||||
@ -3818,6 +3823,28 @@ class Task(_Task):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __add_model_wildcards(cls, auto_connect_frameworks):
|
||||||
|
if isinstance(auto_connect_frameworks, dict):
|
||||||
|
for k, v in auto_connect_frameworks.items():
|
||||||
|
if isinstance(v, str):
|
||||||
|
v = [v]
|
||||||
|
if isinstance(v, list):
|
||||||
|
WeightsFileHandler.model_wildcards[k] = v
|
||||||
|
|
||||||
|
def callback(_, model_info):
|
||||||
|
parents = Framework.get_framework_parents(model_info.framework)
|
||||||
|
wildcards = []
|
||||||
|
for parent in parents:
|
||||||
|
wildcards.extend(WeightsFileHandler.model_wildcards[parent])
|
||||||
|
if not wildcards:
|
||||||
|
return model_info
|
||||||
|
if not matches_any_wildcard(model_info.local_model_path, wildcards):
|
||||||
|
return None
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
WeightsFileHandler.add_pre_callback(callback)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
# type: () -> dict
|
# type: () -> dict
|
||||||
return {'main': self.is_main_task(), 'id': self.id, 'offline': self.is_offline()}
|
return {'main': self.is_main_task(), 'id': self.id, 'offline': self.is_offline()}
|
||||||
|
20
clearml/utilities/matching.py
Normal file
20
clearml/utilities/matching.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import fnmatch
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
def matches_any_wildcard(pattern, wildcards):
|
||||||
|
# type: (str, Union[str, list]) -> bool
|
||||||
|
"""
|
||||||
|
Checks if given pattern matches any supplied wildcard
|
||||||
|
|
||||||
|
:param pattern: pattern to check
|
||||||
|
:param wildcards: wildcards to check against
|
||||||
|
|
||||||
|
:return: True if pattern matches any wildcard and False otherwise
|
||||||
|
"""
|
||||||
|
if isinstance(wildcards, str):
|
||||||
|
wildcards = [wildcards]
|
||||||
|
for wildcard in wildcards:
|
||||||
|
if fnmatch.fnmatch(pattern, wildcard):
|
||||||
|
return True
|
||||||
|
return False
|
Loading…
Reference in New Issue
Block a user