Unbind all patches when Task.close() is called

This commit is contained in:
allegroai 2022-05-12 23:47:58 +03:00
parent 3c396b1f7e
commit 5dff9dd404
16 changed files with 350 additions and 227 deletions

View File

@ -8,12 +8,15 @@ from ..config import running_remotely
class PatchAbsl(object): class PatchAbsl(object):
_original_DEFINE_flag = None _original_DEFINE_flag = None
_original_FLAGS_parse_call = None _original_FLAGS_parse_call = None
_task = None _current_task = None
__patched = False
@classmethod @classmethod
def update_current_task(cls, current_task): def update_current_task(cls, task):
cls._task = current_task cls._current_task = task
cls._patch_absl() if not cls.__patched:
cls._patch_absl()
cls.__patched = True
@classmethod @classmethod
def _patch_absl(cls): def _patch_absl(cls):
@ -53,7 +56,7 @@ class PatchAbsl(object):
@staticmethod @staticmethod
def _patched_define_flag(*args, **kwargs): def _patched_define_flag(*args, **kwargs):
if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag: if not PatchAbsl._current_task or not PatchAbsl._original_DEFINE_flag:
if PatchAbsl._original_DEFINE_flag: if PatchAbsl._original_DEFINE_flag:
return PatchAbsl._original_DEFINE_flag(*args, **kwargs) return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
else: else:
@ -73,7 +76,7 @@ class PatchAbsl(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if param_name and flag: if param_name and flag:
param_dict = PatchAbsl._task._arguments.copy_to_dict( param_dict = PatchAbsl._current_task._arguments.copy_to_dict(
{param_name: flag.value}, prefix=_Arguments._prefix_tf_defines) {param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
flag.value = param_dict.get(param_name, flag.value) flag.value = param_dict.get(param_name, flag.value)
except Exception: except Exception:
@ -82,13 +85,17 @@ class PatchAbsl(object):
else: else:
if flag and param_name: if flag and param_name:
value = flag.value value = flag.value
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, ) PatchAbsl._current_task.update_parameters(
{_Arguments._prefix_tf_defines + param_name: value},
)
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs) ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
return ret return ret
@staticmethod @staticmethod
def _patched_FLAGS_parse_call(self, *args, **kwargs): def _patched_FLAGS_parse_call(self, *args, **kwargs):
ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs) ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
if not PatchAbsl._current_task:
return ret
# noinspection PyBroadException # noinspection PyBroadException
try: try:
PatchAbsl._update_current_flags(self) PatchAbsl._update_current_flags(self)
@ -98,13 +105,13 @@ class PatchAbsl(object):
@classmethod @classmethod
def _update_current_flags(cls, FLAGS): def _update_current_flags(cls, FLAGS):
if not cls._task: if not cls._current_task:
return return
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if running_remotely(): if running_remotely():
param_dict = dict((k, FLAGS[k].value) for k in FLAGS) param_dict = dict((k, FLAGS[k].value) for k in FLAGS)
param_dict = cls._task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines) param_dict = cls._current_task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
for k, v in param_dict.items(): for k, v in param_dict.items():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -127,7 +134,7 @@ class PatchAbsl(object):
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS]) param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
except Exception: except Exception:
param_types = None param_types = None
cls._task._arguments.copy_from_dict( cls._current_task._arguments.copy_from_dict(
parameters, parameters,
prefix=_Arguments._prefix_tf_defines, prefix=_Arguments._prefix_tf_defines,
descriptions=descriptions, param_types=param_types, descriptions=descriptions, param_types=param_types,

View File

@ -16,7 +16,7 @@ class PatchClick:
_num_commands = 0 _num_commands = 0
_command_type = 'click.Command' _command_type = 'click.Command'
_section_name = 'Args' _section_name = 'Args'
_main_task = None _current_task = None
__remote_task_params = None __remote_task_params = None
__remote_task_params_dict = {} __remote_task_params_dict = {}
__patched = False __patched = False
@ -26,8 +26,8 @@ class PatchClick:
if Command is None: if Command is None:
return return
cls._current_task = task
if task: if task:
cls._main_task = task
PatchClick._update_task_args() PatchClick._update_task_args()
if not cls.__patched: if not cls.__patched:
@ -53,11 +53,11 @@ class PatchClick:
@classmethod @classmethod
def _update_task_args(cls): def _update_task_args(cls):
if running_remotely() or not cls._main_task or not cls._args: if running_remotely() or not cls._current_task or not cls._args:
return return
param_val, param_types, param_desc = cls.args() param_val, param_types, param_desc = cls.args()
# noinspection PyProtectedMember # noinspection PyProtectedMember
cls._main_task._set_parameters( cls._current_task._set_parameters(
param_val, param_val,
__update=True, __update=True,
__parameters_descriptions=param_desc, __parameters_descriptions=param_desc,

View File

@ -1,4 +1,5 @@
import os import os
from time import sleep
import six import six
@ -7,26 +8,29 @@ from ..utilities.process.mp import BackgroundMonitor
class EnvironmentBind(object): class EnvironmentBind(object):
_task = None _current_task = None
_environment_section = 'Environment' _environment_section = 'Environment'
__patched = False
@classmethod @classmethod
def update_current_task(cls, current_task): def update_current_task(cls, task):
cls._task = current_task cls._current_task = task
# noinspection PyBroadException # noinspection PyBroadException
try: try:
cls._bind_environment() if not cls.__patched:
cls.__patched = True
cls._bind_environment()
except Exception: except Exception:
pass pass
@classmethod @classmethod
def _bind_environment(cls): def _bind_environment(cls):
if not cls._task: if not cls._current_task:
return return
# get ENVIRONMENT and put it into the OS environment # get ENVIRONMENT and put it into the OS environment
if running_remotely(): if running_remotely():
params = cls._task.get_parameters_as_dict() params = cls._current_task.get_parameters_as_dict()
if params and cls._environment_section in params: if params and cls._environment_section in params:
# put back into os: # put back into os:
os.environ.update(params[cls._environment_section]) os.environ.update(params[cls._environment_section])
@ -55,14 +59,18 @@ class EnvironmentBind(object):
elif match in os.environ: elif match in os.environ:
env_param.update({match: os.environ.get(match)}) env_param.update({match: os.environ.get(match)})
# store os environments # store os environments
cls._task.connect(env_param, cls._environment_section) cls._current_task.connect(env_param, cls._environment_section)
class PatchOsFork(object): class PatchOsFork(object):
_original_fork = None _original_fork = None
_current_task = None
@classmethod @classmethod
def patch_fork(cls): def patch_fork(cls, task):
cls._current_task = task
if not task:
return
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# only once # only once
@ -88,11 +96,13 @@ class PatchOsFork(object):
Task._wait_for_deferred(task) Task._wait_for_deferred(task)
ret = PatchOsFork._original_fork(*args, **kwargs) ret = PatchOsFork._original_fork(*args, **kwargs)
if not PatchOsFork._current_task:
return ret
# Make sure the new process stdout is logged # Make sure the new process stdout is logged
if not ret: if not ret:
# force creating a Task # force creating a Task
task = Task.current_task() task = Task.current_task()
if task is None: if not task:
return ret return ret
# # Hack: now make sure we setup the reporter threads (Log+Reporter) # # Hack: now make sure we setup the reporter threads (Log+Reporter)
@ -106,7 +116,7 @@ class PatchOsFork(object):
# just make sure we flush the internal state (the at exist caught by the external signal does the rest # just make sure we flush the internal state (the at exist caught by the external signal does the rest
# in theory we should not have to do any of that, but for some reason if we do not # in theory we should not have to do any of that, but for some reason if we do not
# the signal is never caught by the signal call backs, not sure why.... # the signal is never caught by the signal call backs, not sure why....
sleep(0.1)
# Since at_exist handlers do not work on forked processes, we have to manually call them here # Since at_exist handlers do not work on forked processes, we have to manually call them here
if task: if task:
try: try:

View File

@ -20,7 +20,7 @@ class PatchFire:
_section_name = "Args" _section_name = "Args"
_args_sep = "/" _args_sep = "/"
_commands_sep = "." _commands_sep = "."
_main_task = None _current_task = None
__remote_task_params = None __remote_task_params = None
__remote_task_params_dict = {} __remote_task_params_dict = {}
__patched = False __patched = False
@ -38,8 +38,8 @@ class PatchFire:
if fire is None: if fire is None:
return return
cls._current_task = task
if task: if task:
cls._main_task = task
cls._update_task_args() cls._update_task_args()
if not cls.__patched: if not cls.__patched:
@ -53,7 +53,7 @@ class PatchFire:
@classmethod @classmethod
def _update_task_args(cls): def _update_task_args(cls):
if running_remotely() or not cls._main_task: if running_remotely() or not cls._current_task:
return return
args = {} args = {}
parameters_types = {} parameters_types = {}
@ -118,7 +118,7 @@ class PatchFire:
parameters_types = {**parameters_types, **unused_paramenters_types} parameters_types = {**parameters_types, **unused_paramenters_types}
# noinspection PyProtectedMember # noinspection PyProtectedMember
cls._main_task._set_parameters( cls._current_task._set_parameters(
args, args,
__update=True, __update=True,
__parameters_types=parameters_types, __parameters_types=parameters_types,
@ -126,25 +126,25 @@ class PatchFire:
@staticmethod @staticmethod
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
if running_remotely(): if not running_remotely():
command = PatchFire._load_task_params() return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
if command is not None: command = PatchFire._load_task_params()
replaced_args = command.split(PatchFire._commands_sep) if command is not None:
else: replaced_args = command.split(PatchFire._commands_sep)
replaced_args = [] else:
for param in PatchFire.__remote_task_params[PatchFire._section_name].values(): replaced_args = []
if command is not None and param.type == PatchFire._command_arg_type_template % command: for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):]) if command is not None and param.type == PatchFire._command_arg_type_template % command:
value = PatchFire.__remote_task_params_dict[param.name] replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
if len(value) > 0: value = PatchFire.__remote_task_params_dict[param.name]
replaced_args.append(value) if len(value) > 0:
if param.type == PatchFire._shared_arg_type: replaced_args.append(value)
replaced_args.append("--" + param.name) if param.type == PatchFire._shared_arg_type:
value = PatchFire.__remote_task_params_dict[param.name] replaced_args.append("--" + param.name)
if len(value) > 0: value = PatchFire.__remote_task_params_dict[param.name]
replaced_args.append(value) if len(value) > 0:
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs) replaced_args.append(value)
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs) return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
@staticmethod @staticmethod
def __CallAndUpdateTrace( # noqa def __CallAndUpdateTrace( # noqa

View File

@ -11,13 +11,15 @@ from ...model import Framework
class PatchCatBoostModelIO(PatchBaseModelIO): class PatchCatBoostModelIO(PatchBaseModelIO):
__main_task = None _current_task = None
__patched = None __patched = None
__callback_cls = None __callback_cls = None
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **kwargs):
PatchCatBoostModelIO.__main_task = task PatchCatBoostModelIO._current_task = task
if not task:
return
PatchCatBoostModelIO._patch_model_io() PatchCatBoostModelIO._patch_model_io()
PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io) PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io)
@ -38,16 +40,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit) CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit)
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit) CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit) CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit) CatBoostRanker.fit = _patched_call(CatBoostRanker.fit, PatchCatBoostModelIO._fit)
except Exception as e: except Exception as e:
logger = PatchCatBoostModelIO.__main_task.get_logger() logger = PatchCatBoostModelIO._current_task.get_logger()
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'") logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")
@staticmethod @staticmethod
def _save(original_fn, obj, f, *args, **kwargs): def _save(original_fn, obj, f, *args, **kwargs):
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model # see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
ret = original_fn(obj, f, *args, **kwargs) ret = original_fn(obj, f, *args, **kwargs)
if not PatchCatBoostModelIO.__main_task: if not PatchCatBoostModelIO._current_task:
return ret return ret
if isinstance(f, six.string_types): if isinstance(f, six.string_types):
filename = f filename = f
@ -60,13 +62,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
except Exception: except Exception:
model_name = None model_name = None
WeightsFileHandler.create_output_model( WeightsFileHandler.create_output_model(
obj, filename, Framework.catboost, PatchCatBoostModelIO.__main_task, singlefile=True, model_name=model_name obj, filename, Framework.catboost, PatchCatBoostModelIO._current_task, singlefile=True, model_name=model_name
) )
return ret return ret
@staticmethod @staticmethod
def _load(original_fn, f, *args, **kwargs): def _load(original_fn, f, *args, **kwargs):
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model # see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model
if not PatchCatBoostModelIO._current_task:
return original_fn(f, *args, **kwargs)
if isinstance(f, six.string_types): if isinstance(f, six.string_types):
filename = f filename = f
elif len(args) >= 1 and isinstance(args[0], six.string_types): elif len(args) >= 1 and isinstance(args[0], six.string_types):
@ -74,13 +79,10 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
else: else:
filename = None filename = None
if not PatchCatBoostModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model # register input model
empty = _Empty() empty = _Empty()
model = original_fn(f, *args, **kwargs) model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO.__main_task) WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO._current_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -91,13 +93,15 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
@staticmethod @staticmethod
def _fit(original_fn, obj, *args, **kwargs): def _fit(original_fn, obj, *args, **kwargs):
if not PatchCatBoostModelIO._current_task:
return original_fn(obj, *args, **kwargs)
callbacks = kwargs.get("callbacks") or [] callbacks = kwargs.get("callbacks") or []
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)] kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO._current_task)]
# noinspection PyBroadException # noinspection PyBroadException
try: try:
return original_fn(obj, *args, **kwargs) return original_fn(obj, *args, **kwargs)
except Exception: except Exception:
logger = PatchCatBoostModelIO.__main_task.get_logger() logger = PatchCatBoostModelIO._current_task.get_logger()
logger.report_text( logger.report_text(
"Catboost metrics logging is not supported for GPU. " "Catboost metrics logging is not supported for GPU. "
"See https://github.com/catboost/catboost/issues/1792" "See https://github.com/catboost/catboost/issues/1792"

View File

@ -25,11 +25,9 @@ class PatchFastai(object):
try: try:
if Version(fastai.__version__) < Version("2.0.0"): if Version(fastai.__version__) < Version("2.0.0"):
PatchFastaiV1.update_current_task(task) PatchFastaiV1.update_current_task(task)
PatchFastaiV1.patch_model_callback()
PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback) PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback)
else: else:
PatchFastaiV2.update_current_task(task) PatchFastaiV2.update_current_task(task)
PatchFastaiV2.patch_model_callback()
PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback) PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback)
except Exception: except Exception:
pass pass
@ -38,11 +36,17 @@ class PatchFastai(object):
class PatchFastaiV1(object): class PatchFastaiV1(object):
__metrics_names = {} __metrics_names = {}
__gradient_hist_helpers = {} __gradient_hist_helpers = {}
_main_task = None _current_task = None
__patched = False
@staticmethod @staticmethod
def update_current_task(task, **_): def update_current_task(task, **_):
PatchFastaiV1._main_task = task PatchFastaiV1._current_task = task
if not task:
return
if not PatchFastaiV1.__patched:
PatchFastaiV1.__patched = True
PatchFastaiV1.patch_model_callback()
@staticmethod @staticmethod
def patch_model_callback(): def patch_model_callback():
@ -52,7 +56,6 @@ class PatchFastaiV1(object):
try: try:
from fastai.basic_train import Recorder from fastai.basic_train import Recorder
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end) Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end)
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end) Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end)
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end) Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end)
@ -65,7 +68,7 @@ class PatchFastaiV1(object):
@staticmethod @staticmethod
def _on_train_begin(original_fn, recorder, *args, **kwargs): def _on_train_begin(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task: if not PatchFastaiV1._current_task:
return return
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -84,7 +87,7 @@ class PatchFastaiV1(object):
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task: if not PatchFastaiV1._current_task:
return return
# noinspection PyBroadException # noinspection PyBroadException
@ -112,7 +115,7 @@ class PatchFastaiV1(object):
min_gradient=gradient_stats[:, 5].min(), min_gradient=gradient_stats[:, 5].min(),
) )
logger = PatchFastaiV1._main_task.get_logger() logger = PatchFastaiV1._current_task.get_logger()
iteration = kwargs.get("iteration", 0) iteration = kwargs.get("iteration", 0)
for name, val in stats_report.items(): for name, val in stats_report.items():
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration) logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
@ -122,26 +125,26 @@ class PatchFastaiV1(object):
@staticmethod @staticmethod
def _on_epoch_end(original_fn, recorder, *args, **kwargs): def _on_epoch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task: if not PatchFastaiV1._current_task:
return return
# noinspection PyBroadException # noinspection PyBroadException
try: try:
logger = PatchFastaiV1._main_task.get_logger() logger = PatchFastaiV1._current_task.get_logger()
iteration = kwargs.get("iteration") iteration = kwargs.get("iteration")
for series, value in zip( for series, value in zip(
PatchFastaiV1.__metrics_names[id(recorder)], PatchFastaiV1.__metrics_names[id(recorder)],
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []), [kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
): ):
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration) logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
PatchFastaiV1._main_task.flush() PatchFastaiV1._current_task.flush()
except Exception: except Exception:
pass pass
@staticmethod @staticmethod
def _on_batch_end(original_fn, recorder, *args, **kwargs): def _on_batch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs) original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task: if not PatchFastaiV1._current_task:
return return
# noinspection PyBroadException # noinspection PyBroadException
@ -150,7 +153,7 @@ class PatchFastaiV1(object):
if iteration == 0 or not kwargs.get("train"): if iteration == 0 or not kwargs.get("train"):
return return
logger = PatchFastaiV1._main_task.get_logger() logger = PatchFastaiV1._current_task.get_logger()
logger.report_scalar( logger.report_scalar(
title="metrics", title="metrics",
series="train_loss", series="train_loss",
@ -174,11 +177,17 @@ class PatchFastaiV1(object):
class PatchFastaiV2(object): class PatchFastaiV2(object):
_main_task = None _current_task = None
__patched = False
@staticmethod @staticmethod
def update_current_task(task, **_): def update_current_task(task, **_):
PatchFastaiV2._main_task = task PatchFastaiV2._current_task = task
if not task:
return
if not PatchFastaiV2.__patched:
PatchFastaiV2.__patched = True
PatchFastaiV2.patch_model_callback()
@staticmethod @staticmethod
def patch_model_callback(): def patch_model_callback():
@ -212,13 +221,15 @@ class PatchFastaiV2(object):
self.logger = noop self.logger = noop
self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id) self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id)
PatchFastaiV2.PatchFastaiCallbacks.__id += 1 PatchFastaiV2.PatchFastaiCallbacks.__id += 1
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._main_task.get_logger()) self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._current_task.get_logger())
def after_batch(self): def after_batch(self):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
super().after_batch() # noqa super().after_batch() # noqa
logger = PatchFastaiV2._main_task.get_logger() if not PatchFastaiV2._current_task:
return
logger = PatchFastaiV2._current_task.get_logger()
if not self.training: # noqa if not self.training: # noqa
return return
self.__train_iter += 1 self.__train_iter += 1
@ -254,7 +265,9 @@ class PatchFastaiV2(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
super().after_epoch() # noqa super().after_epoch() # noqa
logger = PatchFastaiV2._main_task.get_logger() if not PatchFastaiV2._current_task:
return
logger = PatchFastaiV2._current_task.get_logger()
for metric in self._valid_mets: # noqa for metric in self._valid_mets: # noqa
logger.report_scalar( logger.report_scalar(
title="metrics_" + self.__id, title="metrics_" + self.__id,
@ -270,7 +283,9 @@ class PatchFastaiV2(object):
try: try:
if hasattr(fastai.learner.Recorder, "before_step"): if hasattr(fastai.learner.Recorder, "before_step"):
super().before_step() # noqa super().before_step() # noqa
logger = PatchFastaiV2._main_task.get_logger() if not PatchFastaiV2._current_task:
return
logger = PatchFastaiV2._current_task.get_logger()
gradients = [ gradients = [
x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None
] # noqa ] # noqa

View File

@ -11,12 +11,14 @@ from ...model import Framework
class PatchLIGHTgbmModelIO(PatchBaseModelIO): class PatchLIGHTgbmModelIO(PatchBaseModelIO):
__main_task = None _current_task = None
__patched = None __patched = None
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **kwargs):
PatchLIGHTgbmModelIO.__main_task = task PatchLIGHTgbmModelIO._current_task = task
if not task:
return
PatchLIGHTgbmModelIO._patch_model_io() PatchLIGHTgbmModelIO._patch_model_io()
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io) PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
@ -43,7 +45,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
@staticmethod @staticmethod
def _save(original_fn, obj, f, *args, **kwargs): def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs) ret = original_fn(obj, f, *args, **kwargs)
if not PatchLIGHTgbmModelIO.__main_task: if not PatchLIGHTgbmModelIO._current_task:
return ret return ret
if isinstance(f, six.string_types): if isinstance(f, six.string_types):
@ -64,12 +66,15 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
model_name = Path(filename).stem model_name = Path(filename).stem
except Exception: except Exception:
model_name = None model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task, WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO._current_task,
singlefile=True, model_name=model_name) singlefile=True, model_name=model_name)
return ret return ret
@staticmethod @staticmethod
def _load(original_fn, model_file, *args, **kwargs): def _load(original_fn, model_file, *args, **kwargs):
if not PatchLIGHTgbmModelIO._current_task:
return original_fn(model_file, *args, **kwargs)
if isinstance(model_file, six.string_types): if isinstance(model_file, six.string_types):
filename = model_file filename = model_file
elif hasattr(model_file, 'name'): elif hasattr(model_file, 'name'):
@ -79,21 +84,18 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
else: else:
filename = None filename = None
if not PatchLIGHTgbmModelIO.__main_task:
return original_fn(model_file, *args, **kwargs)
# register input model # register input model
empty = _Empty() empty = _Empty()
# Hack: disabled # Hack: disabled
if False and running_remotely(): if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchLIGHTgbmModelIO.__main_task) PatchLIGHTgbmModelIO._current_task)
model = original_fn(model_file=filename or model_file, *args, **kwargs) model = original_fn(model_file=filename or model_file, *args, **kwargs)
else: else:
# try to load model before registering, in case we fail # try to load model before registering, in case we fail
model = original_fn(model_file=model_file, *args, **kwargs) model = original_fn(model_file=model_file, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm, WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
PatchLIGHTgbmModelIO.__main_task) PatchLIGHTgbmModelIO._current_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException # noinspection PyBroadException
@ -110,7 +112,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
# logging the results to scalars section # logging the results to scalars section
# noinspection PyBroadException # noinspection PyBroadException
try: try:
logger = PatchLIGHTgbmModelIO.__main_task.get_logger() logger = PatchLIGHTgbmModelIO._current_task.get_logger()
iteration = env.iteration iteration = env.iteration
for data_title, data_series, value, _ in env.evaluation_result_list: for data_title, data_series, value, _ in env.evaluation_result_list:
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value), logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
@ -121,12 +123,12 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback()) kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
ret = original_fn(*args, **kwargs) ret = original_fn(*args, **kwargs)
if not PatchLIGHTgbmModelIO.__main_task: if not PatchLIGHTgbmModelIO._current_task:
return ret return ret
params = args[0] if args else kwargs.get('params', {}) params = args[0] if args else kwargs.get('params', {})
for k, v in params.items(): for k, v in params.items():
if isinstance(v, set): if isinstance(v, set):
params[k] = list(v) params[k] = list(v)
if params: if params:
PatchLIGHTgbmModelIO.__main_task.connect(params) PatchLIGHTgbmModelIO._current_task.connect(params)
return ret return ret

View File

@ -10,14 +10,14 @@ from ...model import Framework
class PatchMegEngineModelIO(PatchBaseModelIO): class PatchMegEngineModelIO(PatchBaseModelIO):
__main_task = None _current_task = None
__patched = None __patched = None
# __patched_lightning = None
@staticmethod @staticmethod
def update_current_task(task, **_): def update_current_task(task, **_):
PatchMegEngineModelIO.__main_task = task PatchMegEngineModelIO._current_task = task
if not task:
return
PatchMegEngineModelIO._patch_model_io() PatchMegEngineModelIO._patch_model_io()
PostImportHookPatching.add_on_import( PostImportHookPatching.add_on_import(
'megengine', PatchMegEngineModelIO._patch_model_io 'megengine', PatchMegEngineModelIO._patch_model_io
@ -58,7 +58,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
ret = original_fn(obj, f, *args, **kwargs) ret = original_fn(obj, f, *args, **kwargs)
# if there is no main task or this is a nested call # if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task: if not PatchMegEngineModelIO._current_task:
return ret return ret
# noinspection PyBroadException # noinspection PyBroadException
@ -93,7 +93,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
WeightsFileHandler.create_output_model( WeightsFileHandler.create_output_model(
obj, filename, Framework.megengine, obj, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task, PatchMegEngineModelIO._current_task,
singlefile=True, model_name=model_name, singlefile=True, model_name=model_name,
) )
@ -102,7 +102,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
@staticmethod @staticmethod
def _load(original_fn, f, *args, **kwargs): def _load(original_fn, f, *args, **kwargs):
# if there is no main task or this is a nested call # if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task: if not PatchMegEngineModelIO._current_task:
return original_fn(f, *args, **kwargs) return original_fn(f, *args, **kwargs)
# noinspection PyBroadException # noinspection PyBroadException
@ -124,7 +124,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
model = original_fn(f, *args, **kwargs) model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file( WeightsFileHandler.restore_weights_file(
empty, filename, Framework.megengine, empty, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task PatchMegEngineModelIO._current_task
) )
if empty.trains_in_model: if empty.trains_in_model:
@ -139,7 +139,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
@staticmethod @staticmethod
def _load_from_obj(original_fn, obj, f, *args, **kwargs): def _load_from_obj(original_fn, obj, f, *args, **kwargs):
# if there is no main task or this is a nested call # if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task: if not PatchMegEngineModelIO._current_task:
return original_fn(obj, f, *args, **kwargs) return original_fn(obj, f, *args, **kwargs)
# noinspection PyBroadException # noinspection PyBroadException
@ -161,7 +161,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
model = original_fn(obj, f, *args, **kwargs) model = original_fn(obj, f, *args, **kwargs)
WeightsFileHandler.restore_weights_file( WeightsFileHandler.restore_weights_file(
empty, filename, Framework.megengine, empty, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task, PatchMegEngineModelIO._current_task,
) )
if empty.trains_in_model: if empty.trains_in_model:

View File

@ -11,13 +11,15 @@ from ...model import Framework
class PatchPyTorchModelIO(PatchBaseModelIO): class PatchPyTorchModelIO(PatchBaseModelIO):
__main_task = None _current_task = None
__patched = None __patched = None
__patched_lightning = None __patched_lightning = None
@staticmethod @staticmethod
def update_current_task(task, **_): def update_current_task(task, **_):
PatchPyTorchModelIO.__main_task = task PatchPyTorchModelIO._current_task = task
if not task:
return
PatchPyTorchModelIO._patch_model_io() PatchPyTorchModelIO._patch_model_io()
PatchPyTorchModelIO._patch_lightning_io() PatchPyTorchModelIO._patch_lightning_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io) PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@ -112,7 +114,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
ret = original_fn(obj, f, *args, **kwargs) ret = original_fn(obj, f, *args, **kwargs)
# if there is no main task or this is a nested call # if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task: if not PatchPyTorchModelIO._current_task:
return ret return ret
# pytorch-lightning check if rank is zero # pytorch-lightning check if rank is zero
@ -154,14 +156,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
model_name = None model_name = None
WeightsFileHandler.create_output_model( WeightsFileHandler.create_output_model(
obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name) obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
return ret return ret
@staticmethod @staticmethod
def _load(original_fn, f, *args, **kwargs): def _load(original_fn, f, *args, **kwargs):
# if there is no main task or this is a nested call # if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task: if not PatchPyTorchModelIO._current_task:
return original_fn(f, *args, **kwargs) return original_fn(f, *args, **kwargs)
# noinspection PyBroadException # noinspection PyBroadException
@ -182,13 +184,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
# Hack: disabled # Hack: disabled
if False and running_remotely(): if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file( filename = WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
model = original_fn(filename or f, *args, **kwargs) model = original_fn(filename or f, *args, **kwargs)
else: else:
# try to load model before registering, in case we fail # try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs) model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file( WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException # noinspection PyBroadException
@ -202,7 +204,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
@staticmethod @staticmethod
def _load_from_obj(original_fn, obj, f, *args, **kwargs): def _load_from_obj(original_fn, obj, f, *args, **kwargs):
# if there is no main task or this is a nested call # if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task: if not PatchPyTorchModelIO._current_task:
return original_fn(obj, f, *args, **kwargs) return original_fn(obj, f, *args, **kwargs)
# noinspection PyBroadException # noinspection PyBroadException
@ -223,13 +225,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
# Hack: disabled # Hack: disabled
if False and running_remotely(): if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file( filename = WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
model = original_fn(obj, filename or f, *args, **kwargs) model = original_fn(obj, filename or f, *args, **kwargs)
else: else:
# try to load model before registering, in case we fail # try to load model before registering, in case we fail
model = original_fn(obj, f, *args, **kwargs) model = original_fn(obj, f, *args, **kwargs)
WeightsFileHandler.restore_weights_file( WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task) empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException # noinspection PyBroadException

View File

@ -232,7 +232,7 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into TF SummaryWriter implementation that converts the tensorboard's summary into
ClearML events and reports the events (metrics) for an ClearML task (logger). ClearML events and reports the events (metrics) for an ClearML task (logger).
""" """
__main_task = None _current_task = None
__report_hparams = True __report_hparams = True
_add_lock = threading.RLock() _add_lock = threading.RLock()
_series_name_lookup = {} _series_name_lookup = {}
@ -641,7 +641,7 @@ class EventTrainsWriter(object):
session_start_info = parse_session_start_info_plugin_data(content) session_start_info = parse_session_start_info_plugin_data(content)
session_start_info = MessageToDict(session_start_info) session_start_info = MessageToDict(session_start_info)
hparams = session_start_info["hparams"] hparams = session_start_info["hparams"]
EventTrainsWriter.__main_task.update_parameters( EventTrainsWriter._current_task.update_parameters(
{"TB_hparams/{}".format(k): v for k, v in hparams.items()} {"TB_hparams/{}".format(k): v for k, v in hparams.items()}
) )
except Exception: except Exception:
@ -844,13 +844,13 @@ class EventTrainsWriter(object):
@classmethod @classmethod
def update_current_task(cls, task, **kwargs): def update_current_task(cls, task, **kwargs):
cls.__report_hparams = kwargs.get('report_hparams', False) cls.__report_hparams = kwargs.get('report_hparams', False)
if cls.__main_task != task: if cls._current_task != task:
with cls._add_lock: with cls._add_lock:
cls._series_name_lookup = {} cls._series_name_lookup = {}
cls._title_series_writers_lookup = {} cls._title_series_writers_lookup = {}
cls._event_writers_id_to_logdir = {} cls._event_writers_id_to_logdir = {}
cls._title_series_wraparound_counter = {} cls._title_series_wraparound_counter = {}
cls.__main_task = task cls._current_task = task
# noinspection PyCallingNonCallable # noinspection PyCallingNonCallable
@ -905,9 +905,10 @@ class ProxyEventsWriter(object):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class PatchSummaryToEventTransformer(object): class PatchSummaryToEventTransformer(object):
__main_task = None _current_task = None
__original_getattribute = None __original_getattribute = None
__original_getattributeX = None __original_getattributeX = None
__patched = False
_original_add_event = None _original_add_event = None
_original_add_eventT = None _original_add_eventT = None
_original_add_eventX = None _original_add_eventX = None
@ -930,15 +931,19 @@ class PatchSummaryToEventTransformer(object):
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **kwargs):
PatchSummaryToEventTransformer.defaults_dict.update(kwargs) PatchSummaryToEventTransformer.defaults_dict.update(kwargs)
PatchSummaryToEventTransformer.__main_task = task PatchSummaryToEventTransformer._current_task = task
# make sure we patched the SummaryToEventTransformer if not task:
PatchSummaryToEventTransformer._patch_summary_to_event_transformer() return
PostImportHookPatching.add_on_import('tensorflow', if not PatchSummaryToEventTransformer.__patched:
PatchSummaryToEventTransformer._patch_summary_to_event_transformer) PatchSummaryToEventTransformer.__patched = True
PostImportHookPatching.add_on_import('torch', # make sure we patched the SummaryToEventTransformer
PatchSummaryToEventTransformer._patch_summary_to_event_transformer) PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
PostImportHookPatching.add_on_import('tensorboardX', PostImportHookPatching.add_on_import('tensorflow',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer) PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('torch',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('tensorboardX',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
@staticmethod @staticmethod
def _patch_summary_to_event_transformer(): def _patch_summary_to_event_transformer():
@ -1002,7 +1007,7 @@ class PatchSummaryToEventTransformer(object):
@staticmethod @staticmethod
def _patched_add_eventT(self, *args, **kwargs): def _patched_add_eventT(self, *args, **kwargs):
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task: if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs) return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.clearml: # noqa if not self.clearml: # noqa
# noinspection PyBroadException # noinspection PyBroadException
@ -1010,7 +1015,7 @@ class PatchSummaryToEventTransformer(object):
logdir = self.get_logdir() logdir = self.get_logdir()
except Exception: except Exception:
logdir = None logdir = None
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1021,7 +1026,7 @@ class PatchSummaryToEventTransformer(object):
@staticmethod @staticmethod
def _patched_add_eventX(self, *args, **kwargs): def _patched_add_eventX(self, *args, **kwargs):
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task: if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs) return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.clearml: if not self.clearml:
# noinspection PyBroadException # noinspection PyBroadException
@ -1029,7 +1034,7 @@ class PatchSummaryToEventTransformer(object):
logdir = self.get_logdir() logdir = self.get_logdir()
except Exception: except Exception:
logdir = None logdir = None
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict) logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -1051,7 +1056,7 @@ class PatchSummaryToEventTransformer(object):
@staticmethod @staticmethod
def _patched_getattribute_(self, attr, get_base): def _patched_getattribute_(self, attr, get_base):
# no main task, zero chance we have an ClearML event logger # no main task, zero chance we have an ClearML event logger
if PatchSummaryToEventTransformer.__main_task is None: if PatchSummaryToEventTransformer._current_task is None:
return get_base(self, attr) return get_base(self, attr)
# check if we already have an ClearML event logger # check if we already have an ClearML event logger
@ -1068,7 +1073,7 @@ class PatchSummaryToEventTransformer(object):
except Exception: except Exception:
logdir = None logdir = None
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(), trains_event = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
logdir=logdir, **defaults_dict) logdir=logdir, **defaults_dict)
# order is important, the return value of ProxyEventsWriter is the last object in the list # order is important, the return value of ProxyEventsWriter is the last object in the list
@ -1111,8 +1116,9 @@ class _ModelAdapter(object):
class PatchModelCheckPointCallback(object): class PatchModelCheckPointCallback(object):
__main_task = None _current_task = None
__original_getattribute = None __original_getattribute = None
__patched = False
defaults_dict = dict( defaults_dict = dict(
config_text=None, config_text=None,
config_dict=None, config_dict=None,
@ -1132,11 +1138,15 @@ class PatchModelCheckPointCallback(object):
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **kwargs):
PatchModelCheckPointCallback.defaults_dict.update(kwargs) PatchModelCheckPointCallback.defaults_dict.update(kwargs)
PatchModelCheckPointCallback.__main_task = task PatchModelCheckPointCallback._current_task = task
# make sure we patched the SummaryToEventTransformer if not task:
PatchModelCheckPointCallback._patch_model_checkpoint() return
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint) if not PatchModelCheckPointCallback.__patched:
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint) PatchModelCheckPointCallback.__patched = True
# make sure we patched the SummaryToEventTransformer
PatchModelCheckPointCallback._patch_model_checkpoint()
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
@staticmethod @staticmethod
def _patch_model_checkpoint(): def _patch_model_checkpoint():
@ -1176,7 +1186,7 @@ class PatchModelCheckPointCallback(object):
get_base = PatchModelCheckPointCallback.__original_getattribute get_base = PatchModelCheckPointCallback.__original_getattribute
# no main task, zero chance we have an ClearML event logger # no main task, zero chance we have an ClearML event logger
if PatchModelCheckPointCallback.__main_task is None: if PatchModelCheckPointCallback._current_task is None:
return get_base(self, attr) return get_base(self, attr)
# check if we already have an ClearML event logger # check if we already have an ClearML event logger
@ -1189,17 +1199,17 @@ class PatchModelCheckPointCallback(object):
base_model = __dict__['model'] base_model = __dict__['model']
defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict
output_model = OutputModel( output_model = OutputModel(
PatchModelCheckPointCallback.__main_task, PatchModelCheckPointCallback._current_task,
config_text=defaults_dict.get('config_text'), config_text=defaults_dict.get('config_text'),
config_dict=defaults_dict.get('config_dict'), config_dict=defaults_dict.get('config_dict'),
name=defaults_dict.get('name'), name=defaults_dict.get('name'),
comment=defaults_dict.get('comment'), comment=defaults_dict.get('comment'),
label_enumeration=defaults_dict.get('label_enumeration') or label_enumeration=defaults_dict.get('label_enumeration') or
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(), PatchModelCheckPointCallback._current_task.get_labels_enumeration(),
framework=Framework.keras, framework=Framework.keras,
) )
output_model.set_upload_destination( output_model.set_upload_destination(
PatchModelCheckPointCallback.__main_task.get_output_destination(raise_on_error=False)) PatchModelCheckPointCallback._current_task.get_output_destination(raise_on_error=False))
trains_model = _ModelAdapter(base_model, output_model) trains_model = _ModelAdapter(base_model, output_model)
# order is important, the return value of ProxyEventsWriter is the last object in the list # order is important, the return value of ProxyEventsWriter is the last object in the list
@ -1209,26 +1219,32 @@ class PatchModelCheckPointCallback(object):
# noinspection PyProtectedMember,PyUnresolvedReferences # noinspection PyProtectedMember,PyUnresolvedReferences
class PatchTensorFlowEager(object): class PatchTensorFlowEager(object):
__main_task = None _current_task = None
__original_fn_scalar = None __original_fn_scalar = None
__original_fn_hist = None __original_fn_hist = None
__original_fn_image = None __original_fn_image = None
__original_fn_write_summary = None
__trains_event_writer = {} __trains_event_writer = {}
__tf_tb_writer_id_to_logdir = {} __tf_tb_writer_id_to_logdir = {}
__patched = False
defaults_dict = dict( defaults_dict = dict(
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5, report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
histogram_granularity=50) histogram_granularity=50)
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **kwargs):
if task != PatchTensorFlowEager.__main_task: if task != PatchTensorFlowEager._current_task:
PatchTensorFlowEager.__trains_event_writer = {} PatchTensorFlowEager.__trains_event_writer = {}
PatchTensorFlowEager.defaults_dict.update(kwargs) PatchTensorFlowEager.defaults_dict.update(kwargs)
PatchTensorFlowEager.__main_task = task PatchTensorFlowEager._current_task = task
# make sure we patched the SummaryToEventTransformer if not task:
PatchTensorFlowEager._patch_summary_ops() return
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops) if not PatchTensorFlowEager.__patched:
PatchTensorFlowEager.__patched = True
# make sure we patched the SummaryToEventTransformer
PatchTensorFlowEager._patch_summary_ops()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
@staticmethod @staticmethod
def _patch_summary_ops(): def _patch_summary_ops():
@ -1245,7 +1261,7 @@ class PatchTensorFlowEager(object):
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary PatchTensorFlowEager.__original_fn_write_summary = gen_summary_ops.write_summary
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__, gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
gen_summary_ops.create_summary_file_writer) gen_summary_ops.create_summary_file_writer)
@ -1270,6 +1286,8 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _create_summary_file_writer(original_fn, *args, **kwargs): def _create_summary_file_writer(original_fn, *args, **kwargs):
if not PatchTensorFlowEager._current_task:
return original_fn(*args, **kwargs)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
a_logdir = None a_logdir = None
@ -1294,7 +1312,7 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _get_event_writer(writer): def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task: if not PatchTensorFlowEager._current_task:
return None return None
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)): if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
# noinspection PyBroadException # noinspection PyBroadException
@ -1324,7 +1342,7 @@ class PatchTensorFlowEager(object):
logdir = None logdir = None
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter( PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir, logger=PatchTensorFlowEager._current_task.get_logger(), logdir=logdir,
**PatchTensorFlowEager.defaults_dict) **PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer[id(writer)] return PatchTensorFlowEager.__trains_event_writer[id(writer)]
@ -1337,6 +1355,10 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs): def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_write_summary(
writer, step, tensor, tag, summary_metadata, name, **kwargs
)
event_writer = PatchTensorFlowEager._get_event_writer(writer) event_writer = PatchTensorFlowEager._get_event_writer(writer)
# make sure we can get the tensors values # make sure we can get the tensors values
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
@ -1370,13 +1392,17 @@ class PatchTensorFlowEager(object):
step=int(step.numpy()) if not isinstance(step, int) else step, step=int(step.numpy()) if not isinstance(step, int) else step,
values=None, audio_data=audio_bytes) values=None, audio_data=audio_bytes)
else: else:
pass # print('unsupported plugin_type', plugin_type) pass
except Exception: except Exception:
pass pass
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs) return PatchTensorFlowEager.__original_fn_write_summary(
writer, step, tensor, tag, summary_metadata, name, **kwargs
)
@staticmethod @staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs): def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
event_writer = PatchTensorFlowEager._get_event_writer(writer) event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try: try:
@ -1417,6 +1443,8 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs): def _write_hist_summary(writer, step, tag, values, name, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
event_writer = PatchTensorFlowEager._get_event_writer(writer) event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try: try:
@ -1459,6 +1487,9 @@ class PatchTensorFlowEager(object):
@staticmethod @staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs): def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_image(
writer, step, tag, tensor, bad_color, max_images, name, **kwargs)
event_writer = PatchTensorFlowEager._get_event_writer(writer) event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'): if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try: try:
@ -1526,13 +1557,15 @@ class PatchTensorFlowEager(object):
# noinspection PyPep8Naming,SpellCheckingInspection # noinspection PyPep8Naming,SpellCheckingInspection
class PatchKerasModelIO(object): class PatchKerasModelIO(object):
__main_task = None _current_task = None
__patched_keras = None __patched_keras = None
__patched_tensorflow = None __patched_tensorflow = None
@staticmethod @staticmethod
def update_current_task(task, **_): def update_current_task(task, **_):
PatchKerasModelIO.__main_task = task PatchKerasModelIO._current_task = task
if not task:
return
PatchKerasModelIO._patch_model_checkpoint() PatchKerasModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint) PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint) PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint)
@ -1692,7 +1725,7 @@ class PatchKerasModelIO(object):
def _updated_config(original_fn, self): def _updated_config(original_fn, self):
config = original_fn(self) config = original_fn(self)
# check if we have main task # check if we have main task
if PatchKerasModelIO.__main_task is None: if PatchKerasModelIO._current_task is None:
return config return config
try: try:
@ -1717,10 +1750,10 @@ class PatchKerasModelIO(object):
else: else:
# todo: support multiple models for the same task # todo: support multiple models for the same task
self.trains_out_model.append(OutputModel( self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task, task=PatchKerasModelIO._current_task,
config_dict=safe_config, config_dict=safe_config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id, name=PatchKerasModelIO._current_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
framework=Framework.keras, framework=Framework.keras,
)) ))
except Exception as ex: except Exception as ex:
@ -1738,7 +1771,7 @@ class PatchKerasModelIO(object):
self = _Empty() self = _Empty()
# check if we have main task # check if we have main task
if PatchKerasModelIO.__main_task is None: if PatchKerasModelIO._current_task is None:
return self return self
try: try:
@ -1751,10 +1784,10 @@ class PatchKerasModelIO(object):
# check if object already has InputModel # check if object already has InputModel
self.trains_in_model = InputModel.empty( self.trains_in_model = InputModel.empty(
config_dict=config_dict, config_dict=config_dict,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(), label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
) )
# todo: support multiple models for the same task # todo: support multiple models for the same task
PatchKerasModelIO.__main_task.connect(self.trains_in_model) PatchKerasModelIO._current_task.connect(self.trains_in_model)
# if we are running remotely we should deserialize the object # if we are running remotely we should deserialize the object
# because someone might have changed the configuration # because someone might have changed the configuration
# Hack: disabled # Hack: disabled
@ -1781,7 +1814,7 @@ class PatchKerasModelIO(object):
@staticmethod @staticmethod
def _load_weights(original_fn, self, *args, **kwargs): def _load_weights(original_fn, self, *args, **kwargs):
# check if we have main task # check if we have main task
if PatchKerasModelIO.__main_task is None: if PatchKerasModelIO._current_task is None:
return original_fn(self, *args, **kwargs) return original_fn(self, *args, **kwargs)
# get filepath # get filepath
@ -1794,7 +1827,7 @@ class PatchKerasModelIO(object):
if False and running_remotely(): if False and running_remotely():
# register/load model weights # register/load model weights
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
PatchKerasModelIO.__main_task) PatchKerasModelIO._current_task)
if 'filepath' in kwargs: if 'filepath' in kwargs:
kwargs['filepath'] = filepath kwargs['filepath'] = filepath
else: else:
@ -1805,11 +1838,13 @@ class PatchKerasModelIO(object):
# try to load the files, if something happened exception will be raised before we register the file # try to load the files, if something happened exception will be raised before we register the file
model = original_fn(self, *args, **kwargs) model = original_fn(self, *args, **kwargs)
# register/load model weights # register/load model weights
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task) WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO._current_task)
return model return model
@staticmethod @staticmethod
def _save(original_fn, self, *args, **kwargs): def _save(original_fn, self, *args, **kwargs):
if not PatchKerasModelIO._current_task:
return original_fn(self, *args, **kwargs)
if hasattr(self, 'trains_out_model') and self.trains_out_model: if hasattr(self, 'trains_out_model') and self.trains_out_model:
# noinspection PyProtectedMember # noinspection PyProtectedMember
self.trains_out_model[-1]._processed = False self.trains_out_model[-1]._processed = False
@ -1823,12 +1858,13 @@ class PatchKerasModelIO(object):
@staticmethod @staticmethod
def _save_weights(original_fn, self, *args, **kwargs): def _save_weights(original_fn, self, *args, **kwargs):
original_fn(self, *args, **kwargs) original_fn(self, *args, **kwargs)
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs) if PatchKerasModelIO._current_task:
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
@staticmethod @staticmethod
def _update_outputmodel(self, *args, **kwargs): def _update_outputmodel(self, *args, **kwargs):
# check if we have main task # check if we have main task
if PatchKerasModelIO.__main_task is None: if not PatchKerasModelIO._current_task:
return return
try: try:
@ -1851,7 +1887,7 @@ class PatchKerasModelIO(object):
if filepath: if filepath:
WeightsFileHandler.create_output_model( WeightsFileHandler.create_output_model(
self, filepath, Framework.keras, PatchKerasModelIO.__main_task, self, filepath, Framework.keras, PatchKerasModelIO._current_task,
config_obj=config or None, singlefile=True) config_obj=config or None, singlefile=True)
except Exception as ex: except Exception as ex:
@ -1860,12 +1896,12 @@ class PatchKerasModelIO(object):
@staticmethod @staticmethod
def _save_model(original_fn, model, filepath, *args, **kwargs): def _save_model(original_fn, model, filepath, *args, **kwargs):
original_fn(model, filepath, *args, **kwargs) original_fn(model, filepath, *args, **kwargs)
if PatchKerasModelIO.__main_task: if PatchKerasModelIO._current_task:
PatchKerasModelIO._update_outputmodel(model, filepath) PatchKerasModelIO._update_outputmodel(model, filepath)
@staticmethod @staticmethod
def _load_model(original_fn, filepath, *args, **kwargs): def _load_model(original_fn, filepath, *args, **kwargs):
if not PatchKerasModelIO.__main_task: if not PatchKerasModelIO._current_task:
return original_fn(filepath, *args, **kwargs) return original_fn(filepath, *args, **kwargs)
empty = _Empty() empty = _Empty()
@ -1873,12 +1909,12 @@ class PatchKerasModelIO(object):
if False and running_remotely(): if False and running_remotely():
# register/load model weights # register/load model weights
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
PatchKerasModelIO.__main_task) PatchKerasModelIO._current_task)
model = original_fn(filepath, *args, **kwargs) model = original_fn(filepath, *args, **kwargs)
else: else:
model = original_fn(filepath, *args, **kwargs) model = original_fn(filepath, *args, **kwargs)
# register/load model weights # register/load model weights
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task) WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO._current_task)
# update the input model object # update the input model object
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException # noinspection PyBroadException
@ -1891,12 +1927,14 @@ class PatchKerasModelIO(object):
class PatchTensorflowModelIO(object): class PatchTensorflowModelIO(object):
__main_task = None _current_task = None
__patched = None __patched = None
@staticmethod @staticmethod
def update_current_task(task, **_): def update_current_task(task, **_):
PatchTensorflowModelIO.__main_task = task PatchTensorflowModelIO._current_task = task
if not task:
return
PatchTensorflowModelIO._patch_model_checkpoint() PatchTensorflowModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint) PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
@ -2028,29 +2066,31 @@ class PatchTensorflowModelIO(object):
@staticmethod @staticmethod
def _save(original_fn, self, sess, save_path, *args, **kwargs): def _save(original_fn, self, sess, save_path, *args, **kwargs):
saved_path = original_fn(self, sess, save_path, *args, **kwargs) saved_path = original_fn(self, sess, save_path, *args, **kwargs)
if not saved_path: if not saved_path or not PatchTensorflowModelIO._current_task:
return saved_path return saved_path
# store output Model # store output Model
return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow, return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
@staticmethod @staticmethod
def _save_model(original_fn, obj, export_dir, *args, **kwargs): def _save_model(original_fn, obj, export_dir, *args, **kwargs):
original_fn(obj, export_dir, *args, **kwargs) original_fn(obj, export_dir, *args, **kwargs)
if not PatchKerasModelIO._current_task:
return
# store output Model # store output Model
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow, WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
@staticmethod @staticmethod
def _restore(original_fn, self, sess, save_path, *args, **kwargs): def _restore(original_fn, self, sess, save_path, *args, **kwargs):
if PatchTensorflowModelIO.__main_task is None: if PatchTensorflowModelIO._current_task is None:
return original_fn(self, sess, save_path, *args, **kwargs) return original_fn(self, sess, save_path, *args, **kwargs)
# Hack: disabled # Hack: disabled
if False and running_remotely(): if False and running_remotely():
# register/load model weights # register/load model weights
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
# load model # load model
return original_fn(self, sess, save_path, *args, **kwargs) return original_fn(self, sess, save_path, *args, **kwargs)
@ -2058,12 +2098,12 @@ class PatchTensorflowModelIO(object):
model = original_fn(self, sess, save_path, *args, **kwargs) model = original_fn(self, sess, save_path, *args, **kwargs)
# register/load model weights # register/load model weights
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
return model return model
@staticmethod @staticmethod
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs): def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
if PatchTensorflowModelIO.__main_task is None: if PatchTensorflowModelIO._current_task is None:
return original_fn(sess, tags, export_dir, *args, **saver_kwargs) return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
# register input model # register input model
@ -2072,7 +2112,7 @@ class PatchTensorflowModelIO(object):
if False and running_remotely(): if False and running_remotely():
export_dir = WeightsFileHandler.restore_weights_file( export_dir = WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow, empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task PatchTensorflowModelIO._current_task
) )
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs) model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
else: else:
@ -2080,7 +2120,7 @@ class PatchTensorflowModelIO(object):
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs) model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
WeightsFileHandler.restore_weights_file( WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow, empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task PatchTensorflowModelIO._current_task
) )
if empty.trains_in_model: if empty.trains_in_model:
@ -2093,7 +2133,7 @@ class PatchTensorflowModelIO(object):
@staticmethod @staticmethod
def _load(original_fn, export_dir, *args, **saver_kwargs): def _load(original_fn, export_dir, *args, **saver_kwargs):
if PatchTensorflowModelIO.__main_task is None: if PatchTensorflowModelIO._current_task is None:
return original_fn(export_dir, *args, **saver_kwargs) return original_fn(export_dir, *args, **saver_kwargs)
# register input model # register input model
@ -2102,7 +2142,7 @@ class PatchTensorflowModelIO(object):
if False and running_remotely(): if False and running_remotely():
export_dir = WeightsFileHandler.restore_weights_file( export_dir = WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow, empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task PatchTensorflowModelIO._current_task
) )
model = original_fn(export_dir, *args, **saver_kwargs) model = original_fn(export_dir, *args, **saver_kwargs)
else: else:
@ -2110,7 +2150,7 @@ class PatchTensorflowModelIO(object):
model = original_fn(export_dir, *args, **saver_kwargs) model = original_fn(export_dir, *args, **saver_kwargs)
WeightsFileHandler.restore_weights_file( WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow, empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task PatchTensorflowModelIO._current_task
) )
if empty.trains_in_model: if empty.trains_in_model:
@ -2124,24 +2164,24 @@ class PatchTensorflowModelIO(object):
@staticmethod @staticmethod
def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs): def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs):
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs) checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
if PatchTensorflowModelIO.__main_task is None: if PatchTensorflowModelIO._current_task is None:
return checkpoint_path return checkpoint_path
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow, WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
return checkpoint_path return checkpoint_path
@staticmethod @staticmethod
def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs): def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs):
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs) checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
if PatchTensorflowModelIO.__main_task is None: if PatchTensorflowModelIO._current_task is None:
return checkpoint_path return checkpoint_path
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow, WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
return checkpoint_path return checkpoint_path
@staticmethod @staticmethod
def _ckpt_restore(original_fn, self, save_path, *args, **kwargs): def _ckpt_restore(original_fn, self, save_path, *args, **kwargs):
if PatchTensorflowModelIO.__main_task is None: if PatchTensorflowModelIO._current_task is None:
return original_fn(self, save_path, *args, **kwargs) return original_fn(self, save_path, *args, **kwargs)
# register input model # register input model
@ -2149,13 +2189,13 @@ class PatchTensorflowModelIO(object):
# Hack: disabled # Hack: disabled
if False and running_remotely(): if False and running_remotely():
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow, save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
model = original_fn(self, save_path, *args, **kwargs) model = original_fn(self, save_path, *args, **kwargs)
else: else:
# try to load model before registering it, in case it fails. # try to load model before registering it, in case it fails.
model = original_fn(self, save_path, *args, **kwargs) model = original_fn(self, save_path, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow, WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task) PatchTensorflowModelIO._current_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException # noinspection PyBroadException
@ -2167,12 +2207,14 @@ class PatchTensorflowModelIO(object):
class PatchTensorflow2ModelIO(object): class PatchTensorflow2ModelIO(object):
__main_task = None _current_task = None
__patched = None __patched = None
@staticmethod @staticmethod
def update_current_task(task, **_): def update_current_task(task, **_):
PatchTensorflow2ModelIO.__main_task = task PatchTensorflow2ModelIO._current_task = task
if not task:
return
PatchTensorflow2ModelIO._patch_model_checkpoint() PatchTensorflow2ModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint) PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
@ -2210,18 +2252,20 @@ class PatchTensorflow2ModelIO(object):
@staticmethod @staticmethod
def _save(original_fn, self, file_prefix, *args, **kwargs): def _save(original_fn, self, file_prefix, *args, **kwargs):
model = original_fn(self, file_prefix, *args, **kwargs) model = original_fn(self, file_prefix, *args, **kwargs)
if not PatchTensorflow2ModelIO._current_task:
return model
# store output Model # store output Model
# noinspection PyBroadException # noinspection PyBroadException
try: try:
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow, WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task) PatchTensorflow2ModelIO._current_task)
except Exception: except Exception:
pass pass
return model return model
@staticmethod @staticmethod
def _restore(original_fn, self, save_path, *args, **kwargs): def _restore(original_fn, self, save_path, *args, **kwargs):
if PatchTensorflow2ModelIO.__main_task is None: if not PatchTensorflow2ModelIO._current_task:
return original_fn(self, save_path, *args, **kwargs) return original_fn(self, save_path, *args, **kwargs)
# Hack: disabled # Hack: disabled
@ -2230,7 +2274,7 @@ class PatchTensorflow2ModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task) PatchTensorflow2ModelIO._current_task)
except Exception: except Exception:
pass pass
# load model # load model
@ -2242,7 +2286,7 @@ class PatchTensorflow2ModelIO(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow, WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task) PatchTensorflow2ModelIO._current_task)
except Exception: except Exception:
pass pass
return model return model

View File

@ -11,13 +11,15 @@ from ...model import Framework
class PatchXGBoostModelIO(PatchBaseModelIO): class PatchXGBoostModelIO(PatchBaseModelIO):
__main_task = None _current_task = None
__patched = None __patched = None
__callback_cls = None __callback_cls = None
@staticmethod @staticmethod
def update_current_task(task, **kwargs): def update_current_task(task, **kwargs):
PatchXGBoostModelIO.__main_task = task PatchXGBoostModelIO._current_task = task
if not task:
return
PatchXGBoostModelIO._patch_model_io() PatchXGBoostModelIO._patch_model_io()
PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io) PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io)
@ -55,7 +57,7 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
@staticmethod @staticmethod
def _save(original_fn, obj, f, *args, **kwargs): def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs) ret = original_fn(obj, f, *args, **kwargs)
if not PatchXGBoostModelIO.__main_task: if not PatchXGBoostModelIO._current_task:
return ret return ret
if isinstance(f, six.string_types): if isinstance(f, six.string_types):
@ -76,12 +78,15 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
model_name = Path(filename).stem model_name = Path(filename).stem
except Exception: except Exception:
model_name = None model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task, WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO._current_task,
singlefile=True, model_name=model_name) singlefile=True, model_name=model_name)
return ret return ret
@staticmethod @staticmethod
def _load(original_fn, f, *args, **kwargs): def _load(original_fn, f, *args, **kwargs):
if not PatchXGBoostModelIO._current_task:
return original_fn(f, *args, **kwargs)
if isinstance(f, six.string_types): if isinstance(f, six.string_types):
filename = f filename = f
elif hasattr(f, 'name'): elif hasattr(f, 'name'):
@ -91,21 +96,18 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
else: else:
filename = None filename = None
if not PatchXGBoostModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model # register input model
empty = _Empty() empty = _Empty()
# Hack: disabled # Hack: disabled
if False and running_remotely(): if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchXGBoostModelIO.__main_task) PatchXGBoostModelIO._current_task)
model = original_fn(filename or f, *args, **kwargs) model = original_fn(filename or f, *args, **kwargs)
else: else:
# try to load model before registering, in case we fail # try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs) model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost, WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchXGBoostModelIO.__main_task) PatchXGBoostModelIO._current_task)
if empty.trains_in_model: if empty.trains_in_model:
# noinspection PyBroadException # noinspection PyBroadException
@ -117,10 +119,12 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
@staticmethod @staticmethod
def _train(original_fn, *args, **kwargs): def _train(original_fn, *args, **kwargs):
if not PatchXGBoostModelIO._current_task:
return original_fn(*args, **kwargs)
if PatchXGBoostModelIO.__callback_cls: if PatchXGBoostModelIO.__callback_cls:
callbacks = kwargs.get('callbacks') or [] callbacks = kwargs.get('callbacks') or []
kwargs['callbacks'] = callbacks + [ kwargs['callbacks'] = callbacks + [
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task) PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO._current_task)
] ]
return original_fn(*args, **kwargs) return original_fn(*args, **kwargs)

View File

@ -2,7 +2,7 @@ import io
import sys import sys
from functools import partial from functools import partial
from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE from ..config import running_remotely, DEV_TASK_NO_REUSE
from ..debugging.log import LoggerRoot from ..debugging.log import LoggerRoot
@ -46,6 +46,8 @@ class PatchHydra(object):
def update_current_task(task): def update_current_task(task):
# set current Task before patching # set current Task before patching
PatchHydra._current_task = task PatchHydra._current_task = task
if not task:
return
if PatchHydra.patch_hydra(): if PatchHydra.patch_hydra():
# check if we have an untracked state, store it. # check if we have an untracked state, store it.
if PatchHydra._last_untracked_state.get('connect'): if PatchHydra._last_untracked_state.get('connect'):
@ -66,9 +68,6 @@ class PatchHydra(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if running_remotely(): if running_remotely():
if not PatchHydra._current_task:
from ..task import Task
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
# get the _parameter_allow_full_edit casted back to boolean # get the _parameter_allow_full_edit casted back to boolean
connected_config = dict() connected_config = dict()
connected_config[PatchHydra._parameter_allow_full_edit] = False connected_config[PatchHydra._parameter_allow_full_edit] = False
@ -126,8 +125,9 @@ class PatchHydra(object):
if pre_app_task_init_call and not running_remotely(): if pre_app_task_init_call and not running_remotely():
LoggerRoot.get_base_logger(PatchHydra).info( LoggerRoot.get_base_logger(PatchHydra).info(
'Task.init called outside of Hydra-App. For full Hydra multi-run support, ' "Task.init called outside of Hydra-App. For full Hydra multi-run support, "
'move the Task.init call into the Hydra-App main function') "move the Task.init call into the Hydra-App main function"
)
kwargs["config"] = config kwargs["config"] = config
kwargs["task_function"] = partial(PatchHydra._patched_task_function, task_function,) kwargs["task_function"] = partial(PatchHydra._patched_task_function, task_function,)

View File

@ -77,6 +77,8 @@ class PatchedJoblib(object):
@staticmethod @staticmethod
def update_current_task(task): def update_current_task(task):
PatchedJoblib._current_task = task PatchedJoblib._current_task = task
if not task:
return
PatchedJoblib.patch_joblib() PatchedJoblib.patch_joblib()
@staticmethod @staticmethod

View File

@ -14,7 +14,7 @@ from ..utilities.proxy_object import flatten_dictionary
class PatchJsonArgParse(object): class PatchJsonArgParse(object):
_args = {} _args = {}
_main_task = None _current_task = None
_args_sep = "/" _args_sep = "/"
_args_type = {} _args_type = {}
_commands_sep = "." _commands_sep = "."
@ -24,14 +24,19 @@ class PatchJsonArgParse(object):
__remote_task_params = {} __remote_task_params = {}
__patched = False __patched = False
@classmethod
def update_current_task(cls, task):
cls._current_task = task
if not task:
return
cls.patch(task)
@classmethod @classmethod
def patch(cls, task): def patch(cls, task):
if ArgumentParser is None: if ArgumentParser is None:
return return
if task: PatchJsonArgParse._update_task_args()
cls._main_task = task
PatchJsonArgParse._update_task_args()
if not cls.__patched: if not cls.__patched:
cls.__patched = True cls.__patched = True
@ -39,14 +44,16 @@ class PatchJsonArgParse(object):
@classmethod @classmethod
def _update_task_args(cls): def _update_task_args(cls):
if running_remotely() or not cls._main_task or not cls._args: if running_remotely() or not cls._current_task or not cls._args:
return return
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()} args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()} args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()}
cls._main_task._set_parameters(args, __update=True, __parameters_types=args_type) cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
@staticmethod @staticmethod
def _parse_args(original_fn, obj, *args, **kwargs): def _parse_args(original_fn, obj, *args, **kwargs):
if not PatchJsonArgParse._current_task:
return original_fn(obj, *args, **kwargs)
if len(args) == 1: if len(args) == 1:
kwargs["args"] = args[0] kwargs["args"] = args[0]
args = [] args = []

View File

@ -207,12 +207,15 @@ class PatchedMatplotlib:
PatchedMatplotlib._current_task = task PatchedMatplotlib._current_task = task
PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib) PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib)
PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib) PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib)
elif PatchedMatplotlib.patch_matplotlib(): else:
PatchedMatplotlib.patch_matplotlib()
PatchedMatplotlib._current_task = task PatchedMatplotlib._current_task = task
@staticmethod @staticmethod
def patched_imshow(*args, **kw): def patched_imshow(*args, **kw):
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw) ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
if not PatchedMatplotlib._current_task:
return ret
try: try:
from matplotlib import _pylab_helpers from matplotlib import _pylab_helpers
# store on the plot that this is an imshow plot # store on the plot that this is an imshow plot
@ -227,6 +230,8 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def patched_savefig(self, *args, **kw): def patched_savefig(self, *args, **kw):
ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw) ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
if not PatchedMatplotlib._current_task:
return ret
# noinspection PyBroadException # noinspection PyBroadException
try: try:
fname = kw.get('fname') or args[0] fname = kw.get('fname') or args[0]
@ -256,6 +261,8 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def patched_figure_show(self, *args, **kw): def patched_figure_show(self, *args, **kw):
if not PatchedMatplotlib._current_task:
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
tid = threading._get_ident() if six.PY2 else threading.get_ident() tid = threading._get_ident() if six.PY2 else threading.get_ident()
if PatchedMatplotlib._recursion_guard.get(tid): if PatchedMatplotlib._recursion_guard.get(tid):
# we are inside a gaurd do nothing # we are inside a gaurd do nothing
@ -269,6 +276,8 @@ class PatchedMatplotlib:
@staticmethod @staticmethod
def patched_show(*args, **kw): def patched_show(*args, **kw):
if not PatchedMatplotlib._current_task:
return PatchedMatplotlib._patched_original_plot(*args, **kw)
tid = threading._get_ident() if six.PY2 else threading.get_ident() tid = threading._get_ident() if six.PY2 else threading.get_ident()
PatchedMatplotlib._recursion_guard[tid] = True PatchedMatplotlib._recursion_guard[tid] = True
# noinspection PyBroadException # noinspection PyBroadException

View File

@ -626,8 +626,7 @@ class Task(_Task):
task.__register_at_exit(task._at_exit) task.__register_at_exit(task._at_exit)
# always patch OS forking because of ProcessPool and the alike # always patch OS forking because of ProcessPool and the alike
PatchOsFork.patch_fork() PatchOsFork.patch_fork(task)
if auto_connect_frameworks: if auto_connect_frameworks:
def should_connect(*keys): def should_connect(*keys):
""" """
@ -708,13 +707,15 @@ class Task(_Task):
make_deterministic(task.get_random_seed()) make_deterministic(task.get_random_seed())
if auto_connect_arg_parser: if auto_connect_arg_parser:
EnvironmentBind.update_current_task(Task.__main_task) EnvironmentBind.update_current_task(task)
PatchJsonArgParse.update_current_task(task)
# Patch ArgParser to be aware of the current task # Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task) argparser_update_currenttask(task)
PatchClick.patch(Task.__main_task)
PatchFire.patch(Task.__main_task) PatchClick.patch(task)
PatchJsonArgParse.patch(Task.__main_task) PatchFire.patch(task)
# set excluded arguments # set excluded arguments
if isinstance(auto_connect_arg_parser, dict): if isinstance(auto_connect_arg_parser, dict):
@ -1686,6 +1687,22 @@ class Task(_Task):
# noinspection PyProtectedMember # noinspection PyProtectedMember
Logger._remove_std_logger() Logger._remove_std_logger()
# unbind everything
PatchHydra.update_current_task(None)
PatchedJoblib.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
PatchAbsl.update_current_task(None)
TensorflowBinding.update_current_task(None)
PatchPyTorchModelIO.update_current_task(None)
PatchMegEngineModelIO.update_current_task(None)
PatchXGBoostModelIO.update_current_task(None)
PatchCatBoostModelIO.update_current_task(None)
PatchFastai.update_current_task(None)
PatchLIGHTgbmModelIO.update_current_task(None)
EnvironmentBind.update_current_task(None)
PatchJsonArgParse.update_current_task(None)
PatchOsFork.patch_fork(None)
def delete( def delete(
self, self,
delete_artifacts_and_models=True, delete_artifacts_and_models=True,