This commit is contained in:
allegroai 2022-04-05 13:47:07 +03:00
parent f4e4423b3a
commit 5a6ec697e1
4 changed files with 74 additions and 4 deletions

View File

@ -55,6 +55,7 @@ class WeightsFileHandler(object):
_model_store_lookup_lock = threading.Lock()
_model_pre_callbacks = {}
_model_post_callbacks = {}
model_wildcards = {}
class CallbackType(Enum):
def __str__(self):

View File

@ -83,6 +83,28 @@ class Framework(Options):
'.cbm': (catboost, ),
}
__parent_mapping = {
"tensorflow": (
tensorflow,
tensorflowjs,
tensorflowlite,
keras,
),
"pytorch": (pytorch,),
"xgboost": (xgboost,),
"lightgbm": (lightgbm,),
"catboost": (catboost,),
"joblib": (scikitlearn, xgboost)
}
@classmethod
def get_framework_parents(cls, framework):
parents = []
for k, v in cls.__parent_mapping.items():
if framework in v:
parents.append(k)
return parents
@classmethod
def _get_file_ext(cls, framework, filename):
mapping = cls.__file_extensions_mapping

View File

@ -52,6 +52,7 @@ from .binding.hydra_bind import PatchHydra
from .binding.click_bind import PatchClick
from .binding.fire_bind import PatchFire
from .binding.jsonargs_bind import PatchJsonArgParse
from .binding.frameworks import WeightsFileHandler
from .config import (
config, DEV_TASK_NO_REUSE, get_is_master_node, DEBUG_SIMULATE_REMOTE_TASK, DEV_DEFAULT_OUTPUT_URI,
deferred_config, TASK_SET_ITERATION_OFFSET, )
@ -60,7 +61,7 @@ from .config.cache import SessionCache
from .debugging.log import LoggerRoot
from .errors import UsageError
from .logger import Logger
from .model import Model, InputModel, OutputModel
from .model import Model, InputModel, OutputModel, Framework
from .task_parameters import TaskParameters
from .utilities.config import verify_basic_value
from .binding.args import (
@ -74,6 +75,7 @@ from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
from .utilities.lowlevel.threads import get_current_thread_id
from .utilities.process.mp import BackgroundMonitor, leave_process
from .utilities.matching import matches_any_wildcard
# noinspection PyProtectedMember
from .backend_interface.task.args import _Arguments
@ -364,7 +366,9 @@ class Task(_Task):
- ``True`` - Automatically connect (default)
- ``False`` - Do not automatically connect
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of connected
frameworks. The dictionary keys are frameworks and the values are booleans.
frameworks. The dictionary keys are frameworks and values are booleans or wildcard strings.
In case of wildcard strings, the local path of models have to match at least one wildcard to be
saved/loaded by ClearML.
Keys missing from the dictionary default to ``True``, and an empty dictionary defaults to ``False``.
For example:
@ -372,8 +376,8 @@ class Task(_Task):
.. code-block:: py
auto_connect_frameworks={
'matplotlib': True, 'tensorflow': True, 'tensorboard': True, 'pytorch': True,
'xgboost': True, 'scikit': True, 'fastai': True,
'matplotlib': True, 'tensorflow': ['*.hdf5, 'something_else*], 'tensorboard': True,
'pytorch': ['*.pt'], 'xgboost': True, 'scikit': True, 'fastai': True,
'lightgbm': True, 'hydra': True, 'detect_repository': True, 'tfdefines': True,
'joblib': True, 'megengine': True, 'catboost': True
}
@ -599,6 +603,7 @@ class Task(_Task):
task, report_mem_used_per_process=not config.get(
'development.worker.report_global_mem_used', False))
task._resource_monitor.start()
cls.__add_model_wildcards(auto_connect_frameworks)
# make sure all random generators are initialized with new seed
make_deterministic(task.get_random_seed())
@ -3818,6 +3823,28 @@ class Task(_Task):
return True
return False
@classmethod
def __add_model_wildcards(cls, auto_connect_frameworks):
if isinstance(auto_connect_frameworks, dict):
for k, v in auto_connect_frameworks.items():
if isinstance(v, str):
v = [v]
if isinstance(v, list):
WeightsFileHandler.model_wildcards[k] = v
def callback(_, model_info):
parents = Framework.get_framework_parents(model_info.framework)
wildcards = []
for parent in parents:
wildcards.extend(WeightsFileHandler.model_wildcards[parent])
if not wildcards:
return model_info
if not matches_any_wildcard(model_info.local_model_path, wildcards):
return None
return model_info
WeightsFileHandler.add_pre_callback(callback)
def __getstate__(self):
# type: () -> dict
return {'main': self.is_main_task(), 'id': self.id, 'offline': self.is_offline()}

View File

@ -0,0 +1,20 @@
import fnmatch
from typing import Union
def matches_any_wildcard(pattern, wildcards):
# type: (str, Union[str, list]) -> bool
"""
Checks if given pattern matches any supplied wildcard
:param pattern: pattern to check
:param wildcards: wildcards to check against
:return: True if pattern matches any wildcard and False otherwise
"""
if isinstance(wildcards, str):
wildcards = [wildcards]
for wildcard in wildcards:
if fnmatch.fnmatch(pattern, wildcard):
return True
return False