diff --git a/clearml/binding/absl_bind.py b/clearml/binding/absl_bind.py index 8228699c..80aad923 100644 --- a/clearml/binding/absl_bind.py +++ b/clearml/binding/absl_bind.py @@ -8,12 +8,15 @@ from ..config import running_remotely class PatchAbsl(object): _original_DEFINE_flag = None _original_FLAGS_parse_call = None - _task = None + _current_task = None + __patched = False @classmethod - def update_current_task(cls, current_task): - cls._task = current_task - cls._patch_absl() + def update_current_task(cls, task): + cls._current_task = task + if not cls.__patched: + cls._patch_absl() + cls.__patched = True @classmethod def _patch_absl(cls): @@ -53,7 +56,7 @@ class PatchAbsl(object): @staticmethod def _patched_define_flag(*args, **kwargs): - if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag: + if not PatchAbsl._current_task or not PatchAbsl._original_DEFINE_flag: if PatchAbsl._original_DEFINE_flag: return PatchAbsl._original_DEFINE_flag(*args, **kwargs) else: @@ -73,7 +76,7 @@ class PatchAbsl(object): # noinspection PyBroadException try: if param_name and flag: - param_dict = PatchAbsl._task._arguments.copy_to_dict( + param_dict = PatchAbsl._current_task._arguments.copy_to_dict( {param_name: flag.value}, prefix=_Arguments._prefix_tf_defines) flag.value = param_dict.get(param_name, flag.value) except Exception: @@ -82,13 +85,17 @@ class PatchAbsl(object): else: if flag and param_name: value = flag.value - PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, ) + PatchAbsl._current_task.update_parameters( + {_Arguments._prefix_tf_defines + param_name: value}, + ) ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs) return ret @staticmethod def _patched_FLAGS_parse_call(self, *args, **kwargs): ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs) + if not PatchAbsl._current_task: + return ret # noinspection PyBroadException try: PatchAbsl._update_current_flags(self) @@ -98,13 +105,13 @@ class PatchAbsl(object): @classmethod def _update_current_flags(cls, FLAGS): - if not cls._task: + if not cls._current_task: return # noinspection PyBroadException try: if running_remotely(): param_dict = dict((k, FLAGS[k].value) for k in FLAGS) - param_dict = cls._task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines) + param_dict = cls._current_task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines) for k, v in param_dict.items(): # noinspection PyBroadException try: @@ -127,7 +134,7 @@ class PatchAbsl(object): param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS]) except Exception: param_types = None - cls._task._arguments.copy_from_dict( + cls._current_task._arguments.copy_from_dict( parameters, prefix=_Arguments._prefix_tf_defines, descriptions=descriptions, param_types=param_types, diff --git a/clearml/binding/click_bind.py b/clearml/binding/click_bind.py index c6ffbb08..0fb55af4 100644 --- a/clearml/binding/click_bind.py +++ b/clearml/binding/click_bind.py @@ -16,7 +16,7 @@ class PatchClick: _num_commands = 0 _command_type = 'click.Command' _section_name = 'Args' - _main_task = None + _current_task = None __remote_task_params = None __remote_task_params_dict = {} __patched = False @@ -26,8 +26,8 @@ class PatchClick: if Command is None: return + cls._current_task = task if task: - cls._main_task = task PatchClick._update_task_args() if not cls.__patched: @@ -53,11 +53,11 @@ class PatchClick: @classmethod def _update_task_args(cls): - if running_remotely() or not cls._main_task or not cls._args: + if running_remotely() or not cls._current_task or not cls._args: return param_val, param_types, param_desc = cls.args() # noinspection PyProtectedMember - cls._main_task._set_parameters( + cls._current_task._set_parameters( param_val, __update=True, __parameters_descriptions=param_desc, diff --git a/clearml/binding/environ_bind.py b/clearml/binding/environ_bind.py index 8d6ce306..f91b86cb 100644 --- a/clearml/binding/environ_bind.py +++ b/clearml/binding/environ_bind.py @@ -1,4 +1,5 @@ import os +from time import sleep import six @@ -7,26 +8,29 @@ from ..utilities.process.mp import BackgroundMonitor class EnvironmentBind(object): - _task = None + _current_task = None _environment_section = 'Environment' + __patched = False @classmethod - def update_current_task(cls, current_task): - cls._task = current_task + def update_current_task(cls, task): + cls._current_task = task # noinspection PyBroadException try: - cls._bind_environment() + if not cls.__patched: + cls.__patched = True + cls._bind_environment() except Exception: pass @classmethod def _bind_environment(cls): - if not cls._task: + if not cls._current_task: return # get ENVIRONMENT and put it into the OS environment if running_remotely(): - params = cls._task.get_parameters_as_dict() + params = cls._current_task.get_parameters_as_dict() if params and cls._environment_section in params: # put back into os: os.environ.update(params[cls._environment_section]) @@ -55,14 +59,18 @@ class EnvironmentBind(object): elif match in os.environ: env_param.update({match: os.environ.get(match)}) # store os environments - cls._task.connect(env_param, cls._environment_section) + cls._current_task.connect(env_param, cls._environment_section) class PatchOsFork(object): _original_fork = None + _current_task = None @classmethod - def patch_fork(cls): + def patch_fork(cls, task): + cls._current_task = task + if not task: + return # noinspection PyBroadException try: # only once @@ -88,11 +96,13 @@ class PatchOsFork(object): Task._wait_for_deferred(task) ret = PatchOsFork._original_fork(*args, **kwargs) + if not PatchOsFork._current_task: + return ret # Make sure the new process stdout is logged if not ret: # force creating a Task task = Task.current_task() - if task is None: + if not task: return ret # # Hack: now make sure we setup the reporter threads (Log+Reporter) @@ -106,7 +116,7 @@ class PatchOsFork(object): # just make sure we flush the internal state (the at exist caught by the external signal does the rest # in theory we should not have to do any of that, but for some reason if we do not # the signal is never caught by the signal call backs, not sure why.... - + sleep(0.1) # Since at_exist handlers do not work on forked processes, we have to manually call them here if task: try: diff --git a/clearml/binding/fire_bind.py b/clearml/binding/fire_bind.py index 0f22f221..fb59f5a7 100644 --- a/clearml/binding/fire_bind.py +++ b/clearml/binding/fire_bind.py @@ -20,7 +20,7 @@ class PatchFire: _section_name = "Args" _args_sep = "/" _commands_sep = "." - _main_task = None + _current_task = None __remote_task_params = None __remote_task_params_dict = {} __patched = False @@ -38,8 +38,8 @@ class PatchFire: if fire is None: return + cls._current_task = task if task: - cls._main_task = task cls._update_task_args() if not cls.__patched: @@ -53,7 +53,7 @@ class PatchFire: @classmethod def _update_task_args(cls): - if running_remotely() or not cls._main_task: + if running_remotely() or not cls._current_task: return args = {} parameters_types = {} @@ -118,7 +118,7 @@ class PatchFire: parameters_types = {**parameters_types, **unused_paramenters_types} # noinspection PyProtectedMember - cls._main_task._set_parameters( + cls._current_task._set_parameters( args, __update=True, __parameters_types=parameters_types, @@ -126,25 +126,25 @@ class PatchFire: @staticmethod def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa - if running_remotely(): - command = PatchFire._load_task_params() - if command is not None: - replaced_args = command.split(PatchFire._commands_sep) - else: - replaced_args = [] - for param in PatchFire.__remote_task_params[PatchFire._section_name].values(): - if command is not None and param.type == PatchFire._command_arg_type_template % command: - replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):]) - value = PatchFire.__remote_task_params_dict[param.name] - if len(value) > 0: - replaced_args.append(value) - if param.type == PatchFire._shared_arg_type: - replaced_args.append("--" + param.name) - value = PatchFire.__remote_task_params_dict[param.name] - if len(value) > 0: - replaced_args.append(value) - return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs) - return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs) + if not running_remotely(): + return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs) + command = PatchFire._load_task_params() + if command is not None: + replaced_args = command.split(PatchFire._commands_sep) + else: + replaced_args = [] + for param in PatchFire.__remote_task_params[PatchFire._section_name].values(): + if command is not None and param.type == PatchFire._command_arg_type_template % command: + replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):]) + value = PatchFire.__remote_task_params_dict[param.name] + if len(value) > 0: + replaced_args.append(value) + if param.type == PatchFire._shared_arg_type: + replaced_args.append("--" + param.name) + value = PatchFire.__remote_task_params_dict[param.name] + if len(value) > 0: + replaced_args.append(value) + return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs) @staticmethod def __CallAndUpdateTrace( # noqa diff --git a/clearml/binding/frameworks/catboost_bind.py b/clearml/binding/frameworks/catboost_bind.py index d83e5ac9..52afcbd4 100644 --- a/clearml/binding/frameworks/catboost_bind.py +++ b/clearml/binding/frameworks/catboost_bind.py @@ -11,13 +11,15 @@ from ...model import Framework class PatchCatBoostModelIO(PatchBaseModelIO): - __main_task = None + _current_task = None __patched = None __callback_cls = None @staticmethod def update_current_task(task, **kwargs): - PatchCatBoostModelIO.__main_task = task + PatchCatBoostModelIO._current_task = task + if not task: + return PatchCatBoostModelIO._patch_model_io() PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io) @@ -38,16 +40,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO): CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit) CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit) CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit) - CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit) + CatBoostRanker.fit = _patched_call(CatBoostRanker.fit, PatchCatBoostModelIO._fit) except Exception as e: - logger = PatchCatBoostModelIO.__main_task.get_logger() + logger = PatchCatBoostModelIO._current_task.get_logger() logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'") @staticmethod def _save(original_fn, obj, f, *args, **kwargs): # see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model ret = original_fn(obj, f, *args, **kwargs) - if not PatchCatBoostModelIO.__main_task: + if not PatchCatBoostModelIO._current_task: return ret if isinstance(f, six.string_types): filename = f @@ -60,13 +62,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO): except Exception: model_name = None WeightsFileHandler.create_output_model( - obj, filename, Framework.catboost, PatchCatBoostModelIO.__main_task, singlefile=True, model_name=model_name + obj, filename, Framework.catboost, PatchCatBoostModelIO._current_task, singlefile=True, model_name=model_name ) return ret @staticmethod def _load(original_fn, f, *args, **kwargs): # see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model + if not PatchCatBoostModelIO._current_task: + return original_fn(f, *args, **kwargs) + if isinstance(f, six.string_types): filename = f elif len(args) >= 1 and isinstance(args[0], six.string_types): @@ -74,13 +79,10 @@ class PatchCatBoostModelIO(PatchBaseModelIO): else: filename = None - if not PatchCatBoostModelIO.__main_task: - return original_fn(f, *args, **kwargs) - # register input model empty = _Empty() model = original_fn(f, *args, **kwargs) - WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO.__main_task) + WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO._current_task) if empty.trains_in_model: # noinspection PyBroadException try: @@ -91,13 +93,15 @@ class PatchCatBoostModelIO(PatchBaseModelIO): @staticmethod def _fit(original_fn, obj, *args, **kwargs): + if not PatchCatBoostModelIO._current_task: + return original_fn(obj, *args, **kwargs) callbacks = kwargs.get("callbacks") or [] - kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)] + kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO._current_task)] # noinspection PyBroadException try: return original_fn(obj, *args, **kwargs) except Exception: - logger = PatchCatBoostModelIO.__main_task.get_logger() + logger = PatchCatBoostModelIO._current_task.get_logger() logger.report_text( "Catboost metrics logging is not supported for GPU. " "See https://github.com/catboost/catboost/issues/1792" diff --git a/clearml/binding/frameworks/fastai_bind.py b/clearml/binding/frameworks/fastai_bind.py index ce1fb703..82bb2bd0 100644 --- a/clearml/binding/frameworks/fastai_bind.py +++ b/clearml/binding/frameworks/fastai_bind.py @@ -25,11 +25,9 @@ class PatchFastai(object): try: if Version(fastai.__version__) < Version("2.0.0"): PatchFastaiV1.update_current_task(task) - PatchFastaiV1.patch_model_callback() PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback) else: PatchFastaiV2.update_current_task(task) - PatchFastaiV2.patch_model_callback() PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback) except Exception: pass @@ -38,11 +36,17 @@ class PatchFastai(object): class PatchFastaiV1(object): __metrics_names = {} __gradient_hist_helpers = {} - _main_task = None + _current_task = None + __patched = False @staticmethod def update_current_task(task, **_): - PatchFastaiV1._main_task = task + PatchFastaiV1._current_task = task + if not task: + return + if not PatchFastaiV1.__patched: + PatchFastaiV1.__patched = True + PatchFastaiV1.patch_model_callback() @staticmethod def patch_model_callback(): @@ -52,7 +56,6 @@ class PatchFastaiV1(object): try: from fastai.basic_train import Recorder - Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end) Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end) Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end) @@ -65,7 +68,7 @@ class PatchFastaiV1(object): @staticmethod def _on_train_begin(original_fn, recorder, *args, **kwargs): original_fn(recorder, *args, **kwargs) - if not PatchFastaiV1._main_task: + if not PatchFastaiV1._current_task: return # noinspection PyBroadException try: @@ -84,7 +87,7 @@ class PatchFastaiV1(object): original_fn(recorder, *args, **kwargs) - if not PatchFastaiV1._main_task: + if not PatchFastaiV1._current_task: return # noinspection PyBroadException @@ -112,7 +115,7 @@ class PatchFastaiV1(object): min_gradient=gradient_stats[:, 5].min(), ) - logger = PatchFastaiV1._main_task.get_logger() + logger = PatchFastaiV1._current_task.get_logger() iteration = kwargs.get("iteration", 0) for name, val in stats_report.items(): logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration) @@ -122,26 +125,26 @@ class PatchFastaiV1(object): @staticmethod def _on_epoch_end(original_fn, recorder, *args, **kwargs): original_fn(recorder, *args, **kwargs) - if not PatchFastaiV1._main_task: + if not PatchFastaiV1._current_task: return # noinspection PyBroadException try: - logger = PatchFastaiV1._main_task.get_logger() + logger = PatchFastaiV1._current_task.get_logger() iteration = kwargs.get("iteration") for series, value in zip( PatchFastaiV1.__metrics_names[id(recorder)], [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), ): logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration) - PatchFastaiV1._main_task.flush() + PatchFastaiV1._current_task.flush() except Exception: pass @staticmethod def _on_batch_end(original_fn, recorder, *args, **kwargs): original_fn(recorder, *args, **kwargs) - if not PatchFastaiV1._main_task: + if not PatchFastaiV1._current_task: return # noinspection PyBroadException @@ -150,7 +153,7 @@ class PatchFastaiV1(object): if iteration == 0 or not kwargs.get("train"): return - logger = PatchFastaiV1._main_task.get_logger() + logger = PatchFastaiV1._current_task.get_logger() logger.report_scalar( title="metrics", series="train_loss", @@ -174,11 +177,17 @@ class PatchFastaiV1(object): class PatchFastaiV2(object): - _main_task = None + _current_task = None + __patched = False @staticmethod def update_current_task(task, **_): - PatchFastaiV2._main_task = task + PatchFastaiV2._current_task = task + if not task: + return + if not PatchFastaiV2.__patched: + PatchFastaiV2.__patched = True + PatchFastaiV2.patch_model_callback() @staticmethod def patch_model_callback(): @@ -212,13 +221,15 @@ class PatchFastaiV2(object): self.logger = noop self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id) PatchFastaiV2.PatchFastaiCallbacks.__id += 1 - self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._main_task.get_logger()) + self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._current_task.get_logger()) def after_batch(self): # noinspection PyBroadException try: super().after_batch() # noqa - logger = PatchFastaiV2._main_task.get_logger() + if not PatchFastaiV2._current_task: + return + logger = PatchFastaiV2._current_task.get_logger() if not self.training: # noqa return self.__train_iter += 1 @@ -254,7 +265,9 @@ class PatchFastaiV2(object): # noinspection PyBroadException try: super().after_epoch() # noqa - logger = PatchFastaiV2._main_task.get_logger() + if not PatchFastaiV2._current_task: + return + logger = PatchFastaiV2._current_task.get_logger() for metric in self._valid_mets: # noqa logger.report_scalar( title="metrics_" + self.__id, @@ -270,7 +283,9 @@ class PatchFastaiV2(object): try: if hasattr(fastai.learner.Recorder, "before_step"): super().before_step() # noqa - logger = PatchFastaiV2._main_task.get_logger() + if not PatchFastaiV2._current_task: + return + logger = PatchFastaiV2._current_task.get_logger() gradients = [ x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None ] # noqa diff --git a/clearml/binding/frameworks/lightgbm_bind.py b/clearml/binding/frameworks/lightgbm_bind.py index 524e8c1f..3493027a 100644 --- a/clearml/binding/frameworks/lightgbm_bind.py +++ b/clearml/binding/frameworks/lightgbm_bind.py @@ -11,12 +11,14 @@ from ...model import Framework class PatchLIGHTgbmModelIO(PatchBaseModelIO): - __main_task = None + _current_task = None __patched = None @staticmethod def update_current_task(task, **kwargs): - PatchLIGHTgbmModelIO.__main_task = task + PatchLIGHTgbmModelIO._current_task = task + if not task: + return PatchLIGHTgbmModelIO._patch_model_io() PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io) @@ -43,7 +45,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO): @staticmethod def _save(original_fn, obj, f, *args, **kwargs): ret = original_fn(obj, f, *args, **kwargs) - if not PatchLIGHTgbmModelIO.__main_task: + if not PatchLIGHTgbmModelIO._current_task: return ret if isinstance(f, six.string_types): @@ -64,12 +66,15 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO): model_name = Path(filename).stem except Exception: model_name = None - WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task, + WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO._current_task, singlefile=True, model_name=model_name) return ret @staticmethod def _load(original_fn, model_file, *args, **kwargs): + if not PatchLIGHTgbmModelIO._current_task: + return original_fn(model_file, *args, **kwargs) + if isinstance(model_file, six.string_types): filename = model_file elif hasattr(model_file, 'name'): @@ -79,21 +84,18 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO): else: filename = None - if not PatchLIGHTgbmModelIO.__main_task: - return original_fn(model_file, *args, **kwargs) - # register input model empty = _Empty() # Hack: disabled if False and running_remotely(): filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, - PatchLIGHTgbmModelIO.__main_task) + PatchLIGHTgbmModelIO._current_task) model = original_fn(model_file=filename or model_file, *args, **kwargs) else: # try to load model before registering, in case we fail model = original_fn(model_file=model_file, *args, **kwargs) WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm, - PatchLIGHTgbmModelIO.__main_task) + PatchLIGHTgbmModelIO._current_task) if empty.trains_in_model: # noinspection PyBroadException @@ -110,7 +112,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO): # logging the results to scalars section # noinspection PyBroadException try: - logger = PatchLIGHTgbmModelIO.__main_task.get_logger() + logger = PatchLIGHTgbmModelIO._current_task.get_logger() iteration = env.iteration for data_title, data_series, value, _ in env.evaluation_result_list: logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value), @@ -121,12 +123,12 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO): kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback()) ret = original_fn(*args, **kwargs) - if not PatchLIGHTgbmModelIO.__main_task: + if not PatchLIGHTgbmModelIO._current_task: return ret params = args[0] if args else kwargs.get('params', {}) for k, v in params.items(): if isinstance(v, set): params[k] = list(v) if params: - PatchLIGHTgbmModelIO.__main_task.connect(params) + PatchLIGHTgbmModelIO._current_task.connect(params) return ret diff --git a/clearml/binding/frameworks/megengine_bind.py b/clearml/binding/frameworks/megengine_bind.py index a06db8e5..af64a7b1 100644 --- a/clearml/binding/frameworks/megengine_bind.py +++ b/clearml/binding/frameworks/megengine_bind.py @@ -10,14 +10,14 @@ from ...model import Framework class PatchMegEngineModelIO(PatchBaseModelIO): - __main_task = None + _current_task = None __patched = None - # __patched_lightning = None - @staticmethod def update_current_task(task, **_): - PatchMegEngineModelIO.__main_task = task + PatchMegEngineModelIO._current_task = task + if not task: + return PatchMegEngineModelIO._patch_model_io() PostImportHookPatching.add_on_import( 'megengine', PatchMegEngineModelIO._patch_model_io @@ -58,7 +58,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO): ret = original_fn(obj, f, *args, **kwargs) # if there is no main task or this is a nested call - if not PatchMegEngineModelIO.__main_task: + if not PatchMegEngineModelIO._current_task: return ret # noinspection PyBroadException @@ -93,7 +93,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO): WeightsFileHandler.create_output_model( obj, filename, Framework.megengine, - PatchMegEngineModelIO.__main_task, + PatchMegEngineModelIO._current_task, singlefile=True, model_name=model_name, ) @@ -102,7 +102,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO): @staticmethod def _load(original_fn, f, *args, **kwargs): # if there is no main task or this is a nested call - if not PatchMegEngineModelIO.__main_task: + if not PatchMegEngineModelIO._current_task: return original_fn(f, *args, **kwargs) # noinspection PyBroadException @@ -124,7 +124,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO): model = original_fn(f, *args, **kwargs) WeightsFileHandler.restore_weights_file( empty, filename, Framework.megengine, - PatchMegEngineModelIO.__main_task + PatchMegEngineModelIO._current_task ) if empty.trains_in_model: @@ -139,7 +139,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO): @staticmethod def _load_from_obj(original_fn, obj, f, *args, **kwargs): # if there is no main task or this is a nested call - if not PatchMegEngineModelIO.__main_task: + if not PatchMegEngineModelIO._current_task: return original_fn(obj, f, *args, **kwargs) # noinspection PyBroadException @@ -161,7 +161,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO): model = original_fn(obj, f, *args, **kwargs) WeightsFileHandler.restore_weights_file( empty, filename, Framework.megengine, - PatchMegEngineModelIO.__main_task, + PatchMegEngineModelIO._current_task, ) if empty.trains_in_model: diff --git a/clearml/binding/frameworks/pytorch_bind.py b/clearml/binding/frameworks/pytorch_bind.py index 1a7c92e4..4cfea85a 100644 --- a/clearml/binding/frameworks/pytorch_bind.py +++ b/clearml/binding/frameworks/pytorch_bind.py @@ -11,13 +11,15 @@ from ...model import Framework class PatchPyTorchModelIO(PatchBaseModelIO): - __main_task = None + _current_task = None __patched = None __patched_lightning = None @staticmethod def update_current_task(task, **_): - PatchPyTorchModelIO.__main_task = task + PatchPyTorchModelIO._current_task = task + if not task: + return PatchPyTorchModelIO._patch_model_io() PatchPyTorchModelIO._patch_lightning_io() PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) @@ -112,7 +114,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): ret = original_fn(obj, f, *args, **kwargs) # if there is no main task or this is a nested call - if not PatchPyTorchModelIO.__main_task: + if not PatchPyTorchModelIO._current_task: return ret # pytorch-lightning check if rank is zero @@ -154,14 +156,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO): model_name = None WeightsFileHandler.create_output_model( - obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name) + obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name) return ret @staticmethod def _load(original_fn, f, *args, **kwargs): # if there is no main task or this is a nested call - if not PatchPyTorchModelIO.__main_task: + if not PatchPyTorchModelIO._current_task: return original_fn(f, *args, **kwargs) # noinspection PyBroadException @@ -182,13 +184,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO): # Hack: disabled if False and running_remotely(): filename = WeightsFileHandler.restore_weights_file( - empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) + empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task) model = original_fn(filename or f, *args, **kwargs) else: # try to load model before registering, in case we fail model = original_fn(f, *args, **kwargs) WeightsFileHandler.restore_weights_file( - empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) + empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task) if empty.trains_in_model: # noinspection PyBroadException @@ -202,7 +204,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO): @staticmethod def _load_from_obj(original_fn, obj, f, *args, **kwargs): # if there is no main task or this is a nested call - if not PatchPyTorchModelIO.__main_task: + if not PatchPyTorchModelIO._current_task: return original_fn(obj, f, *args, **kwargs) # noinspection PyBroadException @@ -223,13 +225,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO): # Hack: disabled if False and running_remotely(): filename = WeightsFileHandler.restore_weights_file( - empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) + empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task) model = original_fn(obj, filename or f, *args, **kwargs) else: # try to load model before registering, in case we fail model = original_fn(obj, f, *args, **kwargs) WeightsFileHandler.restore_weights_file( - empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) + empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task) if empty.trains_in_model: # noinspection PyBroadException diff --git a/clearml/binding/frameworks/tensorflow_bind.py b/clearml/binding/frameworks/tensorflow_bind.py index b2a47d9b..745ece15 100644 --- a/clearml/binding/frameworks/tensorflow_bind.py +++ b/clearml/binding/frameworks/tensorflow_bind.py @@ -232,7 +232,7 @@ class EventTrainsWriter(object): TF SummaryWriter implementation that converts the tensorboard's summary into ClearML events and reports the events (metrics) for an ClearML task (logger). """ - __main_task = None + _current_task = None __report_hparams = True _add_lock = threading.RLock() _series_name_lookup = {} @@ -641,7 +641,7 @@ class EventTrainsWriter(object): session_start_info = parse_session_start_info_plugin_data(content) session_start_info = MessageToDict(session_start_info) hparams = session_start_info["hparams"] - EventTrainsWriter.__main_task.update_parameters( + EventTrainsWriter._current_task.update_parameters( {"TB_hparams/{}".format(k): v for k, v in hparams.items()} ) except Exception: @@ -844,13 +844,13 @@ class EventTrainsWriter(object): @classmethod def update_current_task(cls, task, **kwargs): cls.__report_hparams = kwargs.get('report_hparams', False) - if cls.__main_task != task: + if cls._current_task != task: with cls._add_lock: cls._series_name_lookup = {} cls._title_series_writers_lookup = {} cls._event_writers_id_to_logdir = {} cls._title_series_wraparound_counter = {} - cls.__main_task = task + cls._current_task = task # noinspection PyCallingNonCallable @@ -905,9 +905,10 @@ class ProxyEventsWriter(object): # noinspection PyPep8Naming class PatchSummaryToEventTransformer(object): - __main_task = None + _current_task = None __original_getattribute = None __original_getattributeX = None + __patched = False _original_add_event = None _original_add_eventT = None _original_add_eventX = None @@ -930,15 +931,19 @@ class PatchSummaryToEventTransformer(object): @staticmethod def update_current_task(task, **kwargs): PatchSummaryToEventTransformer.defaults_dict.update(kwargs) - PatchSummaryToEventTransformer.__main_task = task - # make sure we patched the SummaryToEventTransformer - PatchSummaryToEventTransformer._patch_summary_to_event_transformer() - PostImportHookPatching.add_on_import('tensorflow', - PatchSummaryToEventTransformer._patch_summary_to_event_transformer) - PostImportHookPatching.add_on_import('torch', - PatchSummaryToEventTransformer._patch_summary_to_event_transformer) - PostImportHookPatching.add_on_import('tensorboardX', - PatchSummaryToEventTransformer._patch_summary_to_event_transformer) + PatchSummaryToEventTransformer._current_task = task + if not task: + return + if not PatchSummaryToEventTransformer.__patched: + PatchSummaryToEventTransformer.__patched = True + # make sure we patched the SummaryToEventTransformer + PatchSummaryToEventTransformer._patch_summary_to_event_transformer() + PostImportHookPatching.add_on_import('tensorflow', + PatchSummaryToEventTransformer._patch_summary_to_event_transformer) + PostImportHookPatching.add_on_import('torch', + PatchSummaryToEventTransformer._patch_summary_to_event_transformer) + PostImportHookPatching.add_on_import('tensorboardX', + PatchSummaryToEventTransformer._patch_summary_to_event_transformer) @staticmethod def _patch_summary_to_event_transformer(): @@ -1002,7 +1007,7 @@ class PatchSummaryToEventTransformer(object): @staticmethod def _patched_add_eventT(self, *args, **kwargs): - if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task: + if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task: return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs) if not self.clearml: # noqa # noinspection PyBroadException @@ -1010,7 +1015,7 @@ class PatchSummaryToEventTransformer(object): logdir = self.get_logdir() except Exception: logdir = None - self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), + self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(), logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) # noinspection PyBroadException try: @@ -1021,7 +1026,7 @@ class PatchSummaryToEventTransformer(object): @staticmethod def _patched_add_eventX(self, *args, **kwargs): - if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task: + if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task: return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs) if not self.clearml: # noinspection PyBroadException @@ -1029,7 +1034,7 @@ class PatchSummaryToEventTransformer(object): logdir = self.get_logdir() except Exception: logdir = None - self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), + self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(), logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) # noinspection PyBroadException try: @@ -1051,7 +1056,7 @@ class PatchSummaryToEventTransformer(object): @staticmethod def _patched_getattribute_(self, attr, get_base): # no main task, zero chance we have an ClearML event logger - if PatchSummaryToEventTransformer.__main_task is None: + if PatchSummaryToEventTransformer._current_task is None: return get_base(self, attr) # check if we already have an ClearML event logger @@ -1068,7 +1073,7 @@ class PatchSummaryToEventTransformer(object): except Exception: logdir = None defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict - trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), + trains_event = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(), logdir=logdir, **defaults_dict) # order is important, the return value of ProxyEventsWriter is the last object in the list @@ -1111,8 +1116,9 @@ class _ModelAdapter(object): class PatchModelCheckPointCallback(object): - __main_task = None + _current_task = None __original_getattribute = None + __patched = False defaults_dict = dict( config_text=None, config_dict=None, @@ -1132,11 +1138,15 @@ class PatchModelCheckPointCallback(object): @staticmethod def update_current_task(task, **kwargs): PatchModelCheckPointCallback.defaults_dict.update(kwargs) - PatchModelCheckPointCallback.__main_task = task - # make sure we patched the SummaryToEventTransformer - PatchModelCheckPointCallback._patch_model_checkpoint() - PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint) - PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint) + PatchModelCheckPointCallback._current_task = task + if not task: + return + if not PatchModelCheckPointCallback.__patched: + PatchModelCheckPointCallback.__patched = True + # make sure we patched the SummaryToEventTransformer + PatchModelCheckPointCallback._patch_model_checkpoint() + PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint) + PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint) @staticmethod def _patch_model_checkpoint(): @@ -1176,7 +1186,7 @@ class PatchModelCheckPointCallback(object): get_base = PatchModelCheckPointCallback.__original_getattribute # no main task, zero chance we have an ClearML event logger - if PatchModelCheckPointCallback.__main_task is None: + if PatchModelCheckPointCallback._current_task is None: return get_base(self, attr) # check if we already have an ClearML event logger @@ -1189,17 +1199,17 @@ class PatchModelCheckPointCallback(object): base_model = __dict__['model'] defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict output_model = OutputModel( - PatchModelCheckPointCallback.__main_task, + PatchModelCheckPointCallback._current_task, config_text=defaults_dict.get('config_text'), config_dict=defaults_dict.get('config_dict'), name=defaults_dict.get('name'), comment=defaults_dict.get('comment'), label_enumeration=defaults_dict.get('label_enumeration') or - PatchModelCheckPointCallback.__main_task.get_labels_enumeration(), + PatchModelCheckPointCallback._current_task.get_labels_enumeration(), framework=Framework.keras, ) output_model.set_upload_destination( - PatchModelCheckPointCallback.__main_task.get_output_destination(raise_on_error=False)) + PatchModelCheckPointCallback._current_task.get_output_destination(raise_on_error=False)) trains_model = _ModelAdapter(base_model, output_model) # order is important, the return value of ProxyEventsWriter is the last object in the list @@ -1209,26 +1219,32 @@ class PatchModelCheckPointCallback(object): # noinspection PyProtectedMember,PyUnresolvedReferences class PatchTensorFlowEager(object): - __main_task = None + _current_task = None __original_fn_scalar = None __original_fn_hist = None __original_fn_image = None + __original_fn_write_summary = None __trains_event_writer = {} __tf_tb_writer_id_to_logdir = {} + __patched = False defaults_dict = dict( report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5, histogram_granularity=50) @staticmethod def update_current_task(task, **kwargs): - if task != PatchTensorFlowEager.__main_task: + if task != PatchTensorFlowEager._current_task: PatchTensorFlowEager.__trains_event_writer = {} PatchTensorFlowEager.defaults_dict.update(kwargs) - PatchTensorFlowEager.__main_task = task - # make sure we patched the SummaryToEventTransformer - PatchTensorFlowEager._patch_summary_ops() - PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops) + PatchTensorFlowEager._current_task = task + if not task: + return + if not PatchTensorFlowEager.__patched: + PatchTensorFlowEager.__patched = True + # make sure we patched the SummaryToEventTransformer + PatchTensorFlowEager._patch_summary_ops() + PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops) @staticmethod def _patch_summary_ops(): @@ -1245,7 +1261,7 @@ class PatchTensorFlowEager(object): gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary - PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary + PatchTensorFlowEager.__original_fn_write_summary = gen_summary_ops.write_summary gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__, gen_summary_ops.create_summary_file_writer) @@ -1270,6 +1286,8 @@ class PatchTensorFlowEager(object): @staticmethod def _create_summary_file_writer(original_fn, *args, **kwargs): + if not PatchTensorFlowEager._current_task: + return original_fn(*args, **kwargs) # noinspection PyBroadException try: a_logdir = None @@ -1294,7 +1312,7 @@ class PatchTensorFlowEager(object): @staticmethod def _get_event_writer(writer): - if not PatchTensorFlowEager.__main_task: + if not PatchTensorFlowEager._current_task: return None if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)): # noinspection PyBroadException @@ -1324,7 +1342,7 @@ class PatchTensorFlowEager(object): logdir = None PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter( - logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir, + logger=PatchTensorFlowEager._current_task.get_logger(), logdir=logdir, **PatchTensorFlowEager.defaults_dict) return PatchTensorFlowEager.__trains_event_writer[id(writer)] @@ -1337,6 +1355,10 @@ class PatchTensorFlowEager(object): @staticmethod def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs): + if not PatchTensorFlowEager._current_task: + return PatchTensorFlowEager.__original_fn_write_summary( + writer, step, tensor, tag, summary_metadata, name, **kwargs + ) event_writer = PatchTensorFlowEager._get_event_writer(writer) # make sure we can get the tensors values if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): @@ -1370,13 +1392,17 @@ class PatchTensorFlowEager(object): step=int(step.numpy()) if not isinstance(step, int) else step, values=None, audio_data=audio_bytes) else: - pass # print('unsupported plugin_type', plugin_type) + pass except Exception: pass - return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs) + return PatchTensorFlowEager.__original_fn_write_summary( + writer, step, tensor, tag, summary_metadata, name, **kwargs + ) @staticmethod def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): + if not PatchTensorFlowEager._current_task: + return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs) event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: @@ -1417,6 +1443,8 @@ class PatchTensorFlowEager(object): @staticmethod def _write_hist_summary(writer, step, tag, values, name, **kwargs): + if not PatchTensorFlowEager._current_task: + return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs) event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: @@ -1459,6 +1487,9 @@ class PatchTensorFlowEager(object): @staticmethod def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs): + if not PatchTensorFlowEager._current_task: + return PatchTensorFlowEager.__original_fn_image( + writer, step, tag, tensor, bad_color, max_images, name, **kwargs) event_writer = PatchTensorFlowEager._get_event_writer(writer) if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): try: @@ -1526,13 +1557,15 @@ class PatchTensorFlowEager(object): # noinspection PyPep8Naming,SpellCheckingInspection class PatchKerasModelIO(object): - __main_task = None + _current_task = None __patched_keras = None __patched_tensorflow = None @staticmethod def update_current_task(task, **_): - PatchKerasModelIO.__main_task = task + PatchKerasModelIO._current_task = task + if not task: + return PatchKerasModelIO._patch_model_checkpoint() PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint) PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint) @@ -1692,7 +1725,7 @@ class PatchKerasModelIO(object): def _updated_config(original_fn, self): config = original_fn(self) # check if we have main task - if PatchKerasModelIO.__main_task is None: + if PatchKerasModelIO._current_task is None: return config try: @@ -1717,10 +1750,10 @@ class PatchKerasModelIO(object): else: # todo: support multiple models for the same task self.trains_out_model.append(OutputModel( - task=PatchKerasModelIO.__main_task, + task=PatchKerasModelIO._current_task, config_dict=safe_config, - name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id, - label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), + name=PatchKerasModelIO._current_task.name + ' ' + model_name_id, + label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(), framework=Framework.keras, )) except Exception as ex: @@ -1738,7 +1771,7 @@ class PatchKerasModelIO(object): self = _Empty() # check if we have main task - if PatchKerasModelIO.__main_task is None: + if PatchKerasModelIO._current_task is None: return self try: @@ -1751,10 +1784,10 @@ class PatchKerasModelIO(object): # check if object already has InputModel self.trains_in_model = InputModel.empty( config_dict=config_dict, - label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), + label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(), ) # todo: support multiple models for the same task - PatchKerasModelIO.__main_task.connect(self.trains_in_model) + PatchKerasModelIO._current_task.connect(self.trains_in_model) # if we are running remotely we should deserialize the object # because someone might have changed the configuration # Hack: disabled @@ -1781,7 +1814,7 @@ class PatchKerasModelIO(object): @staticmethod def _load_weights(original_fn, self, *args, **kwargs): # check if we have main task - if PatchKerasModelIO.__main_task is None: + if PatchKerasModelIO._current_task is None: return original_fn(self, *args, **kwargs) # get filepath @@ -1794,7 +1827,7 @@ class PatchKerasModelIO(object): if False and running_remotely(): # register/load model weights filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, - PatchKerasModelIO.__main_task) + PatchKerasModelIO._current_task) if 'filepath' in kwargs: kwargs['filepath'] = filepath else: @@ -1805,11 +1838,13 @@ class PatchKerasModelIO(object): # try to load the files, if something happened exception will be raised before we register the file model = original_fn(self, *args, **kwargs) # register/load model weights - WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task) + WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO._current_task) return model @staticmethod def _save(original_fn, self, *args, **kwargs): + if not PatchKerasModelIO._current_task: + return original_fn(self, *args, **kwargs) if hasattr(self, 'trains_out_model') and self.trains_out_model: # noinspection PyProtectedMember self.trains_out_model[-1]._processed = False @@ -1823,12 +1858,13 @@ class PatchKerasModelIO(object): @staticmethod def _save_weights(original_fn, self, *args, **kwargs): original_fn(self, *args, **kwargs) - PatchKerasModelIO._update_outputmodel(self, *args, **kwargs) + if PatchKerasModelIO._current_task: + PatchKerasModelIO._update_outputmodel(self, *args, **kwargs) @staticmethod def _update_outputmodel(self, *args, **kwargs): # check if we have main task - if PatchKerasModelIO.__main_task is None: + if not PatchKerasModelIO._current_task: return try: @@ -1851,7 +1887,7 @@ class PatchKerasModelIO(object): if filepath: WeightsFileHandler.create_output_model( - self, filepath, Framework.keras, PatchKerasModelIO.__main_task, + self, filepath, Framework.keras, PatchKerasModelIO._current_task, config_obj=config or None, singlefile=True) except Exception as ex: @@ -1860,12 +1896,12 @@ class PatchKerasModelIO(object): @staticmethod def _save_model(original_fn, model, filepath, *args, **kwargs): original_fn(model, filepath, *args, **kwargs) - if PatchKerasModelIO.__main_task: + if PatchKerasModelIO._current_task: PatchKerasModelIO._update_outputmodel(model, filepath) @staticmethod def _load_model(original_fn, filepath, *args, **kwargs): - if not PatchKerasModelIO.__main_task: + if not PatchKerasModelIO._current_task: return original_fn(filepath, *args, **kwargs) empty = _Empty() @@ -1873,12 +1909,12 @@ class PatchKerasModelIO(object): if False and running_remotely(): # register/load model weights filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, - PatchKerasModelIO.__main_task) + PatchKerasModelIO._current_task) model = original_fn(filepath, *args, **kwargs) else: model = original_fn(filepath, *args, **kwargs) # register/load model weights - WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task) + WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO._current_task) # update the input model object if empty.trains_in_model: # noinspection PyBroadException @@ -1891,12 +1927,14 @@ class PatchKerasModelIO(object): class PatchTensorflowModelIO(object): - __main_task = None + _current_task = None __patched = None @staticmethod def update_current_task(task, **_): - PatchTensorflowModelIO.__main_task = task + PatchTensorflowModelIO._current_task = task + if not task: + return PatchTensorflowModelIO._patch_model_checkpoint() PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint) @@ -2028,29 +2066,31 @@ class PatchTensorflowModelIO(object): @staticmethod def _save(original_fn, self, sess, save_path, *args, **kwargs): saved_path = original_fn(self, sess, save_path, *args, **kwargs) - if not saved_path: + if not saved_path or not PatchTensorflowModelIO._current_task: return saved_path # store output Model return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) @staticmethod def _save_model(original_fn, obj, export_dir, *args, **kwargs): original_fn(obj, export_dir, *args, **kwargs) + if not PatchKerasModelIO._current_task: + return # store output Model WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) @staticmethod def _restore(original_fn, self, sess, save_path, *args, **kwargs): - if PatchTensorflowModelIO.__main_task is None: + if PatchTensorflowModelIO._current_task is None: return original_fn(self, sess, save_path, *args, **kwargs) # Hack: disabled if False and running_remotely(): # register/load model weights save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) # load model return original_fn(self, sess, save_path, *args, **kwargs) @@ -2058,12 +2098,12 @@ class PatchTensorflowModelIO(object): model = original_fn(self, sess, save_path, *args, **kwargs) # register/load model weights WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) return model @staticmethod def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs): - if PatchTensorflowModelIO.__main_task is None: + if PatchTensorflowModelIO._current_task is None: return original_fn(sess, tags, export_dir, *args, **saver_kwargs) # register input model @@ -2072,7 +2112,7 @@ class PatchTensorflowModelIO(object): if False and running_remotely(): export_dir = WeightsFileHandler.restore_weights_file( empty, export_dir, Framework.tensorflow, - PatchTensorflowModelIO.__main_task + PatchTensorflowModelIO._current_task ) model = original_fn(sess, tags, export_dir, *args, **saver_kwargs) else: @@ -2080,7 +2120,7 @@ class PatchTensorflowModelIO(object): model = original_fn(sess, tags, export_dir, *args, **saver_kwargs) WeightsFileHandler.restore_weights_file( empty, export_dir, Framework.tensorflow, - PatchTensorflowModelIO.__main_task + PatchTensorflowModelIO._current_task ) if empty.trains_in_model: @@ -2093,7 +2133,7 @@ class PatchTensorflowModelIO(object): @staticmethod def _load(original_fn, export_dir, *args, **saver_kwargs): - if PatchTensorflowModelIO.__main_task is None: + if PatchTensorflowModelIO._current_task is None: return original_fn(export_dir, *args, **saver_kwargs) # register input model @@ -2102,7 +2142,7 @@ class PatchTensorflowModelIO(object): if False and running_remotely(): export_dir = WeightsFileHandler.restore_weights_file( empty, export_dir, Framework.tensorflow, - PatchTensorflowModelIO.__main_task + PatchTensorflowModelIO._current_task ) model = original_fn(export_dir, *args, **saver_kwargs) else: @@ -2110,7 +2150,7 @@ class PatchTensorflowModelIO(object): model = original_fn(export_dir, *args, **saver_kwargs) WeightsFileHandler.restore_weights_file( empty, export_dir, Framework.tensorflow, - PatchTensorflowModelIO.__main_task + PatchTensorflowModelIO._current_task ) if empty.trains_in_model: @@ -2124,24 +2164,24 @@ class PatchTensorflowModelIO(object): @staticmethod def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs): checkpoint_path = original_fn(self, file_prefix, *args, **kwargs) - if PatchTensorflowModelIO.__main_task is None: + if PatchTensorflowModelIO._current_task is None: return checkpoint_path WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) return checkpoint_path @staticmethod def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs): checkpoint_path = original_fn(self, file_prefix, *args, **kwargs) - if PatchTensorflowModelIO.__main_task is None: + if PatchTensorflowModelIO._current_task is None: return checkpoint_path WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) return checkpoint_path @staticmethod def _ckpt_restore(original_fn, self, save_path, *args, **kwargs): - if PatchTensorflowModelIO.__main_task is None: + if PatchTensorflowModelIO._current_task is None: return original_fn(self, save_path, *args, **kwargs) # register input model @@ -2149,13 +2189,13 @@ class PatchTensorflowModelIO(object): # Hack: disabled if False and running_remotely(): save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) model = original_fn(self, save_path, *args, **kwargs) else: # try to load model before registering it, in case it fails. model = original_fn(self, save_path, *args, **kwargs) WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow, - PatchTensorflowModelIO.__main_task) + PatchTensorflowModelIO._current_task) if empty.trains_in_model: # noinspection PyBroadException @@ -2167,12 +2207,14 @@ class PatchTensorflowModelIO(object): class PatchTensorflow2ModelIO(object): - __main_task = None + _current_task = None __patched = None @staticmethod def update_current_task(task, **_): - PatchTensorflow2ModelIO.__main_task = task + PatchTensorflow2ModelIO._current_task = task + if not task: + return PatchTensorflow2ModelIO._patch_model_checkpoint() PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint) @@ -2210,18 +2252,20 @@ class PatchTensorflow2ModelIO(object): @staticmethod def _save(original_fn, self, file_prefix, *args, **kwargs): model = original_fn(self, file_prefix, *args, **kwargs) + if not PatchTensorflow2ModelIO._current_task: + return model # store output Model # noinspection PyBroadException try: WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow, - PatchTensorflow2ModelIO.__main_task) + PatchTensorflow2ModelIO._current_task) except Exception: pass return model @staticmethod def _restore(original_fn, self, save_path, *args, **kwargs): - if PatchTensorflow2ModelIO.__main_task is None: + if not PatchTensorflow2ModelIO._current_task: return original_fn(self, save_path, *args, **kwargs) # Hack: disabled @@ -2230,7 +2274,7 @@ class PatchTensorflow2ModelIO(object): # noinspection PyBroadException try: save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, - PatchTensorflow2ModelIO.__main_task) + PatchTensorflow2ModelIO._current_task) except Exception: pass # load model @@ -2242,7 +2286,7 @@ class PatchTensorflow2ModelIO(object): # noinspection PyBroadException try: WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, - PatchTensorflow2ModelIO.__main_task) + PatchTensorflow2ModelIO._current_task) except Exception: pass return model diff --git a/clearml/binding/frameworks/xgboost_bind.py b/clearml/binding/frameworks/xgboost_bind.py index fb408d00..2fc1e9ec 100644 --- a/clearml/binding/frameworks/xgboost_bind.py +++ b/clearml/binding/frameworks/xgboost_bind.py @@ -11,13 +11,15 @@ from ...model import Framework class PatchXGBoostModelIO(PatchBaseModelIO): - __main_task = None + _current_task = None __patched = None __callback_cls = None @staticmethod def update_current_task(task, **kwargs): - PatchXGBoostModelIO.__main_task = task + PatchXGBoostModelIO._current_task = task + if not task: + return PatchXGBoostModelIO._patch_model_io() PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io) @@ -55,7 +57,7 @@ class PatchXGBoostModelIO(PatchBaseModelIO): @staticmethod def _save(original_fn, obj, f, *args, **kwargs): ret = original_fn(obj, f, *args, **kwargs) - if not PatchXGBoostModelIO.__main_task: + if not PatchXGBoostModelIO._current_task: return ret if isinstance(f, six.string_types): @@ -76,12 +78,15 @@ class PatchXGBoostModelIO(PatchBaseModelIO): model_name = Path(filename).stem except Exception: model_name = None - WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task, + WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO._current_task, singlefile=True, model_name=model_name) return ret @staticmethod def _load(original_fn, f, *args, **kwargs): + if not PatchXGBoostModelIO._current_task: + return original_fn(f, *args, **kwargs) + if isinstance(f, six.string_types): filename = f elif hasattr(f, 'name'): @@ -91,21 +96,18 @@ class PatchXGBoostModelIO(PatchBaseModelIO): else: filename = None - if not PatchXGBoostModelIO.__main_task: - return original_fn(f, *args, **kwargs) - # register input model empty = _Empty() # Hack: disabled if False and running_remotely(): filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, - PatchXGBoostModelIO.__main_task) + PatchXGBoostModelIO._current_task) model = original_fn(filename or f, *args, **kwargs) else: # try to load model before registering, in case we fail model = original_fn(f, *args, **kwargs) WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, - PatchXGBoostModelIO.__main_task) + PatchXGBoostModelIO._current_task) if empty.trains_in_model: # noinspection PyBroadException @@ -117,10 +119,12 @@ class PatchXGBoostModelIO(PatchBaseModelIO): @staticmethod def _train(original_fn, *args, **kwargs): + if not PatchXGBoostModelIO._current_task: + return original_fn(*args, **kwargs) if PatchXGBoostModelIO.__callback_cls: callbacks = kwargs.get('callbacks') or [] kwargs['callbacks'] = callbacks + [ - PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task) + PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO._current_task) ] return original_fn(*args, **kwargs) diff --git a/clearml/binding/hydra_bind.py b/clearml/binding/hydra_bind.py index f3b4ecfb..5cd6f823 100644 --- a/clearml/binding/hydra_bind.py +++ b/clearml/binding/hydra_bind.py @@ -2,7 +2,7 @@ import io import sys from functools import partial -from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE +from ..config import running_remotely, DEV_TASK_NO_REUSE from ..debugging.log import LoggerRoot @@ -46,6 +46,8 @@ class PatchHydra(object): def update_current_task(task): # set current Task before patching PatchHydra._current_task = task + if not task: + return if PatchHydra.patch_hydra(): # check if we have an untracked state, store it. if PatchHydra._last_untracked_state.get('connect'): @@ -66,9 +68,6 @@ class PatchHydra(object): # noinspection PyBroadException try: if running_remotely(): - if not PatchHydra._current_task: - from ..task import Task - PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id()) # get the _parameter_allow_full_edit casted back to boolean connected_config = dict() connected_config[PatchHydra._parameter_allow_full_edit] = False @@ -126,8 +125,9 @@ class PatchHydra(object): if pre_app_task_init_call and not running_remotely(): LoggerRoot.get_base_logger(PatchHydra).info( - 'Task.init called outside of Hydra-App. For full Hydra multi-run support, ' - 'move the Task.init call into the Hydra-App main function') + "Task.init called outside of Hydra-App. For full Hydra multi-run support, " + "move the Task.init call into the Hydra-App main function" + ) kwargs["config"] = config kwargs["task_function"] = partial(PatchHydra._patched_task_function, task_function,) diff --git a/clearml/binding/joblib_bind.py b/clearml/binding/joblib_bind.py index bcdd7008..547da002 100644 --- a/clearml/binding/joblib_bind.py +++ b/clearml/binding/joblib_bind.py @@ -77,6 +77,8 @@ class PatchedJoblib(object): @staticmethod def update_current_task(task): PatchedJoblib._current_task = task + if not task: + return PatchedJoblib.patch_joblib() @staticmethod diff --git a/clearml/binding/jsonargs_bind.py b/clearml/binding/jsonargs_bind.py index 7019d2bc..eeba275b 100644 --- a/clearml/binding/jsonargs_bind.py +++ b/clearml/binding/jsonargs_bind.py @@ -14,7 +14,7 @@ from ..utilities.proxy_object import flatten_dictionary class PatchJsonArgParse(object): _args = {} - _main_task = None + _current_task = None _args_sep = "/" _args_type = {} _commands_sep = "." @@ -24,14 +24,19 @@ class PatchJsonArgParse(object): __remote_task_params = {} __patched = False + @classmethod + def update_current_task(cls, task): + cls._current_task = task + if not task: + return + cls.patch(task) + @classmethod def patch(cls, task): if ArgumentParser is None: return - if task: - cls._main_task = task - PatchJsonArgParse._update_task_args() + PatchJsonArgParse._update_task_args() if not cls.__patched: cls.__patched = True @@ -39,14 +44,16 @@ class PatchJsonArgParse(object): @classmethod def _update_task_args(cls): - if running_remotely() or not cls._main_task or not cls._args: + if running_remotely() or not cls._current_task or not cls._args: return args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()} args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()} - cls._main_task._set_parameters(args, __update=True, __parameters_types=args_type) + cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type) @staticmethod def _parse_args(original_fn, obj, *args, **kwargs): + if not PatchJsonArgParse._current_task: + return original_fn(obj, *args, **kwargs) if len(args) == 1: kwargs["args"] = args[0] args = [] diff --git a/clearml/binding/matplotlib_bind.py b/clearml/binding/matplotlib_bind.py index 1c7341b8..df7c9af3 100644 --- a/clearml/binding/matplotlib_bind.py +++ b/clearml/binding/matplotlib_bind.py @@ -207,12 +207,15 @@ class PatchedMatplotlib: PatchedMatplotlib._current_task = task PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib) PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib) - elif PatchedMatplotlib.patch_matplotlib(): + else: + PatchedMatplotlib.patch_matplotlib() PatchedMatplotlib._current_task = task @staticmethod def patched_imshow(*args, **kw): ret = PatchedMatplotlib._patched_original_imshow(*args, **kw) + if not PatchedMatplotlib._current_task: + return ret try: from matplotlib import _pylab_helpers # store on the plot that this is an imshow plot @@ -227,6 +230,8 @@ class PatchedMatplotlib: @staticmethod def patched_savefig(self, *args, **kw): ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw) + if not PatchedMatplotlib._current_task: + return ret # noinspection PyBroadException try: fname = kw.get('fname') or args[0] @@ -256,6 +261,8 @@ class PatchedMatplotlib: @staticmethod def patched_figure_show(self, *args, **kw): + if not PatchedMatplotlib._current_task: + return PatchedMatplotlib._patched_original_figure(self, *args, **kw) tid = threading._get_ident() if six.PY2 else threading.get_ident() if PatchedMatplotlib._recursion_guard.get(tid): # we are inside a gaurd do nothing @@ -269,6 +276,8 @@ class PatchedMatplotlib: @staticmethod def patched_show(*args, **kw): + if not PatchedMatplotlib._current_task: + return PatchedMatplotlib._patched_original_plot(*args, **kw) tid = threading._get_ident() if six.PY2 else threading.get_ident() PatchedMatplotlib._recursion_guard[tid] = True # noinspection PyBroadException diff --git a/clearml/task.py b/clearml/task.py index 70f41ad4..47e4d56c 100644 --- a/clearml/task.py +++ b/clearml/task.py @@ -626,8 +626,7 @@ class Task(_Task): task.__register_at_exit(task._at_exit) # always patch OS forking because of ProcessPool and the alike - PatchOsFork.patch_fork() - + PatchOsFork.patch_fork(task) if auto_connect_frameworks: def should_connect(*keys): """ @@ -708,13 +707,15 @@ class Task(_Task): make_deterministic(task.get_random_seed()) if auto_connect_arg_parser: - EnvironmentBind.update_current_task(Task.__main_task) + EnvironmentBind.update_current_task(task) + + PatchJsonArgParse.update_current_task(task) # Patch ArgParser to be aware of the current task - argparser_update_currenttask(Task.__main_task) - PatchClick.patch(Task.__main_task) - PatchFire.patch(Task.__main_task) - PatchJsonArgParse.patch(Task.__main_task) + argparser_update_currenttask(task) + + PatchClick.patch(task) + PatchFire.patch(task) # set excluded arguments if isinstance(auto_connect_arg_parser, dict): @@ -1686,6 +1687,22 @@ class Task(_Task): # noinspection PyProtectedMember Logger._remove_std_logger() + # unbind everything + PatchHydra.update_current_task(None) + PatchedJoblib.update_current_task(None) + PatchedMatplotlib.update_current_task(None) + PatchAbsl.update_current_task(None) + TensorflowBinding.update_current_task(None) + PatchPyTorchModelIO.update_current_task(None) + PatchMegEngineModelIO.update_current_task(None) + PatchXGBoostModelIO.update_current_task(None) + PatchCatBoostModelIO.update_current_task(None) + PatchFastai.update_current_task(None) + PatchLIGHTgbmModelIO.update_current_task(None) + EnvironmentBind.update_current_task(None) + PatchJsonArgParse.update_current_task(None) + PatchOsFork.patch_fork(None) + def delete( self, delete_artifacts_and_models=True,