Unbind all patches when Task.close() is called

This commit is contained in:
allegroai 2022-05-12 23:47:58 +03:00
parent 3c396b1f7e
commit 5dff9dd404
16 changed files with 350 additions and 227 deletions

View File

@ -8,12 +8,15 @@ from ..config import running_remotely
class PatchAbsl(object):
_original_DEFINE_flag = None
_original_FLAGS_parse_call = None
_task = None
_current_task = None
__patched = False
@classmethod
def update_current_task(cls, current_task):
cls._task = current_task
cls._patch_absl()
def update_current_task(cls, task):
cls._current_task = task
if not cls.__patched:
cls._patch_absl()
cls.__patched = True
@classmethod
def _patch_absl(cls):
@ -53,7 +56,7 @@ class PatchAbsl(object):
@staticmethod
def _patched_define_flag(*args, **kwargs):
if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag:
if not PatchAbsl._current_task or not PatchAbsl._original_DEFINE_flag:
if PatchAbsl._original_DEFINE_flag:
return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
else:
@ -73,7 +76,7 @@ class PatchAbsl(object):
# noinspection PyBroadException
try:
if param_name and flag:
param_dict = PatchAbsl._task._arguments.copy_to_dict(
param_dict = PatchAbsl._current_task._arguments.copy_to_dict(
{param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
flag.value = param_dict.get(param_name, flag.value)
except Exception:
@ -82,13 +85,17 @@ class PatchAbsl(object):
else:
if flag and param_name:
value = flag.value
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, )
PatchAbsl._current_task.update_parameters(
{_Arguments._prefix_tf_defines + param_name: value},
)
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
return ret
@staticmethod
def _patched_FLAGS_parse_call(self, *args, **kwargs):
ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
if not PatchAbsl._current_task:
return ret
# noinspection PyBroadException
try:
PatchAbsl._update_current_flags(self)
@ -98,13 +105,13 @@ class PatchAbsl(object):
@classmethod
def _update_current_flags(cls, FLAGS):
if not cls._task:
if not cls._current_task:
return
# noinspection PyBroadException
try:
if running_remotely():
param_dict = dict((k, FLAGS[k].value) for k in FLAGS)
param_dict = cls._task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
param_dict = cls._current_task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
for k, v in param_dict.items():
# noinspection PyBroadException
try:
@ -127,7 +134,7 @@ class PatchAbsl(object):
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
except Exception:
param_types = None
cls._task._arguments.copy_from_dict(
cls._current_task._arguments.copy_from_dict(
parameters,
prefix=_Arguments._prefix_tf_defines,
descriptions=descriptions, param_types=param_types,

View File

@ -16,7 +16,7 @@ class PatchClick:
_num_commands = 0
_command_type = 'click.Command'
_section_name = 'Args'
_main_task = None
_current_task = None
__remote_task_params = None
__remote_task_params_dict = {}
__patched = False
@ -26,8 +26,8 @@ class PatchClick:
if Command is None:
return
cls._current_task = task
if task:
cls._main_task = task
PatchClick._update_task_args()
if not cls.__patched:
@ -53,11 +53,11 @@ class PatchClick:
@classmethod
def _update_task_args(cls):
if running_remotely() or not cls._main_task or not cls._args:
if running_remotely() or not cls._current_task or not cls._args:
return
param_val, param_types, param_desc = cls.args()
# noinspection PyProtectedMember
cls._main_task._set_parameters(
cls._current_task._set_parameters(
param_val,
__update=True,
__parameters_descriptions=param_desc,

View File

@ -1,4 +1,5 @@
import os
from time import sleep
import six
@ -7,26 +8,29 @@ from ..utilities.process.mp import BackgroundMonitor
class EnvironmentBind(object):
_task = None
_current_task = None
_environment_section = 'Environment'
__patched = False
@classmethod
def update_current_task(cls, current_task):
cls._task = current_task
def update_current_task(cls, task):
cls._current_task = task
# noinspection PyBroadException
try:
cls._bind_environment()
if not cls.__patched:
cls.__patched = True
cls._bind_environment()
except Exception:
pass
@classmethod
def _bind_environment(cls):
if not cls._task:
if not cls._current_task:
return
# get ENVIRONMENT and put it into the OS environment
if running_remotely():
params = cls._task.get_parameters_as_dict()
params = cls._current_task.get_parameters_as_dict()
if params and cls._environment_section in params:
# put back into os:
os.environ.update(params[cls._environment_section])
@ -55,14 +59,18 @@ class EnvironmentBind(object):
elif match in os.environ:
env_param.update({match: os.environ.get(match)})
# store os environments
cls._task.connect(env_param, cls._environment_section)
cls._current_task.connect(env_param, cls._environment_section)
class PatchOsFork(object):
_original_fork = None
_current_task = None
@classmethod
def patch_fork(cls):
def patch_fork(cls, task):
cls._current_task = task
if not task:
return
# noinspection PyBroadException
try:
# only once
@ -88,11 +96,13 @@ class PatchOsFork(object):
Task._wait_for_deferred(task)
ret = PatchOsFork._original_fork(*args, **kwargs)
if not PatchOsFork._current_task:
return ret
# Make sure the new process stdout is logged
if not ret:
# force creating a Task
task = Task.current_task()
if task is None:
if not task:
return ret
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
@ -106,7 +116,7 @@ class PatchOsFork(object):
# just make sure we flush the internal state (the at exist caught by the external signal does the rest
# in theory we should not have to do any of that, but for some reason if we do not
# the signal is never caught by the signal call backs, not sure why....
sleep(0.1)
# Since at_exist handlers do not work on forked processes, we have to manually call them here
if task:
try:

View File

@ -20,7 +20,7 @@ class PatchFire:
_section_name = "Args"
_args_sep = "/"
_commands_sep = "."
_main_task = None
_current_task = None
__remote_task_params = None
__remote_task_params_dict = {}
__patched = False
@ -38,8 +38,8 @@ class PatchFire:
if fire is None:
return
cls._current_task = task
if task:
cls._main_task = task
cls._update_task_args()
if not cls.__patched:
@ -53,7 +53,7 @@ class PatchFire:
@classmethod
def _update_task_args(cls):
if running_remotely() or not cls._main_task:
if running_remotely() or not cls._current_task:
return
args = {}
parameters_types = {}
@ -118,7 +118,7 @@ class PatchFire:
parameters_types = {**parameters_types, **unused_paramenters_types}
# noinspection PyProtectedMember
cls._main_task._set_parameters(
cls._current_task._set_parameters(
args,
__update=True,
__parameters_types=parameters_types,
@ -126,25 +126,25 @@ class PatchFire:
@staticmethod
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
if running_remotely():
command = PatchFire._load_task_params()
if command is not None:
replaced_args = command.split(PatchFire._commands_sep)
else:
replaced_args = []
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
if command is not None and param.type == PatchFire._command_arg_type_template % command:
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
value = PatchFire.__remote_task_params_dict[param.name]
if len(value) > 0:
replaced_args.append(value)
if param.type == PatchFire._shared_arg_type:
replaced_args.append("--" + param.name)
value = PatchFire.__remote_task_params_dict[param.name]
if len(value) > 0:
replaced_args.append(value)
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
if not running_remotely():
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
command = PatchFire._load_task_params()
if command is not None:
replaced_args = command.split(PatchFire._commands_sep)
else:
replaced_args = []
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
if command is not None and param.type == PatchFire._command_arg_type_template % command:
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
value = PatchFire.__remote_task_params_dict[param.name]
if len(value) > 0:
replaced_args.append(value)
if param.type == PatchFire._shared_arg_type:
replaced_args.append("--" + param.name)
value = PatchFire.__remote_task_params_dict[param.name]
if len(value) > 0:
replaced_args.append(value)
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
@staticmethod
def __CallAndUpdateTrace( # noqa

View File

@ -11,13 +11,15 @@ from ...model import Framework
class PatchCatBoostModelIO(PatchBaseModelIO):
__main_task = None
_current_task = None
__patched = None
__callback_cls = None
@staticmethod
def update_current_task(task, **kwargs):
PatchCatBoostModelIO.__main_task = task
PatchCatBoostModelIO._current_task = task
if not task:
return
PatchCatBoostModelIO._patch_model_io()
PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io)
@ -38,16 +40,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit)
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
CatBoostRanker.fit = _patched_call(CatBoostRanker.fit, PatchCatBoostModelIO._fit)
except Exception as e:
logger = PatchCatBoostModelIO.__main_task.get_logger()
logger = PatchCatBoostModelIO._current_task.get_logger()
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
ret = original_fn(obj, f, *args, **kwargs)
if not PatchCatBoostModelIO.__main_task:
if not PatchCatBoostModelIO._current_task:
return ret
if isinstance(f, six.string_types):
filename = f
@ -60,13 +62,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
except Exception:
model_name = None
WeightsFileHandler.create_output_model(
obj, filename, Framework.catboost, PatchCatBoostModelIO.__main_task, singlefile=True, model_name=model_name
obj, filename, Framework.catboost, PatchCatBoostModelIO._current_task, singlefile=True, model_name=model_name
)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model
if not PatchCatBoostModelIO._current_task:
return original_fn(f, *args, **kwargs)
if isinstance(f, six.string_types):
filename = f
elif len(args) >= 1 and isinstance(args[0], six.string_types):
@ -74,13 +79,10 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
else:
filename = None
if not PatchCatBoostModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO.__main_task)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO._current_task)
if empty.trains_in_model:
# noinspection PyBroadException
try:
@ -91,13 +93,15 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
@staticmethod
def _fit(original_fn, obj, *args, **kwargs):
if not PatchCatBoostModelIO._current_task:
return original_fn(obj, *args, **kwargs)
callbacks = kwargs.get("callbacks") or []
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO._current_task)]
# noinspection PyBroadException
try:
return original_fn(obj, *args, **kwargs)
except Exception:
logger = PatchCatBoostModelIO.__main_task.get_logger()
logger = PatchCatBoostModelIO._current_task.get_logger()
logger.report_text(
"Catboost metrics logging is not supported for GPU. "
"See https://github.com/catboost/catboost/issues/1792"

View File

@ -25,11 +25,9 @@ class PatchFastai(object):
try:
if Version(fastai.__version__) < Version("2.0.0"):
PatchFastaiV1.update_current_task(task)
PatchFastaiV1.patch_model_callback()
PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback)
else:
PatchFastaiV2.update_current_task(task)
PatchFastaiV2.patch_model_callback()
PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback)
except Exception:
pass
@ -38,11 +36,17 @@ class PatchFastai(object):
class PatchFastaiV1(object):
__metrics_names = {}
__gradient_hist_helpers = {}
_main_task = None
_current_task = None
__patched = False
@staticmethod
def update_current_task(task, **_):
PatchFastaiV1._main_task = task
PatchFastaiV1._current_task = task
if not task:
return
if not PatchFastaiV1.__patched:
PatchFastaiV1.__patched = True
PatchFastaiV1.patch_model_callback()
@staticmethod
def patch_model_callback():
@ -52,7 +56,6 @@ class PatchFastaiV1(object):
try:
from fastai.basic_train import Recorder
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end)
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end)
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end)
@ -65,7 +68,7 @@ class PatchFastaiV1(object):
@staticmethod
def _on_train_begin(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task:
if not PatchFastaiV1._current_task:
return
# noinspection PyBroadException
try:
@ -84,7 +87,7 @@ class PatchFastaiV1(object):
original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task:
if not PatchFastaiV1._current_task:
return
# noinspection PyBroadException
@ -112,7 +115,7 @@ class PatchFastaiV1(object):
min_gradient=gradient_stats[:, 5].min(),
)
logger = PatchFastaiV1._main_task.get_logger()
logger = PatchFastaiV1._current_task.get_logger()
iteration = kwargs.get("iteration", 0)
for name, val in stats_report.items():
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
@ -122,26 +125,26 @@ class PatchFastaiV1(object):
@staticmethod
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task:
if not PatchFastaiV1._current_task:
return
# noinspection PyBroadException
try:
logger = PatchFastaiV1._main_task.get_logger()
logger = PatchFastaiV1._current_task.get_logger()
iteration = kwargs.get("iteration")
for series, value in zip(
PatchFastaiV1.__metrics_names[id(recorder)],
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
):
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
PatchFastaiV1._main_task.flush()
PatchFastaiV1._current_task.flush()
except Exception:
pass
@staticmethod
def _on_batch_end(original_fn, recorder, *args, **kwargs):
original_fn(recorder, *args, **kwargs)
if not PatchFastaiV1._main_task:
if not PatchFastaiV1._current_task:
return
# noinspection PyBroadException
@ -150,7 +153,7 @@ class PatchFastaiV1(object):
if iteration == 0 or not kwargs.get("train"):
return
logger = PatchFastaiV1._main_task.get_logger()
logger = PatchFastaiV1._current_task.get_logger()
logger.report_scalar(
title="metrics",
series="train_loss",
@ -174,11 +177,17 @@ class PatchFastaiV1(object):
class PatchFastaiV2(object):
_main_task = None
_current_task = None
__patched = False
@staticmethod
def update_current_task(task, **_):
PatchFastaiV2._main_task = task
PatchFastaiV2._current_task = task
if not task:
return
if not PatchFastaiV2.__patched:
PatchFastaiV2.__patched = True
PatchFastaiV2.patch_model_callback()
@staticmethod
def patch_model_callback():
@ -212,13 +221,15 @@ class PatchFastaiV2(object):
self.logger = noop
self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id)
PatchFastaiV2.PatchFastaiCallbacks.__id += 1
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._main_task.get_logger())
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._current_task.get_logger())
def after_batch(self):
# noinspection PyBroadException
try:
super().after_batch() # noqa
logger = PatchFastaiV2._main_task.get_logger()
if not PatchFastaiV2._current_task:
return
logger = PatchFastaiV2._current_task.get_logger()
if not self.training: # noqa
return
self.__train_iter += 1
@ -254,7 +265,9 @@ class PatchFastaiV2(object):
# noinspection PyBroadException
try:
super().after_epoch() # noqa
logger = PatchFastaiV2._main_task.get_logger()
if not PatchFastaiV2._current_task:
return
logger = PatchFastaiV2._current_task.get_logger()
for metric in self._valid_mets: # noqa
logger.report_scalar(
title="metrics_" + self.__id,
@ -270,7 +283,9 @@ class PatchFastaiV2(object):
try:
if hasattr(fastai.learner.Recorder, "before_step"):
super().before_step() # noqa
logger = PatchFastaiV2._main_task.get_logger()
if not PatchFastaiV2._current_task:
return
logger = PatchFastaiV2._current_task.get_logger()
gradients = [
x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None
] # noqa

View File

@ -11,12 +11,14 @@ from ...model import Framework
class PatchLIGHTgbmModelIO(PatchBaseModelIO):
__main_task = None
_current_task = None
__patched = None
@staticmethod
def update_current_task(task, **kwargs):
PatchLIGHTgbmModelIO.__main_task = task
PatchLIGHTgbmModelIO._current_task = task
if not task:
return
PatchLIGHTgbmModelIO._patch_model_io()
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
@ -43,7 +45,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchLIGHTgbmModelIO.__main_task:
if not PatchLIGHTgbmModelIO._current_task:
return ret
if isinstance(f, six.string_types):
@ -64,12 +66,15 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task,
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO._current_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, model_file, *args, **kwargs):
if not PatchLIGHTgbmModelIO._current_task:
return original_fn(model_file, *args, **kwargs)
if isinstance(model_file, six.string_types):
filename = model_file
elif hasattr(model_file, 'name'):
@ -79,21 +84,18 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
else:
filename = None
if not PatchLIGHTgbmModelIO.__main_task:
return original_fn(model_file, *args, **kwargs)
# register input model
empty = _Empty()
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchLIGHTgbmModelIO.__main_task)
PatchLIGHTgbmModelIO._current_task)
model = original_fn(model_file=filename or model_file, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(model_file=model_file, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
PatchLIGHTgbmModelIO.__main_task)
PatchLIGHTgbmModelIO._current_task)
if empty.trains_in_model:
# noinspection PyBroadException
@ -110,7 +112,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
# logging the results to scalars section
# noinspection PyBroadException
try:
logger = PatchLIGHTgbmModelIO.__main_task.get_logger()
logger = PatchLIGHTgbmModelIO._current_task.get_logger()
iteration = env.iteration
for data_title, data_series, value, _ in env.evaluation_result_list:
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
@ -121,12 +123,12 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
ret = original_fn(*args, **kwargs)
if not PatchLIGHTgbmModelIO.__main_task:
if not PatchLIGHTgbmModelIO._current_task:
return ret
params = args[0] if args else kwargs.get('params', {})
for k, v in params.items():
if isinstance(v, set):
params[k] = list(v)
if params:
PatchLIGHTgbmModelIO.__main_task.connect(params)
PatchLIGHTgbmModelIO._current_task.connect(params)
return ret

View File

@ -10,14 +10,14 @@ from ...model import Framework
class PatchMegEngineModelIO(PatchBaseModelIO):
__main_task = None
_current_task = None
__patched = None
# __patched_lightning = None
@staticmethod
def update_current_task(task, **_):
PatchMegEngineModelIO.__main_task = task
PatchMegEngineModelIO._current_task = task
if not task:
return
PatchMegEngineModelIO._patch_model_io()
PostImportHookPatching.add_on_import(
'megengine', PatchMegEngineModelIO._patch_model_io
@ -58,7 +58,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
ret = original_fn(obj, f, *args, **kwargs)
# if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task:
if not PatchMegEngineModelIO._current_task:
return ret
# noinspection PyBroadException
@ -93,7 +93,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
WeightsFileHandler.create_output_model(
obj, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task,
PatchMegEngineModelIO._current_task,
singlefile=True, model_name=model_name,
)
@ -102,7 +102,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
@staticmethod
def _load(original_fn, f, *args, **kwargs):
# if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task:
if not PatchMegEngineModelIO._current_task:
return original_fn(f, *args, **kwargs)
# noinspection PyBroadException
@ -124,7 +124,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(
empty, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task
PatchMegEngineModelIO._current_task
)
if empty.trains_in_model:
@ -139,7 +139,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
@staticmethod
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
# if there is no main task or this is a nested call
if not PatchMegEngineModelIO.__main_task:
if not PatchMegEngineModelIO._current_task:
return original_fn(obj, f, *args, **kwargs)
# noinspection PyBroadException
@ -161,7 +161,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
model = original_fn(obj, f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(
empty, filename, Framework.megengine,
PatchMegEngineModelIO.__main_task,
PatchMegEngineModelIO._current_task,
)
if empty.trains_in_model:

View File

@ -11,13 +11,15 @@ from ...model import Framework
class PatchPyTorchModelIO(PatchBaseModelIO):
__main_task = None
_current_task = None
__patched = None
__patched_lightning = None
@staticmethod
def update_current_task(task, **_):
PatchPyTorchModelIO.__main_task = task
PatchPyTorchModelIO._current_task = task
if not task:
return
PatchPyTorchModelIO._patch_model_io()
PatchPyTorchModelIO._patch_lightning_io()
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
@ -112,7 +114,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
ret = original_fn(obj, f, *args, **kwargs)
# if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task:
if not PatchPyTorchModelIO._current_task:
return ret
# pytorch-lightning check if rank is zero
@ -154,14 +156,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
model_name = None
WeightsFileHandler.create_output_model(
obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name)
obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
# if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task:
if not PatchPyTorchModelIO._current_task:
return original_fn(f, *args, **kwargs)
# noinspection PyBroadException
@ -182,13 +184,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
if empty.trains_in_model:
# noinspection PyBroadException
@ -202,7 +204,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
@staticmethod
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
# if there is no main task or this is a nested call
if not PatchPyTorchModelIO.__main_task:
if not PatchPyTorchModelIO._current_task:
return original_fn(obj, f, *args, **kwargs)
# noinspection PyBroadException
@ -223,13 +225,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
model = original_fn(obj, filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(obj, f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
if empty.trains_in_model:
# noinspection PyBroadException

View File

@ -232,7 +232,7 @@ class EventTrainsWriter(object):
TF SummaryWriter implementation that converts the tensorboard's summary into
ClearML events and reports the events (metrics) for an ClearML task (logger).
"""
__main_task = None
_current_task = None
__report_hparams = True
_add_lock = threading.RLock()
_series_name_lookup = {}
@ -641,7 +641,7 @@ class EventTrainsWriter(object):
session_start_info = parse_session_start_info_plugin_data(content)
session_start_info = MessageToDict(session_start_info)
hparams = session_start_info["hparams"]
EventTrainsWriter.__main_task.update_parameters(
EventTrainsWriter._current_task.update_parameters(
{"TB_hparams/{}".format(k): v for k, v in hparams.items()}
)
except Exception:
@ -844,13 +844,13 @@ class EventTrainsWriter(object):
@classmethod
def update_current_task(cls, task, **kwargs):
cls.__report_hparams = kwargs.get('report_hparams', False)
if cls.__main_task != task:
if cls._current_task != task:
with cls._add_lock:
cls._series_name_lookup = {}
cls._title_series_writers_lookup = {}
cls._event_writers_id_to_logdir = {}
cls._title_series_wraparound_counter = {}
cls.__main_task = task
cls._current_task = task
# noinspection PyCallingNonCallable
@ -905,9 +905,10 @@ class ProxyEventsWriter(object):
# noinspection PyPep8Naming
class PatchSummaryToEventTransformer(object):
__main_task = None
_current_task = None
__original_getattribute = None
__original_getattributeX = None
__patched = False
_original_add_event = None
_original_add_eventT = None
_original_add_eventX = None
@ -930,15 +931,19 @@ class PatchSummaryToEventTransformer(object):
@staticmethod
def update_current_task(task, **kwargs):
PatchSummaryToEventTransformer.defaults_dict.update(kwargs)
PatchSummaryToEventTransformer.__main_task = task
# make sure we patched the SummaryToEventTransformer
PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
PostImportHookPatching.add_on_import('tensorflow',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('torch',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('tensorboardX',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PatchSummaryToEventTransformer._current_task = task
if not task:
return
if not PatchSummaryToEventTransformer.__patched:
PatchSummaryToEventTransformer.__patched = True
# make sure we patched the SummaryToEventTransformer
PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
PostImportHookPatching.add_on_import('tensorflow',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('torch',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
PostImportHookPatching.add_on_import('tensorboardX',
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
@staticmethod
def _patch_summary_to_event_transformer():
@ -1002,7 +1007,7 @@ class PatchSummaryToEventTransformer(object):
@staticmethod
def _patched_add_eventT(self, *args, **kwargs):
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task:
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
if not self.clearml: # noqa
# noinspection PyBroadException
@ -1010,7 +1015,7 @@ class PatchSummaryToEventTransformer(object):
logdir = self.get_logdir()
except Exception:
logdir = None
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
@ -1021,7 +1026,7 @@ class PatchSummaryToEventTransformer(object):
@staticmethod
def _patched_add_eventX(self, *args, **kwargs):
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task:
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
if not self.clearml:
# noinspection PyBroadException
@ -1029,7 +1034,7 @@ class PatchSummaryToEventTransformer(object):
logdir = self.get_logdir()
except Exception:
logdir = None
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
# noinspection PyBroadException
try:
@ -1051,7 +1056,7 @@ class PatchSummaryToEventTransformer(object):
@staticmethod
def _patched_getattribute_(self, attr, get_base):
# no main task, zero chance we have an ClearML event logger
if PatchSummaryToEventTransformer.__main_task is None:
if PatchSummaryToEventTransformer._current_task is None:
return get_base(self, attr)
# check if we already have an ClearML event logger
@ -1068,7 +1073,7 @@ class PatchSummaryToEventTransformer(object):
except Exception:
logdir = None
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
logdir=logdir, **defaults_dict)
# order is important, the return value of ProxyEventsWriter is the last object in the list
@ -1111,8 +1116,9 @@ class _ModelAdapter(object):
class PatchModelCheckPointCallback(object):
__main_task = None
_current_task = None
__original_getattribute = None
__patched = False
defaults_dict = dict(
config_text=None,
config_dict=None,
@ -1132,11 +1138,15 @@ class PatchModelCheckPointCallback(object):
@staticmethod
def update_current_task(task, **kwargs):
PatchModelCheckPointCallback.defaults_dict.update(kwargs)
PatchModelCheckPointCallback.__main_task = task
# make sure we patched the SummaryToEventTransformer
PatchModelCheckPointCallback._patch_model_checkpoint()
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
PatchModelCheckPointCallback._current_task = task
if not task:
return
if not PatchModelCheckPointCallback.__patched:
PatchModelCheckPointCallback.__patched = True
# make sure we patched the SummaryToEventTransformer
PatchModelCheckPointCallback._patch_model_checkpoint()
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
@staticmethod
def _patch_model_checkpoint():
@ -1176,7 +1186,7 @@ class PatchModelCheckPointCallback(object):
get_base = PatchModelCheckPointCallback.__original_getattribute
# no main task, zero chance we have an ClearML event logger
if PatchModelCheckPointCallback.__main_task is None:
if PatchModelCheckPointCallback._current_task is None:
return get_base(self, attr)
# check if we already have an ClearML event logger
@ -1189,17 +1199,17 @@ class PatchModelCheckPointCallback(object):
base_model = __dict__['model']
defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict
output_model = OutputModel(
PatchModelCheckPointCallback.__main_task,
PatchModelCheckPointCallback._current_task,
config_text=defaults_dict.get('config_text'),
config_dict=defaults_dict.get('config_dict'),
name=defaults_dict.get('name'),
comment=defaults_dict.get('comment'),
label_enumeration=defaults_dict.get('label_enumeration') or
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
PatchModelCheckPointCallback._current_task.get_labels_enumeration(),
framework=Framework.keras,
)
output_model.set_upload_destination(
PatchModelCheckPointCallback.__main_task.get_output_destination(raise_on_error=False))
PatchModelCheckPointCallback._current_task.get_output_destination(raise_on_error=False))
trains_model = _ModelAdapter(base_model, output_model)
# order is important, the return value of ProxyEventsWriter is the last object in the list
@ -1209,26 +1219,32 @@ class PatchModelCheckPointCallback(object):
# noinspection PyProtectedMember,PyUnresolvedReferences
class PatchTensorFlowEager(object):
__main_task = None
_current_task = None
__original_fn_scalar = None
__original_fn_hist = None
__original_fn_image = None
__original_fn_write_summary = None
__trains_event_writer = {}
__tf_tb_writer_id_to_logdir = {}
__patched = False
defaults_dict = dict(
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
histogram_granularity=50)
@staticmethod
def update_current_task(task, **kwargs):
if task != PatchTensorFlowEager.__main_task:
if task != PatchTensorFlowEager._current_task:
PatchTensorFlowEager.__trains_event_writer = {}
PatchTensorFlowEager.defaults_dict.update(kwargs)
PatchTensorFlowEager.__main_task = task
# make sure we patched the SummaryToEventTransformer
PatchTensorFlowEager._patch_summary_ops()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
PatchTensorFlowEager._current_task = task
if not task:
return
if not PatchTensorFlowEager.__patched:
PatchTensorFlowEager.__patched = True
# make sure we patched the SummaryToEventTransformer
PatchTensorFlowEager._patch_summary_ops()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
@staticmethod
def _patch_summary_ops():
@ -1245,7 +1261,7 @@ class PatchTensorFlowEager(object):
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary
PatchTensorFlowEager.__original_fn_write_summary = gen_summary_ops.write_summary
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
gen_summary_ops.create_summary_file_writer)
@ -1270,6 +1286,8 @@ class PatchTensorFlowEager(object):
@staticmethod
def _create_summary_file_writer(original_fn, *args, **kwargs):
if not PatchTensorFlowEager._current_task:
return original_fn(*args, **kwargs)
# noinspection PyBroadException
try:
a_logdir = None
@ -1294,7 +1312,7 @@ class PatchTensorFlowEager(object):
@staticmethod
def _get_event_writer(writer):
if not PatchTensorFlowEager.__main_task:
if not PatchTensorFlowEager._current_task:
return None
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
# noinspection PyBroadException
@ -1324,7 +1342,7 @@ class PatchTensorFlowEager(object):
logdir = None
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
logger=PatchTensorFlowEager._current_task.get_logger(), logdir=logdir,
**PatchTensorFlowEager.defaults_dict)
return PatchTensorFlowEager.__trains_event_writer[id(writer)]
@ -1337,6 +1355,10 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_write_summary(
writer, step, tensor, tag, summary_metadata, name, **kwargs
)
event_writer = PatchTensorFlowEager._get_event_writer(writer)
# make sure we can get the tensors values
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
@ -1370,13 +1392,17 @@ class PatchTensorFlowEager(object):
step=int(step.numpy()) if not isinstance(step, int) else step,
values=None, audio_data=audio_bytes)
else:
pass # print('unsupported plugin_type', plugin_type)
pass
except Exception:
pass
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs)
return PatchTensorFlowEager.__original_fn_write_summary(
writer, step, tensor, tag, summary_metadata, name, **kwargs
)
@staticmethod
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
@ -1417,6 +1443,8 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
@ -1459,6 +1487,9 @@ class PatchTensorFlowEager(object):
@staticmethod
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
if not PatchTensorFlowEager._current_task:
return PatchTensorFlowEager.__original_fn_image(
writer, step, tag, tensor, bad_color, max_images, name, **kwargs)
event_writer = PatchTensorFlowEager._get_event_writer(writer)
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
try:
@ -1526,13 +1557,15 @@ class PatchTensorFlowEager(object):
# noinspection PyPep8Naming,SpellCheckingInspection
class PatchKerasModelIO(object):
__main_task = None
_current_task = None
__patched_keras = None
__patched_tensorflow = None
@staticmethod
def update_current_task(task, **_):
PatchKerasModelIO.__main_task = task
PatchKerasModelIO._current_task = task
if not task:
return
PatchKerasModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint)
@ -1692,7 +1725,7 @@ class PatchKerasModelIO(object):
def _updated_config(original_fn, self):
config = original_fn(self)
# check if we have main task
if PatchKerasModelIO.__main_task is None:
if PatchKerasModelIO._current_task is None:
return config
try:
@ -1717,10 +1750,10 @@ class PatchKerasModelIO(object):
else:
# todo: support multiple models for the same task
self.trains_out_model.append(OutputModel(
task=PatchKerasModelIO.__main_task,
task=PatchKerasModelIO._current_task,
config_dict=safe_config,
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
name=PatchKerasModelIO._current_task.name + ' ' + model_name_id,
label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
framework=Framework.keras,
))
except Exception as ex:
@ -1738,7 +1771,7 @@ class PatchKerasModelIO(object):
self = _Empty()
# check if we have main task
if PatchKerasModelIO.__main_task is None:
if PatchKerasModelIO._current_task is None:
return self
try:
@ -1751,10 +1784,10 @@ class PatchKerasModelIO(object):
# check if object already has InputModel
self.trains_in_model = InputModel.empty(
config_dict=config_dict,
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
)
# todo: support multiple models for the same task
PatchKerasModelIO.__main_task.connect(self.trains_in_model)
PatchKerasModelIO._current_task.connect(self.trains_in_model)
# if we are running remotely we should deserialize the object
# because someone might have changed the configuration
# Hack: disabled
@ -1781,7 +1814,7 @@ class PatchKerasModelIO(object):
@staticmethod
def _load_weights(original_fn, self, *args, **kwargs):
# check if we have main task
if PatchKerasModelIO.__main_task is None:
if PatchKerasModelIO._current_task is None:
return original_fn(self, *args, **kwargs)
# get filepath
@ -1794,7 +1827,7 @@ class PatchKerasModelIO(object):
if False and running_remotely():
# register/load model weights
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
PatchKerasModelIO.__main_task)
PatchKerasModelIO._current_task)
if 'filepath' in kwargs:
kwargs['filepath'] = filepath
else:
@ -1805,11 +1838,13 @@ class PatchKerasModelIO(object):
# try to load the files, if something happened exception will be raised before we register the file
model = original_fn(self, *args, **kwargs)
# register/load model weights
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task)
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO._current_task)
return model
@staticmethod
def _save(original_fn, self, *args, **kwargs):
if not PatchKerasModelIO._current_task:
return original_fn(self, *args, **kwargs)
if hasattr(self, 'trains_out_model') and self.trains_out_model:
# noinspection PyProtectedMember
self.trains_out_model[-1]._processed = False
@ -1823,12 +1858,13 @@ class PatchKerasModelIO(object):
@staticmethod
def _save_weights(original_fn, self, *args, **kwargs):
original_fn(self, *args, **kwargs)
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
if PatchKerasModelIO._current_task:
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
@staticmethod
def _update_outputmodel(self, *args, **kwargs):
# check if we have main task
if PatchKerasModelIO.__main_task is None:
if not PatchKerasModelIO._current_task:
return
try:
@ -1851,7 +1887,7 @@ class PatchKerasModelIO(object):
if filepath:
WeightsFileHandler.create_output_model(
self, filepath, Framework.keras, PatchKerasModelIO.__main_task,
self, filepath, Framework.keras, PatchKerasModelIO._current_task,
config_obj=config or None, singlefile=True)
except Exception as ex:
@ -1860,12 +1896,12 @@ class PatchKerasModelIO(object):
@staticmethod
def _save_model(original_fn, model, filepath, *args, **kwargs):
original_fn(model, filepath, *args, **kwargs)
if PatchKerasModelIO.__main_task:
if PatchKerasModelIO._current_task:
PatchKerasModelIO._update_outputmodel(model, filepath)
@staticmethod
def _load_model(original_fn, filepath, *args, **kwargs):
if not PatchKerasModelIO.__main_task:
if not PatchKerasModelIO._current_task:
return original_fn(filepath, *args, **kwargs)
empty = _Empty()
@ -1873,12 +1909,12 @@ class PatchKerasModelIO(object):
if False and running_remotely():
# register/load model weights
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
PatchKerasModelIO.__main_task)
PatchKerasModelIO._current_task)
model = original_fn(filepath, *args, **kwargs)
else:
model = original_fn(filepath, *args, **kwargs)
# register/load model weights
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO._current_task)
# update the input model object
if empty.trains_in_model:
# noinspection PyBroadException
@ -1891,12 +1927,14 @@ class PatchKerasModelIO(object):
class PatchTensorflowModelIO(object):
__main_task = None
_current_task = None
__patched = None
@staticmethod
def update_current_task(task, **_):
PatchTensorflowModelIO.__main_task = task
PatchTensorflowModelIO._current_task = task
if not task:
return
PatchTensorflowModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
@ -2028,29 +2066,31 @@ class PatchTensorflowModelIO(object):
@staticmethod
def _save(original_fn, self, sess, save_path, *args, **kwargs):
saved_path = original_fn(self, sess, save_path, *args, **kwargs)
if not saved_path:
if not saved_path or not PatchTensorflowModelIO._current_task:
return saved_path
# store output Model
return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
@staticmethod
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
original_fn(obj, export_dir, *args, **kwargs)
if not PatchKerasModelIO._current_task:
return
# store output Model
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
@staticmethod
def _restore(original_fn, self, sess, save_path, *args, **kwargs):
if PatchTensorflowModelIO.__main_task is None:
if PatchTensorflowModelIO._current_task is None:
return original_fn(self, sess, save_path, *args, **kwargs)
# Hack: disabled
if False and running_remotely():
# register/load model weights
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
# load model
return original_fn(self, sess, save_path, *args, **kwargs)
@ -2058,12 +2098,12 @@ class PatchTensorflowModelIO(object):
model = original_fn(self, sess, save_path, *args, **kwargs)
# register/load model weights
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
return model
@staticmethod
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
if PatchTensorflowModelIO.__main_task is None:
if PatchTensorflowModelIO._current_task is None:
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
# register input model
@ -2072,7 +2112,7 @@ class PatchTensorflowModelIO(object):
if False and running_remotely():
export_dir = WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
PatchTensorflowModelIO._current_task
)
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
else:
@ -2080,7 +2120,7 @@ class PatchTensorflowModelIO(object):
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
PatchTensorflowModelIO._current_task
)
if empty.trains_in_model:
@ -2093,7 +2133,7 @@ class PatchTensorflowModelIO(object):
@staticmethod
def _load(original_fn, export_dir, *args, **saver_kwargs):
if PatchTensorflowModelIO.__main_task is None:
if PatchTensorflowModelIO._current_task is None:
return original_fn(export_dir, *args, **saver_kwargs)
# register input model
@ -2102,7 +2142,7 @@ class PatchTensorflowModelIO(object):
if False and running_remotely():
export_dir = WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
PatchTensorflowModelIO._current_task
)
model = original_fn(export_dir, *args, **saver_kwargs)
else:
@ -2110,7 +2150,7 @@ class PatchTensorflowModelIO(object):
model = original_fn(export_dir, *args, **saver_kwargs)
WeightsFileHandler.restore_weights_file(
empty, export_dir, Framework.tensorflow,
PatchTensorflowModelIO.__main_task
PatchTensorflowModelIO._current_task
)
if empty.trains_in_model:
@ -2124,24 +2164,24 @@ class PatchTensorflowModelIO(object):
@staticmethod
def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs):
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
if PatchTensorflowModelIO.__main_task is None:
if PatchTensorflowModelIO._current_task is None:
return checkpoint_path
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
return checkpoint_path
@staticmethod
def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs):
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
if PatchTensorflowModelIO.__main_task is None:
if PatchTensorflowModelIO._current_task is None:
return checkpoint_path
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
return checkpoint_path
@staticmethod
def _ckpt_restore(original_fn, self, save_path, *args, **kwargs):
if PatchTensorflowModelIO.__main_task is None:
if PatchTensorflowModelIO._current_task is None:
return original_fn(self, save_path, *args, **kwargs)
# register input model
@ -2149,13 +2189,13 @@ class PatchTensorflowModelIO(object):
# Hack: disabled
if False and running_remotely():
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
model = original_fn(self, save_path, *args, **kwargs)
else:
# try to load model before registering it, in case it fails.
model = original_fn(self, save_path, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
PatchTensorflowModelIO.__main_task)
PatchTensorflowModelIO._current_task)
if empty.trains_in_model:
# noinspection PyBroadException
@ -2167,12 +2207,14 @@ class PatchTensorflowModelIO(object):
class PatchTensorflow2ModelIO(object):
__main_task = None
_current_task = None
__patched = None
@staticmethod
def update_current_task(task, **_):
PatchTensorflow2ModelIO.__main_task = task
PatchTensorflow2ModelIO._current_task = task
if not task:
return
PatchTensorflow2ModelIO._patch_model_checkpoint()
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
@ -2210,18 +2252,20 @@ class PatchTensorflow2ModelIO(object):
@staticmethod
def _save(original_fn, self, file_prefix, *args, **kwargs):
model = original_fn(self, file_prefix, *args, **kwargs)
if not PatchTensorflow2ModelIO._current_task:
return model
# store output Model
# noinspection PyBroadException
try:
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
PatchTensorflow2ModelIO._current_task)
except Exception:
pass
return model
@staticmethod
def _restore(original_fn, self, save_path, *args, **kwargs):
if PatchTensorflow2ModelIO.__main_task is None:
if not PatchTensorflow2ModelIO._current_task:
return original_fn(self, save_path, *args, **kwargs)
# Hack: disabled
@ -2230,7 +2274,7 @@ class PatchTensorflow2ModelIO(object):
# noinspection PyBroadException
try:
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
PatchTensorflow2ModelIO._current_task)
except Exception:
pass
# load model
@ -2242,7 +2286,7 @@ class PatchTensorflow2ModelIO(object):
# noinspection PyBroadException
try:
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
PatchTensorflow2ModelIO.__main_task)
PatchTensorflow2ModelIO._current_task)
except Exception:
pass
return model

View File

@ -11,13 +11,15 @@ from ...model import Framework
class PatchXGBoostModelIO(PatchBaseModelIO):
__main_task = None
_current_task = None
__patched = None
__callback_cls = None
@staticmethod
def update_current_task(task, **kwargs):
PatchXGBoostModelIO.__main_task = task
PatchXGBoostModelIO._current_task = task
if not task:
return
PatchXGBoostModelIO._patch_model_io()
PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io)
@ -55,7 +57,7 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
@staticmethod
def _save(original_fn, obj, f, *args, **kwargs):
ret = original_fn(obj, f, *args, **kwargs)
if not PatchXGBoostModelIO.__main_task:
if not PatchXGBoostModelIO._current_task:
return ret
if isinstance(f, six.string_types):
@ -76,12 +78,15 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
model_name = Path(filename).stem
except Exception:
model_name = None
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task,
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO._current_task,
singlefile=True, model_name=model_name)
return ret
@staticmethod
def _load(original_fn, f, *args, **kwargs):
if not PatchXGBoostModelIO._current_task:
return original_fn(f, *args, **kwargs)
if isinstance(f, six.string_types):
filename = f
elif hasattr(f, 'name'):
@ -91,21 +96,18 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
else:
filename = None
if not PatchXGBoostModelIO.__main_task:
return original_fn(f, *args, **kwargs)
# register input model
empty = _Empty()
# Hack: disabled
if False and running_remotely():
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchXGBoostModelIO.__main_task)
PatchXGBoostModelIO._current_task)
model = original_fn(filename or f, *args, **kwargs)
else:
# try to load model before registering, in case we fail
model = original_fn(f, *args, **kwargs)
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
PatchXGBoostModelIO.__main_task)
PatchXGBoostModelIO._current_task)
if empty.trains_in_model:
# noinspection PyBroadException
@ -117,10 +119,12 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
@staticmethod
def _train(original_fn, *args, **kwargs):
if not PatchXGBoostModelIO._current_task:
return original_fn(*args, **kwargs)
if PatchXGBoostModelIO.__callback_cls:
callbacks = kwargs.get('callbacks') or []
kwargs['callbacks'] = callbacks + [
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task)
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO._current_task)
]
return original_fn(*args, **kwargs)

View File

@ -2,7 +2,7 @@ import io
import sys
from functools import partial
from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE
from ..config import running_remotely, DEV_TASK_NO_REUSE
from ..debugging.log import LoggerRoot
@ -46,6 +46,8 @@ class PatchHydra(object):
def update_current_task(task):
# set current Task before patching
PatchHydra._current_task = task
if not task:
return
if PatchHydra.patch_hydra():
# check if we have an untracked state, store it.
if PatchHydra._last_untracked_state.get('connect'):
@ -66,9 +68,6 @@ class PatchHydra(object):
# noinspection PyBroadException
try:
if running_remotely():
if not PatchHydra._current_task:
from ..task import Task
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
# get the _parameter_allow_full_edit casted back to boolean
connected_config = dict()
connected_config[PatchHydra._parameter_allow_full_edit] = False
@ -126,8 +125,9 @@ class PatchHydra(object):
if pre_app_task_init_call and not running_remotely():
LoggerRoot.get_base_logger(PatchHydra).info(
'Task.init called outside of Hydra-App. For full Hydra multi-run support, '
'move the Task.init call into the Hydra-App main function')
"Task.init called outside of Hydra-App. For full Hydra multi-run support, "
"move the Task.init call into the Hydra-App main function"
)
kwargs["config"] = config
kwargs["task_function"] = partial(PatchHydra._patched_task_function, task_function,)

View File

@ -77,6 +77,8 @@ class PatchedJoblib(object):
@staticmethod
def update_current_task(task):
PatchedJoblib._current_task = task
if not task:
return
PatchedJoblib.patch_joblib()
@staticmethod

View File

@ -14,7 +14,7 @@ from ..utilities.proxy_object import flatten_dictionary
class PatchJsonArgParse(object):
_args = {}
_main_task = None
_current_task = None
_args_sep = "/"
_args_type = {}
_commands_sep = "."
@ -24,14 +24,19 @@ class PatchJsonArgParse(object):
__remote_task_params = {}
__patched = False
@classmethod
def update_current_task(cls, task):
cls._current_task = task
if not task:
return
cls.patch(task)
@classmethod
def patch(cls, task):
if ArgumentParser is None:
return
if task:
cls._main_task = task
PatchJsonArgParse._update_task_args()
PatchJsonArgParse._update_task_args()
if not cls.__patched:
cls.__patched = True
@ -39,14 +44,16 @@ class PatchJsonArgParse(object):
@classmethod
def _update_task_args(cls):
if running_remotely() or not cls._main_task or not cls._args:
if running_remotely() or not cls._current_task or not cls._args:
return
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()}
cls._main_task._set_parameters(args, __update=True, __parameters_types=args_type)
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
@staticmethod
def _parse_args(original_fn, obj, *args, **kwargs):
if not PatchJsonArgParse._current_task:
return original_fn(obj, *args, **kwargs)
if len(args) == 1:
kwargs["args"] = args[0]
args = []

View File

@ -207,12 +207,15 @@ class PatchedMatplotlib:
PatchedMatplotlib._current_task = task
PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib)
PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib)
elif PatchedMatplotlib.patch_matplotlib():
else:
PatchedMatplotlib.patch_matplotlib()
PatchedMatplotlib._current_task = task
@staticmethod
def patched_imshow(*args, **kw):
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
if not PatchedMatplotlib._current_task:
return ret
try:
from matplotlib import _pylab_helpers
# store on the plot that this is an imshow plot
@ -227,6 +230,8 @@ class PatchedMatplotlib:
@staticmethod
def patched_savefig(self, *args, **kw):
ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
if not PatchedMatplotlib._current_task:
return ret
# noinspection PyBroadException
try:
fname = kw.get('fname') or args[0]
@ -256,6 +261,8 @@ class PatchedMatplotlib:
@staticmethod
def patched_figure_show(self, *args, **kw):
if not PatchedMatplotlib._current_task:
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
tid = threading._get_ident() if six.PY2 else threading.get_ident()
if PatchedMatplotlib._recursion_guard.get(tid):
# we are inside a gaurd do nothing
@ -269,6 +276,8 @@ class PatchedMatplotlib:
@staticmethod
def patched_show(*args, **kw):
if not PatchedMatplotlib._current_task:
return PatchedMatplotlib._patched_original_plot(*args, **kw)
tid = threading._get_ident() if six.PY2 else threading.get_ident()
PatchedMatplotlib._recursion_guard[tid] = True
# noinspection PyBroadException

View File

@ -626,8 +626,7 @@ class Task(_Task):
task.__register_at_exit(task._at_exit)
# always patch OS forking because of ProcessPool and the alike
PatchOsFork.patch_fork()
PatchOsFork.patch_fork(task)
if auto_connect_frameworks:
def should_connect(*keys):
"""
@ -708,13 +707,15 @@ class Task(_Task):
make_deterministic(task.get_random_seed())
if auto_connect_arg_parser:
EnvironmentBind.update_current_task(Task.__main_task)
EnvironmentBind.update_current_task(task)
PatchJsonArgParse.update_current_task(task)
# Patch ArgParser to be aware of the current task
argparser_update_currenttask(Task.__main_task)
PatchClick.patch(Task.__main_task)
PatchFire.patch(Task.__main_task)
PatchJsonArgParse.patch(Task.__main_task)
argparser_update_currenttask(task)
PatchClick.patch(task)
PatchFire.patch(task)
# set excluded arguments
if isinstance(auto_connect_arg_parser, dict):
@ -1686,6 +1687,22 @@ class Task(_Task):
# noinspection PyProtectedMember
Logger._remove_std_logger()
# unbind everything
PatchHydra.update_current_task(None)
PatchedJoblib.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
PatchAbsl.update_current_task(None)
TensorflowBinding.update_current_task(None)
PatchPyTorchModelIO.update_current_task(None)
PatchMegEngineModelIO.update_current_task(None)
PatchXGBoostModelIO.update_current_task(None)
PatchCatBoostModelIO.update_current_task(None)
PatchFastai.update_current_task(None)
PatchLIGHTgbmModelIO.update_current_task(None)
EnvironmentBind.update_current_task(None)
PatchJsonArgParse.update_current_task(None)
PatchOsFork.patch_fork(None)
def delete(
self,
delete_artifacts_and_models=True,