mirror of
https://github.com/clearml/clearml
synced 2025-05-22 21:14:04 +00:00
Unbind all patches when Task.close()
is called
This commit is contained in:
parent
3c396b1f7e
commit
5dff9dd404
@ -8,12 +8,15 @@ from ..config import running_remotely
|
||||
class PatchAbsl(object):
|
||||
_original_DEFINE_flag = None
|
||||
_original_FLAGS_parse_call = None
|
||||
_task = None
|
||||
_current_task = None
|
||||
__patched = False
|
||||
|
||||
@classmethod
|
||||
def update_current_task(cls, current_task):
|
||||
cls._task = current_task
|
||||
cls._patch_absl()
|
||||
def update_current_task(cls, task):
|
||||
cls._current_task = task
|
||||
if not cls.__patched:
|
||||
cls._patch_absl()
|
||||
cls.__patched = True
|
||||
|
||||
@classmethod
|
||||
def _patch_absl(cls):
|
||||
@ -53,7 +56,7 @@ class PatchAbsl(object):
|
||||
|
||||
@staticmethod
|
||||
def _patched_define_flag(*args, **kwargs):
|
||||
if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag:
|
||||
if not PatchAbsl._current_task or not PatchAbsl._original_DEFINE_flag:
|
||||
if PatchAbsl._original_DEFINE_flag:
|
||||
return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
||||
else:
|
||||
@ -73,7 +76,7 @@ class PatchAbsl(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if param_name and flag:
|
||||
param_dict = PatchAbsl._task._arguments.copy_to_dict(
|
||||
param_dict = PatchAbsl._current_task._arguments.copy_to_dict(
|
||||
{param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
|
||||
flag.value = param_dict.get(param_name, flag.value)
|
||||
except Exception:
|
||||
@ -82,13 +85,17 @@ class PatchAbsl(object):
|
||||
else:
|
||||
if flag and param_name:
|
||||
value = flag.value
|
||||
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, )
|
||||
PatchAbsl._current_task.update_parameters(
|
||||
{_Arguments._prefix_tf_defines + param_name: value},
|
||||
)
|
||||
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _patched_FLAGS_parse_call(self, *args, **kwargs):
|
||||
ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
|
||||
if not PatchAbsl._current_task:
|
||||
return ret
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
PatchAbsl._update_current_flags(self)
|
||||
@ -98,13 +105,13 @@ class PatchAbsl(object):
|
||||
|
||||
@classmethod
|
||||
def _update_current_flags(cls, FLAGS):
|
||||
if not cls._task:
|
||||
if not cls._current_task:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if running_remotely():
|
||||
param_dict = dict((k, FLAGS[k].value) for k in FLAGS)
|
||||
param_dict = cls._task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
|
||||
param_dict = cls._current_task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
|
||||
for k, v in param_dict.items():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -127,7 +134,7 @@ class PatchAbsl(object):
|
||||
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
|
||||
except Exception:
|
||||
param_types = None
|
||||
cls._task._arguments.copy_from_dict(
|
||||
cls._current_task._arguments.copy_from_dict(
|
||||
parameters,
|
||||
prefix=_Arguments._prefix_tf_defines,
|
||||
descriptions=descriptions, param_types=param_types,
|
||||
|
@ -16,7 +16,7 @@ class PatchClick:
|
||||
_num_commands = 0
|
||||
_command_type = 'click.Command'
|
||||
_section_name = 'Args'
|
||||
_main_task = None
|
||||
_current_task = None
|
||||
__remote_task_params = None
|
||||
__remote_task_params_dict = {}
|
||||
__patched = False
|
||||
@ -26,8 +26,8 @@ class PatchClick:
|
||||
if Command is None:
|
||||
return
|
||||
|
||||
cls._current_task = task
|
||||
if task:
|
||||
cls._main_task = task
|
||||
PatchClick._update_task_args()
|
||||
|
||||
if not cls.__patched:
|
||||
@ -53,11 +53,11 @@ class PatchClick:
|
||||
|
||||
@classmethod
|
||||
def _update_task_args(cls):
|
||||
if running_remotely() or not cls._main_task or not cls._args:
|
||||
if running_remotely() or not cls._current_task or not cls._args:
|
||||
return
|
||||
param_val, param_types, param_desc = cls.args()
|
||||
# noinspection PyProtectedMember
|
||||
cls._main_task._set_parameters(
|
||||
cls._current_task._set_parameters(
|
||||
param_val,
|
||||
__update=True,
|
||||
__parameters_descriptions=param_desc,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
import six
|
||||
|
||||
@ -7,26 +8,29 @@ from ..utilities.process.mp import BackgroundMonitor
|
||||
|
||||
|
||||
class EnvironmentBind(object):
|
||||
_task = None
|
||||
_current_task = None
|
||||
_environment_section = 'Environment'
|
||||
__patched = False
|
||||
|
||||
@classmethod
|
||||
def update_current_task(cls, current_task):
|
||||
cls._task = current_task
|
||||
def update_current_task(cls, task):
|
||||
cls._current_task = task
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cls._bind_environment()
|
||||
if not cls.__patched:
|
||||
cls.__patched = True
|
||||
cls._bind_environment()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _bind_environment(cls):
|
||||
if not cls._task:
|
||||
if not cls._current_task:
|
||||
return
|
||||
|
||||
# get ENVIRONMENT and put it into the OS environment
|
||||
if running_remotely():
|
||||
params = cls._task.get_parameters_as_dict()
|
||||
params = cls._current_task.get_parameters_as_dict()
|
||||
if params and cls._environment_section in params:
|
||||
# put back into os:
|
||||
os.environ.update(params[cls._environment_section])
|
||||
@ -55,14 +59,18 @@ class EnvironmentBind(object):
|
||||
elif match in os.environ:
|
||||
env_param.update({match: os.environ.get(match)})
|
||||
# store os environments
|
||||
cls._task.connect(env_param, cls._environment_section)
|
||||
cls._current_task.connect(env_param, cls._environment_section)
|
||||
|
||||
|
||||
class PatchOsFork(object):
|
||||
_original_fork = None
|
||||
_current_task = None
|
||||
|
||||
@classmethod
|
||||
def patch_fork(cls):
|
||||
def patch_fork(cls, task):
|
||||
cls._current_task = task
|
||||
if not task:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# only once
|
||||
@ -88,11 +96,13 @@ class PatchOsFork(object):
|
||||
Task._wait_for_deferred(task)
|
||||
|
||||
ret = PatchOsFork._original_fork(*args, **kwargs)
|
||||
if not PatchOsFork._current_task:
|
||||
return ret
|
||||
# Make sure the new process stdout is logged
|
||||
if not ret:
|
||||
# force creating a Task
|
||||
task = Task.current_task()
|
||||
if task is None:
|
||||
if not task:
|
||||
return ret
|
||||
|
||||
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
|
||||
@ -106,7 +116,7 @@ class PatchOsFork(object):
|
||||
# just make sure we flush the internal state (the at exist caught by the external signal does the rest
|
||||
# in theory we should not have to do any of that, but for some reason if we do not
|
||||
# the signal is never caught by the signal call backs, not sure why....
|
||||
|
||||
sleep(0.1)
|
||||
# Since at_exist handlers do not work on forked processes, we have to manually call them here
|
||||
if task:
|
||||
try:
|
||||
|
@ -20,7 +20,7 @@ class PatchFire:
|
||||
_section_name = "Args"
|
||||
_args_sep = "/"
|
||||
_commands_sep = "."
|
||||
_main_task = None
|
||||
_current_task = None
|
||||
__remote_task_params = None
|
||||
__remote_task_params_dict = {}
|
||||
__patched = False
|
||||
@ -38,8 +38,8 @@ class PatchFire:
|
||||
if fire is None:
|
||||
return
|
||||
|
||||
cls._current_task = task
|
||||
if task:
|
||||
cls._main_task = task
|
||||
cls._update_task_args()
|
||||
|
||||
if not cls.__patched:
|
||||
@ -53,7 +53,7 @@ class PatchFire:
|
||||
|
||||
@classmethod
|
||||
def _update_task_args(cls):
|
||||
if running_remotely() or not cls._main_task:
|
||||
if running_remotely() or not cls._current_task:
|
||||
return
|
||||
args = {}
|
||||
parameters_types = {}
|
||||
@ -118,7 +118,7 @@ class PatchFire:
|
||||
parameters_types = {**parameters_types, **unused_paramenters_types}
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
cls._main_task._set_parameters(
|
||||
cls._current_task._set_parameters(
|
||||
args,
|
||||
__update=True,
|
||||
__parameters_types=parameters_types,
|
||||
@ -126,25 +126,25 @@ class PatchFire:
|
||||
|
||||
@staticmethod
|
||||
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
|
||||
if running_remotely():
|
||||
command = PatchFire._load_task_params()
|
||||
if command is not None:
|
||||
replaced_args = command.split(PatchFire._commands_sep)
|
||||
else:
|
||||
replaced_args = []
|
||||
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
|
||||
if command is not None and param.type == PatchFire._command_arg_type_template % command:
|
||||
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
|
||||
value = PatchFire.__remote_task_params_dict[param.name]
|
||||
if len(value) > 0:
|
||||
replaced_args.append(value)
|
||||
if param.type == PatchFire._shared_arg_type:
|
||||
replaced_args.append("--" + param.name)
|
||||
value = PatchFire.__remote_task_params_dict[param.name]
|
||||
if len(value) > 0:
|
||||
replaced_args.append(value)
|
||||
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
|
||||
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
|
||||
if not running_remotely():
|
||||
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
|
||||
command = PatchFire._load_task_params()
|
||||
if command is not None:
|
||||
replaced_args = command.split(PatchFire._commands_sep)
|
||||
else:
|
||||
replaced_args = []
|
||||
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
|
||||
if command is not None and param.type == PatchFire._command_arg_type_template % command:
|
||||
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
|
||||
value = PatchFire.__remote_task_params_dict[param.name]
|
||||
if len(value) > 0:
|
||||
replaced_args.append(value)
|
||||
if param.type == PatchFire._shared_arg_type:
|
||||
replaced_args.append("--" + param.name)
|
||||
value = PatchFire.__remote_task_params_dict[param.name]
|
||||
if len(value) > 0:
|
||||
replaced_args.append(value)
|
||||
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def __CallAndUpdateTrace( # noqa
|
||||
|
@ -11,13 +11,15 @@ from ...model import Framework
|
||||
|
||||
|
||||
class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched = None
|
||||
__callback_cls = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchCatBoostModelIO.__main_task = task
|
||||
PatchCatBoostModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchCatBoostModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io)
|
||||
|
||||
@ -38,16 +40,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||
CatBoostRanker.fit = _patched_call(CatBoostRanker.fit, PatchCatBoostModelIO._fit)
|
||||
except Exception as e:
|
||||
logger = PatchCatBoostModelIO.__main_task.get_logger()
|
||||
logger = PatchCatBoostModelIO._current_task.get_logger()
|
||||
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchCatBoostModelIO.__main_task:
|
||||
if not PatchCatBoostModelIO._current_task:
|
||||
return ret
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
@ -60,13 +62,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(
|
||||
obj, filename, Framework.catboost, PatchCatBoostModelIO.__main_task, singlefile=True, model_name=model_name
|
||||
obj, filename, Framework.catboost, PatchCatBoostModelIO._current_task, singlefile=True, model_name=model_name
|
||||
)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model
|
||||
if not PatchCatBoostModelIO._current_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif len(args) >= 1 and isinstance(args[0], six.string_types):
|
||||
@ -74,13 +79,10 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchCatBoostModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO.__main_task)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO._current_task)
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -91,13 +93,15 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||
|
||||
@staticmethod
|
||||
def _fit(original_fn, obj, *args, **kwargs):
|
||||
if not PatchCatBoostModelIO._current_task:
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
callbacks = kwargs.get("callbacks") or []
|
||||
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
|
||||
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO._current_task)]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
except Exception:
|
||||
logger = PatchCatBoostModelIO.__main_task.get_logger()
|
||||
logger = PatchCatBoostModelIO._current_task.get_logger()
|
||||
logger.report_text(
|
||||
"Catboost metrics logging is not supported for GPU. "
|
||||
"See https://github.com/catboost/catboost/issues/1792"
|
||||
|
@ -25,11 +25,9 @@ class PatchFastai(object):
|
||||
try:
|
||||
if Version(fastai.__version__) < Version("2.0.0"):
|
||||
PatchFastaiV1.update_current_task(task)
|
||||
PatchFastaiV1.patch_model_callback()
|
||||
PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback)
|
||||
else:
|
||||
PatchFastaiV2.update_current_task(task)
|
||||
PatchFastaiV2.patch_model_callback()
|
||||
PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback)
|
||||
except Exception:
|
||||
pass
|
||||
@ -38,11 +36,17 @@ class PatchFastai(object):
|
||||
class PatchFastaiV1(object):
|
||||
__metrics_names = {}
|
||||
__gradient_hist_helpers = {}
|
||||
_main_task = None
|
||||
_current_task = None
|
||||
__patched = False
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchFastaiV1._main_task = task
|
||||
PatchFastaiV1._current_task = task
|
||||
if not task:
|
||||
return
|
||||
if not PatchFastaiV1.__patched:
|
||||
PatchFastaiV1.__patched = True
|
||||
PatchFastaiV1.patch_model_callback()
|
||||
|
||||
@staticmethod
|
||||
def patch_model_callback():
|
||||
@ -52,7 +56,6 @@ class PatchFastaiV1(object):
|
||||
|
||||
try:
|
||||
from fastai.basic_train import Recorder
|
||||
|
||||
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end)
|
||||
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end)
|
||||
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end)
|
||||
@ -65,7 +68,7 @@ class PatchFastaiV1(object):
|
||||
@staticmethod
|
||||
def _on_train_begin(original_fn, recorder, *args, **kwargs):
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
if not PatchFastaiV1._main_task:
|
||||
if not PatchFastaiV1._current_task:
|
||||
return
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -84,7 +87,7 @@ class PatchFastaiV1(object):
|
||||
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
|
||||
if not PatchFastaiV1._main_task:
|
||||
if not PatchFastaiV1._current_task:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -112,7 +115,7 @@ class PatchFastaiV1(object):
|
||||
min_gradient=gradient_stats[:, 5].min(),
|
||||
)
|
||||
|
||||
logger = PatchFastaiV1._main_task.get_logger()
|
||||
logger = PatchFastaiV1._current_task.get_logger()
|
||||
iteration = kwargs.get("iteration", 0)
|
||||
for name, val in stats_report.items():
|
||||
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
|
||||
@ -122,26 +125,26 @@ class PatchFastaiV1(object):
|
||||
@staticmethod
|
||||
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
if not PatchFastaiV1._main_task:
|
||||
if not PatchFastaiV1._current_task:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
logger = PatchFastaiV1._main_task.get_logger()
|
||||
logger = PatchFastaiV1._current_task.get_logger()
|
||||
iteration = kwargs.get("iteration")
|
||||
for series, value in zip(
|
||||
PatchFastaiV1.__metrics_names[id(recorder)],
|
||||
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
|
||||
):
|
||||
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
|
||||
PatchFastaiV1._main_task.flush()
|
||||
PatchFastaiV1._current_task.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _on_batch_end(original_fn, recorder, *args, **kwargs):
|
||||
original_fn(recorder, *args, **kwargs)
|
||||
if not PatchFastaiV1._main_task:
|
||||
if not PatchFastaiV1._current_task:
|
||||
return
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -150,7 +153,7 @@ class PatchFastaiV1(object):
|
||||
if iteration == 0 or not kwargs.get("train"):
|
||||
return
|
||||
|
||||
logger = PatchFastaiV1._main_task.get_logger()
|
||||
logger = PatchFastaiV1._current_task.get_logger()
|
||||
logger.report_scalar(
|
||||
title="metrics",
|
||||
series="train_loss",
|
||||
@ -174,11 +177,17 @@ class PatchFastaiV1(object):
|
||||
|
||||
|
||||
class PatchFastaiV2(object):
|
||||
_main_task = None
|
||||
_current_task = None
|
||||
__patched = False
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchFastaiV2._main_task = task
|
||||
PatchFastaiV2._current_task = task
|
||||
if not task:
|
||||
return
|
||||
if not PatchFastaiV2.__patched:
|
||||
PatchFastaiV2.__patched = True
|
||||
PatchFastaiV2.patch_model_callback()
|
||||
|
||||
@staticmethod
|
||||
def patch_model_callback():
|
||||
@ -212,13 +221,15 @@ class PatchFastaiV2(object):
|
||||
self.logger = noop
|
||||
self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id)
|
||||
PatchFastaiV2.PatchFastaiCallbacks.__id += 1
|
||||
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._main_task.get_logger())
|
||||
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._current_task.get_logger())
|
||||
|
||||
def after_batch(self):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
super().after_batch() # noqa
|
||||
logger = PatchFastaiV2._main_task.get_logger()
|
||||
if not PatchFastaiV2._current_task:
|
||||
return
|
||||
logger = PatchFastaiV2._current_task.get_logger()
|
||||
if not self.training: # noqa
|
||||
return
|
||||
self.__train_iter += 1
|
||||
@ -254,7 +265,9 @@ class PatchFastaiV2(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
super().after_epoch() # noqa
|
||||
logger = PatchFastaiV2._main_task.get_logger()
|
||||
if not PatchFastaiV2._current_task:
|
||||
return
|
||||
logger = PatchFastaiV2._current_task.get_logger()
|
||||
for metric in self._valid_mets: # noqa
|
||||
logger.report_scalar(
|
||||
title="metrics_" + self.__id,
|
||||
@ -270,7 +283,9 @@ class PatchFastaiV2(object):
|
||||
try:
|
||||
if hasattr(fastai.learner.Recorder, "before_step"):
|
||||
super().before_step() # noqa
|
||||
logger = PatchFastaiV2._main_task.get_logger()
|
||||
if not PatchFastaiV2._current_task:
|
||||
return
|
||||
logger = PatchFastaiV2._current_task.get_logger()
|
||||
gradients = [
|
||||
x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None
|
||||
] # noqa
|
||||
|
@ -11,12 +11,14 @@ from ...model import Framework
|
||||
|
||||
|
||||
class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchLIGHTgbmModelIO.__main_task = task
|
||||
PatchLIGHTgbmModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchLIGHTgbmModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
|
||||
|
||||
@ -43,7 +45,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchLIGHTgbmModelIO.__main_task:
|
||||
if not PatchLIGHTgbmModelIO._current_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
@ -64,12 +66,15 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task,
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO._current_task,
|
||||
singlefile=True, model_name=model_name)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, model_file, *args, **kwargs):
|
||||
if not PatchLIGHTgbmModelIO._current_task:
|
||||
return original_fn(model_file, *args, **kwargs)
|
||||
|
||||
if isinstance(model_file, six.string_types):
|
||||
filename = model_file
|
||||
elif hasattr(model_file, 'name'):
|
||||
@ -79,21 +84,18 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchLIGHTgbmModelIO.__main_task:
|
||||
return original_fn(model_file, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||
PatchLIGHTgbmModelIO.__main_task)
|
||||
PatchLIGHTgbmModelIO._current_task)
|
||||
model = original_fn(model_file=filename or model_file, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(model_file=model_file, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
|
||||
PatchLIGHTgbmModelIO.__main_task)
|
||||
PatchLIGHTgbmModelIO._current_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
@ -110,7 +112,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||
# logging the results to scalars section
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
logger = PatchLIGHTgbmModelIO.__main_task.get_logger()
|
||||
logger = PatchLIGHTgbmModelIO._current_task.get_logger()
|
||||
iteration = env.iteration
|
||||
for data_title, data_series, value, _ in env.evaluation_result_list:
|
||||
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
|
||||
@ -121,12 +123,12 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||
|
||||
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
|
||||
ret = original_fn(*args, **kwargs)
|
||||
if not PatchLIGHTgbmModelIO.__main_task:
|
||||
if not PatchLIGHTgbmModelIO._current_task:
|
||||
return ret
|
||||
params = args[0] if args else kwargs.get('params', {})
|
||||
for k, v in params.items():
|
||||
if isinstance(v, set):
|
||||
params[k] = list(v)
|
||||
if params:
|
||||
PatchLIGHTgbmModelIO.__main_task.connect(params)
|
||||
PatchLIGHTgbmModelIO._current_task.connect(params)
|
||||
return ret
|
||||
|
@ -10,14 +10,14 @@ from ...model import Framework
|
||||
|
||||
|
||||
class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched = None
|
||||
|
||||
# __patched_lightning = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchMegEngineModelIO.__main_task = task
|
||||
PatchMegEngineModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchMegEngineModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import(
|
||||
'megengine', PatchMegEngineModelIO._patch_model_io
|
||||
@ -58,7 +58,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchMegEngineModelIO.__main_task:
|
||||
if not PatchMegEngineModelIO._current_task:
|
||||
return ret
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -93,7 +93,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
|
||||
WeightsFileHandler.create_output_model(
|
||||
obj, filename, Framework.megengine,
|
||||
PatchMegEngineModelIO.__main_task,
|
||||
PatchMegEngineModelIO._current_task,
|
||||
singlefile=True, model_name=model_name,
|
||||
)
|
||||
|
||||
@ -102,7 +102,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchMegEngineModelIO.__main_task:
|
||||
if not PatchMegEngineModelIO._current_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -124,7 +124,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.megengine,
|
||||
PatchMegEngineModelIO.__main_task
|
||||
PatchMegEngineModelIO._current_task
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
@ -139,7 +139,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
@staticmethod
|
||||
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchMegEngineModelIO.__main_task:
|
||||
if not PatchMegEngineModelIO._current_task:
|
||||
return original_fn(obj, f, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -161,7 +161,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||
model = original_fn(obj, f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.megengine,
|
||||
PatchMegEngineModelIO.__main_task,
|
||||
PatchMegEngineModelIO._current_task,
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
|
@ -11,13 +11,15 @@ from ...model import Framework
|
||||
|
||||
|
||||
class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched = None
|
||||
__patched_lightning = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchPyTorchModelIO.__main_task = task
|
||||
PatchPyTorchModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchPyTorchModelIO._patch_model_io()
|
||||
PatchPyTorchModelIO._patch_lightning_io()
|
||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||
@ -112,7 +114,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
if not PatchPyTorchModelIO._current_task:
|
||||
return ret
|
||||
|
||||
# pytorch-lightning check if rank is zero
|
||||
@ -154,14 +156,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
model_name = None
|
||||
|
||||
WeightsFileHandler.create_output_model(
|
||||
obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name)
|
||||
obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
|
||||
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
if not PatchPyTorchModelIO._current_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -182,13 +184,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
@ -202,7 +204,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
@staticmethod
|
||||
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
||||
# if there is no main task or this is a nested call
|
||||
if not PatchPyTorchModelIO.__main_task:
|
||||
if not PatchPyTorchModelIO._current_task:
|
||||
return original_fn(obj, f, *args, **kwargs)
|
||||
|
||||
# noinspection PyBroadException
|
||||
@ -223,13 +225,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||
model = original_fn(obj, filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(obj, f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
|
@ -232,7 +232,7 @@ class EventTrainsWriter(object):
|
||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||
ClearML events and reports the events (metrics) for an ClearML task (logger).
|
||||
"""
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__report_hparams = True
|
||||
_add_lock = threading.RLock()
|
||||
_series_name_lookup = {}
|
||||
@ -641,7 +641,7 @@ class EventTrainsWriter(object):
|
||||
session_start_info = parse_session_start_info_plugin_data(content)
|
||||
session_start_info = MessageToDict(session_start_info)
|
||||
hparams = session_start_info["hparams"]
|
||||
EventTrainsWriter.__main_task.update_parameters(
|
||||
EventTrainsWriter._current_task.update_parameters(
|
||||
{"TB_hparams/{}".format(k): v for k, v in hparams.items()}
|
||||
)
|
||||
except Exception:
|
||||
@ -844,13 +844,13 @@ class EventTrainsWriter(object):
|
||||
@classmethod
|
||||
def update_current_task(cls, task, **kwargs):
|
||||
cls.__report_hparams = kwargs.get('report_hparams', False)
|
||||
if cls.__main_task != task:
|
||||
if cls._current_task != task:
|
||||
with cls._add_lock:
|
||||
cls._series_name_lookup = {}
|
||||
cls._title_series_writers_lookup = {}
|
||||
cls._event_writers_id_to_logdir = {}
|
||||
cls._title_series_wraparound_counter = {}
|
||||
cls.__main_task = task
|
||||
cls._current_task = task
|
||||
|
||||
|
||||
# noinspection PyCallingNonCallable
|
||||
@ -905,9 +905,10 @@ class ProxyEventsWriter(object):
|
||||
|
||||
# noinspection PyPep8Naming
|
||||
class PatchSummaryToEventTransformer(object):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__original_getattribute = None
|
||||
__original_getattributeX = None
|
||||
__patched = False
|
||||
_original_add_event = None
|
||||
_original_add_eventT = None
|
||||
_original_add_eventX = None
|
||||
@ -930,15 +931,19 @@ class PatchSummaryToEventTransformer(object):
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchSummaryToEventTransformer.defaults_dict.update(kwargs)
|
||||
PatchSummaryToEventTransformer.__main_task = task
|
||||
# make sure we patched the SummaryToEventTransformer
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
|
||||
PostImportHookPatching.add_on_import('tensorflow',
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||
PostImportHookPatching.add_on_import('torch',
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||
PostImportHookPatching.add_on_import('tensorboardX',
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||
PatchSummaryToEventTransformer._current_task = task
|
||||
if not task:
|
||||
return
|
||||
if not PatchSummaryToEventTransformer.__patched:
|
||||
PatchSummaryToEventTransformer.__patched = True
|
||||
# make sure we patched the SummaryToEventTransformer
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
|
||||
PostImportHookPatching.add_on_import('tensorflow',
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||
PostImportHookPatching.add_on_import('torch',
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||
PostImportHookPatching.add_on_import('tensorboardX',
|
||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||
|
||||
@staticmethod
|
||||
def _patch_summary_to_event_transformer():
|
||||
@ -1002,7 +1007,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
|
||||
@staticmethod
|
||||
def _patched_add_eventT(self, *args, **kwargs):
|
||||
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task:
|
||||
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
|
||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||
if not self.clearml: # noqa
|
||||
# noinspection PyBroadException
|
||||
@ -1010,7 +1015,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
logdir = self.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
|
||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1021,7 +1026,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
|
||||
@staticmethod
|
||||
def _patched_add_eventX(self, *args, **kwargs):
|
||||
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task:
|
||||
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
|
||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||
if not self.clearml:
|
||||
# noinspection PyBroadException
|
||||
@ -1029,7 +1034,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
logdir = self.get_logdir()
|
||||
except Exception:
|
||||
logdir = None
|
||||
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
|
||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -1051,7 +1056,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
@staticmethod
|
||||
def _patched_getattribute_(self, attr, get_base):
|
||||
# no main task, zero chance we have an ClearML event logger
|
||||
if PatchSummaryToEventTransformer.__main_task is None:
|
||||
if PatchSummaryToEventTransformer._current_task is None:
|
||||
return get_base(self, attr)
|
||||
|
||||
# check if we already have an ClearML event logger
|
||||
@ -1068,7 +1073,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
except Exception:
|
||||
logdir = None
|
||||
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
|
||||
logdir=logdir, **defaults_dict)
|
||||
|
||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||
@ -1111,8 +1116,9 @@ class _ModelAdapter(object):
|
||||
|
||||
|
||||
class PatchModelCheckPointCallback(object):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__original_getattribute = None
|
||||
__patched = False
|
||||
defaults_dict = dict(
|
||||
config_text=None,
|
||||
config_dict=None,
|
||||
@ -1132,11 +1138,15 @@ class PatchModelCheckPointCallback(object):
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchModelCheckPointCallback.defaults_dict.update(kwargs)
|
||||
PatchModelCheckPointCallback.__main_task = task
|
||||
# make sure we patched the SummaryToEventTransformer
|
||||
PatchModelCheckPointCallback._patch_model_checkpoint()
|
||||
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
|
||||
PatchModelCheckPointCallback._current_task = task
|
||||
if not task:
|
||||
return
|
||||
if not PatchModelCheckPointCallback.__patched:
|
||||
PatchModelCheckPointCallback.__patched = True
|
||||
# make sure we patched the SummaryToEventTransformer
|
||||
PatchModelCheckPointCallback._patch_model_checkpoint()
|
||||
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_checkpoint():
|
||||
@ -1176,7 +1186,7 @@ class PatchModelCheckPointCallback(object):
|
||||
get_base = PatchModelCheckPointCallback.__original_getattribute
|
||||
|
||||
# no main task, zero chance we have an ClearML event logger
|
||||
if PatchModelCheckPointCallback.__main_task is None:
|
||||
if PatchModelCheckPointCallback._current_task is None:
|
||||
return get_base(self, attr)
|
||||
|
||||
# check if we already have an ClearML event logger
|
||||
@ -1189,17 +1199,17 @@ class PatchModelCheckPointCallback(object):
|
||||
base_model = __dict__['model']
|
||||
defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict
|
||||
output_model = OutputModel(
|
||||
PatchModelCheckPointCallback.__main_task,
|
||||
PatchModelCheckPointCallback._current_task,
|
||||
config_text=defaults_dict.get('config_text'),
|
||||
config_dict=defaults_dict.get('config_dict'),
|
||||
name=defaults_dict.get('name'),
|
||||
comment=defaults_dict.get('comment'),
|
||||
label_enumeration=defaults_dict.get('label_enumeration') or
|
||||
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
|
||||
PatchModelCheckPointCallback._current_task.get_labels_enumeration(),
|
||||
framework=Framework.keras,
|
||||
)
|
||||
output_model.set_upload_destination(
|
||||
PatchModelCheckPointCallback.__main_task.get_output_destination(raise_on_error=False))
|
||||
PatchModelCheckPointCallback._current_task.get_output_destination(raise_on_error=False))
|
||||
trains_model = _ModelAdapter(base_model, output_model)
|
||||
|
||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||
@ -1209,26 +1219,32 @@ class PatchModelCheckPointCallback(object):
|
||||
|
||||
# noinspection PyProtectedMember,PyUnresolvedReferences
|
||||
class PatchTensorFlowEager(object):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__original_fn_scalar = None
|
||||
__original_fn_hist = None
|
||||
__original_fn_image = None
|
||||
__original_fn_write_summary = None
|
||||
__trains_event_writer = {}
|
||||
__tf_tb_writer_id_to_logdir = {}
|
||||
__patched = False
|
||||
defaults_dict = dict(
|
||||
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
|
||||
histogram_granularity=50)
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
if task != PatchTensorFlowEager.__main_task:
|
||||
if task != PatchTensorFlowEager._current_task:
|
||||
PatchTensorFlowEager.__trains_event_writer = {}
|
||||
|
||||
PatchTensorFlowEager.defaults_dict.update(kwargs)
|
||||
PatchTensorFlowEager.__main_task = task
|
||||
# make sure we patched the SummaryToEventTransformer
|
||||
PatchTensorFlowEager._patch_summary_ops()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
|
||||
PatchTensorFlowEager._current_task = task
|
||||
if not task:
|
||||
return
|
||||
if not PatchTensorFlowEager.__patched:
|
||||
PatchTensorFlowEager.__patched = True
|
||||
# make sure we patched the SummaryToEventTransformer
|
||||
PatchTensorFlowEager._patch_summary_ops()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
|
||||
|
||||
@staticmethod
|
||||
def _patch_summary_ops():
|
||||
@ -1245,7 +1261,7 @@ class PatchTensorFlowEager(object):
|
||||
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
|
||||
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
|
||||
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
|
||||
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary
|
||||
PatchTensorFlowEager.__original_fn_write_summary = gen_summary_ops.write_summary
|
||||
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
|
||||
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
|
||||
gen_summary_ops.create_summary_file_writer)
|
||||
@ -1270,6 +1286,8 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _create_summary_file_writer(original_fn, *args, **kwargs):
|
||||
if not PatchTensorFlowEager._current_task:
|
||||
return original_fn(*args, **kwargs)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
a_logdir = None
|
||||
@ -1294,7 +1312,7 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _get_event_writer(writer):
|
||||
if not PatchTensorFlowEager.__main_task:
|
||||
if not PatchTensorFlowEager._current_task:
|
||||
return None
|
||||
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
|
||||
# noinspection PyBroadException
|
||||
@ -1324,7 +1342,7 @@ class PatchTensorFlowEager(object):
|
||||
logdir = None
|
||||
|
||||
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
|
||||
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
||||
logger=PatchTensorFlowEager._current_task.get_logger(), logdir=logdir,
|
||||
**PatchTensorFlowEager.defaults_dict)
|
||||
return PatchTensorFlowEager.__trains_event_writer[id(writer)]
|
||||
|
||||
@ -1337,6 +1355,10 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
||||
if not PatchTensorFlowEager._current_task:
|
||||
return PatchTensorFlowEager.__original_fn_write_summary(
|
||||
writer, step, tensor, tag, summary_metadata, name, **kwargs
|
||||
)
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
# make sure we can get the tensors values
|
||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||
@ -1370,13 +1392,17 @@ class PatchTensorFlowEager(object):
|
||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||
values=None, audio_data=audio_bytes)
|
||||
else:
|
||||
pass # print('unsupported plugin_type', plugin_type)
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs)
|
||||
return PatchTensorFlowEager.__original_fn_write_summary(
|
||||
writer, step, tensor, tag, summary_metadata, name, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||
if not PatchTensorFlowEager._current_task:
|
||||
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||
try:
|
||||
@ -1417,6 +1443,8 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
||||
if not PatchTensorFlowEager._current_task:
|
||||
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||
try:
|
||||
@ -1459,6 +1487,9 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
@staticmethod
|
||||
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
||||
if not PatchTensorFlowEager._current_task:
|
||||
return PatchTensorFlowEager.__original_fn_image(
|
||||
writer, step, tag, tensor, bad_color, max_images, name, **kwargs)
|
||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||
try:
|
||||
@ -1526,13 +1557,15 @@ class PatchTensorFlowEager(object):
|
||||
|
||||
# noinspection PyPep8Naming,SpellCheckingInspection
|
||||
class PatchKerasModelIO(object):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched_keras = None
|
||||
__patched_tensorflow = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchKerasModelIO.__main_task = task
|
||||
PatchKerasModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchKerasModelIO._patch_model_checkpoint()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
|
||||
PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint)
|
||||
@ -1692,7 +1725,7 @@ class PatchKerasModelIO(object):
|
||||
def _updated_config(original_fn, self):
|
||||
config = original_fn(self)
|
||||
# check if we have main task
|
||||
if PatchKerasModelIO.__main_task is None:
|
||||
if PatchKerasModelIO._current_task is None:
|
||||
return config
|
||||
|
||||
try:
|
||||
@ -1717,10 +1750,10 @@ class PatchKerasModelIO(object):
|
||||
else:
|
||||
# todo: support multiple models for the same task
|
||||
self.trains_out_model.append(OutputModel(
|
||||
task=PatchKerasModelIO.__main_task,
|
||||
task=PatchKerasModelIO._current_task,
|
||||
config_dict=safe_config,
|
||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||
name=PatchKerasModelIO._current_task.name + ' ' + model_name_id,
|
||||
label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
|
||||
framework=Framework.keras,
|
||||
))
|
||||
except Exception as ex:
|
||||
@ -1738,7 +1771,7 @@ class PatchKerasModelIO(object):
|
||||
self = _Empty()
|
||||
|
||||
# check if we have main task
|
||||
if PatchKerasModelIO.__main_task is None:
|
||||
if PatchKerasModelIO._current_task is None:
|
||||
return self
|
||||
|
||||
try:
|
||||
@ -1751,10 +1784,10 @@ class PatchKerasModelIO(object):
|
||||
# check if object already has InputModel
|
||||
self.trains_in_model = InputModel.empty(
|
||||
config_dict=config_dict,
|
||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
||||
label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
|
||||
)
|
||||
# todo: support multiple models for the same task
|
||||
PatchKerasModelIO.__main_task.connect(self.trains_in_model)
|
||||
PatchKerasModelIO._current_task.connect(self.trains_in_model)
|
||||
# if we are running remotely we should deserialize the object
|
||||
# because someone might have changed the configuration
|
||||
# Hack: disabled
|
||||
@ -1781,7 +1814,7 @@ class PatchKerasModelIO(object):
|
||||
@staticmethod
|
||||
def _load_weights(original_fn, self, *args, **kwargs):
|
||||
# check if we have main task
|
||||
if PatchKerasModelIO.__main_task is None:
|
||||
if PatchKerasModelIO._current_task is None:
|
||||
return original_fn(self, *args, **kwargs)
|
||||
|
||||
# get filepath
|
||||
@ -1794,7 +1827,7 @@ class PatchKerasModelIO(object):
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
|
||||
PatchKerasModelIO.__main_task)
|
||||
PatchKerasModelIO._current_task)
|
||||
if 'filepath' in kwargs:
|
||||
kwargs['filepath'] = filepath
|
||||
else:
|
||||
@ -1805,11 +1838,13 @@ class PatchKerasModelIO(object):
|
||||
# try to load the files, if something happened exception will be raised before we register the file
|
||||
model = original_fn(self, *args, **kwargs)
|
||||
# register/load model weights
|
||||
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task)
|
||||
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO._current_task)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, self, *args, **kwargs):
|
||||
if not PatchKerasModelIO._current_task:
|
||||
return original_fn(self, *args, **kwargs)
|
||||
if hasattr(self, 'trains_out_model') and self.trains_out_model:
|
||||
# noinspection PyProtectedMember
|
||||
self.trains_out_model[-1]._processed = False
|
||||
@ -1823,12 +1858,13 @@ class PatchKerasModelIO(object):
|
||||
@staticmethod
|
||||
def _save_weights(original_fn, self, *args, **kwargs):
|
||||
original_fn(self, *args, **kwargs)
|
||||
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
|
||||
if PatchKerasModelIO._current_task:
|
||||
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _update_outputmodel(self, *args, **kwargs):
|
||||
# check if we have main task
|
||||
if PatchKerasModelIO.__main_task is None:
|
||||
if not PatchKerasModelIO._current_task:
|
||||
return
|
||||
|
||||
try:
|
||||
@ -1851,7 +1887,7 @@ class PatchKerasModelIO(object):
|
||||
|
||||
if filepath:
|
||||
WeightsFileHandler.create_output_model(
|
||||
self, filepath, Framework.keras, PatchKerasModelIO.__main_task,
|
||||
self, filepath, Framework.keras, PatchKerasModelIO._current_task,
|
||||
config_obj=config or None, singlefile=True)
|
||||
|
||||
except Exception as ex:
|
||||
@ -1860,12 +1896,12 @@ class PatchKerasModelIO(object):
|
||||
@staticmethod
|
||||
def _save_model(original_fn, model, filepath, *args, **kwargs):
|
||||
original_fn(model, filepath, *args, **kwargs)
|
||||
if PatchKerasModelIO.__main_task:
|
||||
if PatchKerasModelIO._current_task:
|
||||
PatchKerasModelIO._update_outputmodel(model, filepath)
|
||||
|
||||
@staticmethod
|
||||
def _load_model(original_fn, filepath, *args, **kwargs):
|
||||
if not PatchKerasModelIO.__main_task:
|
||||
if not PatchKerasModelIO._current_task:
|
||||
return original_fn(filepath, *args, **kwargs)
|
||||
|
||||
empty = _Empty()
|
||||
@ -1873,12 +1909,12 @@ class PatchKerasModelIO(object):
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
|
||||
PatchKerasModelIO.__main_task)
|
||||
PatchKerasModelIO._current_task)
|
||||
model = original_fn(filepath, *args, **kwargs)
|
||||
else:
|
||||
model = original_fn(filepath, *args, **kwargs)
|
||||
# register/load model weights
|
||||
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
|
||||
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO._current_task)
|
||||
# update the input model object
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
@ -1891,12 +1927,14 @@ class PatchKerasModelIO(object):
|
||||
|
||||
|
||||
class PatchTensorflowModelIO(object):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchTensorflowModelIO.__main_task = task
|
||||
PatchTensorflowModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchTensorflowModelIO._patch_model_checkpoint()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
|
||||
|
||||
@ -2028,29 +2066,31 @@ class PatchTensorflowModelIO(object):
|
||||
@staticmethod
|
||||
def _save(original_fn, self, sess, save_path, *args, **kwargs):
|
||||
saved_path = original_fn(self, sess, save_path, *args, **kwargs)
|
||||
if not saved_path:
|
||||
if not saved_path or not PatchTensorflowModelIO._current_task:
|
||||
return saved_path
|
||||
# store output Model
|
||||
return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
|
||||
@staticmethod
|
||||
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
|
||||
original_fn(obj, export_dir, *args, **kwargs)
|
||||
if not PatchKerasModelIO._current_task:
|
||||
return
|
||||
# store output Model
|
||||
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
|
||||
@staticmethod
|
||||
def _restore(original_fn, self, sess, save_path, *args, **kwargs):
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
if PatchTensorflowModelIO._current_task is None:
|
||||
return original_fn(self, sess, save_path, *args, **kwargs)
|
||||
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
# register/load model weights
|
||||
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
# load model
|
||||
return original_fn(self, sess, save_path, *args, **kwargs)
|
||||
|
||||
@ -2058,12 +2098,12 @@ class PatchTensorflowModelIO(object):
|
||||
model = original_fn(self, sess, save_path, *args, **kwargs)
|
||||
# register/load model weights
|
||||
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
if PatchTensorflowModelIO._current_task is None:
|
||||
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||
|
||||
# register input model
|
||||
@ -2072,7 +2112,7 @@ class PatchTensorflowModelIO(object):
|
||||
if False and running_remotely():
|
||||
export_dir = WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
PatchTensorflowModelIO._current_task
|
||||
)
|
||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||
else:
|
||||
@ -2080,7 +2120,7 @@ class PatchTensorflowModelIO(object):
|
||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
PatchTensorflowModelIO._current_task
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
@ -2093,7 +2133,7 @@ class PatchTensorflowModelIO(object):
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, export_dir, *args, **saver_kwargs):
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
if PatchTensorflowModelIO._current_task is None:
|
||||
return original_fn(export_dir, *args, **saver_kwargs)
|
||||
|
||||
# register input model
|
||||
@ -2102,7 +2142,7 @@ class PatchTensorflowModelIO(object):
|
||||
if False and running_remotely():
|
||||
export_dir = WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
PatchTensorflowModelIO._current_task
|
||||
)
|
||||
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||
else:
|
||||
@ -2110,7 +2150,7 @@ class PatchTensorflowModelIO(object):
|
||||
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||
WeightsFileHandler.restore_weights_file(
|
||||
empty, export_dir, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task
|
||||
PatchTensorflowModelIO._current_task
|
||||
)
|
||||
|
||||
if empty.trains_in_model:
|
||||
@ -2124,24 +2164,24 @@ class PatchTensorflowModelIO(object):
|
||||
@staticmethod
|
||||
def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs):
|
||||
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
if PatchTensorflowModelIO._current_task is None:
|
||||
return checkpoint_path
|
||||
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
return checkpoint_path
|
||||
|
||||
@staticmethod
|
||||
def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs):
|
||||
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
if PatchTensorflowModelIO._current_task is None:
|
||||
return checkpoint_path
|
||||
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
return checkpoint_path
|
||||
|
||||
@staticmethod
|
||||
def _ckpt_restore(original_fn, self, save_path, *args, **kwargs):
|
||||
if PatchTensorflowModelIO.__main_task is None:
|
||||
if PatchTensorflowModelIO._current_task is None:
|
||||
return original_fn(self, save_path, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
@ -2149,13 +2189,13 @@ class PatchTensorflowModelIO(object):
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
model = original_fn(self, save_path, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering it, in case it fails.
|
||||
model = original_fn(self, save_path, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
|
||||
PatchTensorflowModelIO.__main_task)
|
||||
PatchTensorflowModelIO._current_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
@ -2167,12 +2207,14 @@ class PatchTensorflowModelIO(object):
|
||||
|
||||
|
||||
class PatchTensorflow2ModelIO(object):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **_):
|
||||
PatchTensorflow2ModelIO.__main_task = task
|
||||
PatchTensorflow2ModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchTensorflow2ModelIO._patch_model_checkpoint()
|
||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
|
||||
|
||||
@ -2210,18 +2252,20 @@ class PatchTensorflow2ModelIO(object):
|
||||
@staticmethod
|
||||
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
||||
model = original_fn(self, file_prefix, *args, **kwargs)
|
||||
if not PatchTensorflow2ModelIO._current_task:
|
||||
return model
|
||||
# store output Model
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
|
||||
PatchTensorflow2ModelIO.__main_task)
|
||||
PatchTensorflow2ModelIO._current_task)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _restore(original_fn, self, save_path, *args, **kwargs):
|
||||
if PatchTensorflow2ModelIO.__main_task is None:
|
||||
if not PatchTensorflow2ModelIO._current_task:
|
||||
return original_fn(self, save_path, *args, **kwargs)
|
||||
|
||||
# Hack: disabled
|
||||
@ -2230,7 +2274,7 @@ class PatchTensorflow2ModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||
PatchTensorflow2ModelIO.__main_task)
|
||||
PatchTensorflow2ModelIO._current_task)
|
||||
except Exception:
|
||||
pass
|
||||
# load model
|
||||
@ -2242,7 +2286,7 @@ class PatchTensorflow2ModelIO(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||
PatchTensorflow2ModelIO.__main_task)
|
||||
PatchTensorflow2ModelIO._current_task)
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
@ -11,13 +11,15 @@ from ...model import Framework
|
||||
|
||||
|
||||
class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
_current_task = None
|
||||
__patched = None
|
||||
__callback_cls = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchXGBoostModelIO.__main_task = task
|
||||
PatchXGBoostModelIO._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchXGBoostModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io)
|
||||
|
||||
@ -55,7 +57,7 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchXGBoostModelIO.__main_task:
|
||||
if not PatchXGBoostModelIO._current_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
@ -76,12 +78,15 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task,
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO._current_task,
|
||||
singlefile=True, model_name=model_name)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if not PatchXGBoostModelIO._current_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
@ -91,21 +96,18 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchXGBoostModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
# Hack: disabled
|
||||
if False and running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||
PatchXGBoostModelIO.__main_task)
|
||||
PatchXGBoostModelIO._current_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||
PatchXGBoostModelIO.__main_task)
|
||||
PatchXGBoostModelIO._current_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
@ -117,10 +119,12 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
|
||||
@staticmethod
|
||||
def _train(original_fn, *args, **kwargs):
|
||||
if not PatchXGBoostModelIO._current_task:
|
||||
return original_fn(*args, **kwargs)
|
||||
if PatchXGBoostModelIO.__callback_cls:
|
||||
callbacks = kwargs.get('callbacks') or []
|
||||
kwargs['callbacks'] = callbacks + [
|
||||
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task)
|
||||
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO._current_task)
|
||||
]
|
||||
return original_fn(*args, **kwargs)
|
||||
|
||||
|
@ -2,7 +2,7 @@ import io
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE
|
||||
from ..config import running_remotely, DEV_TASK_NO_REUSE
|
||||
from ..debugging.log import LoggerRoot
|
||||
|
||||
|
||||
@ -46,6 +46,8 @@ class PatchHydra(object):
|
||||
def update_current_task(task):
|
||||
# set current Task before patching
|
||||
PatchHydra._current_task = task
|
||||
if not task:
|
||||
return
|
||||
if PatchHydra.patch_hydra():
|
||||
# check if we have an untracked state, store it.
|
||||
if PatchHydra._last_untracked_state.get('connect'):
|
||||
@ -66,9 +68,6 @@ class PatchHydra(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if running_remotely():
|
||||
if not PatchHydra._current_task:
|
||||
from ..task import Task
|
||||
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
|
||||
# get the _parameter_allow_full_edit casted back to boolean
|
||||
connected_config = dict()
|
||||
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
||||
@ -126,8 +125,9 @@ class PatchHydra(object):
|
||||
|
||||
if pre_app_task_init_call and not running_remotely():
|
||||
LoggerRoot.get_base_logger(PatchHydra).info(
|
||||
'Task.init called outside of Hydra-App. For full Hydra multi-run support, '
|
||||
'move the Task.init call into the Hydra-App main function')
|
||||
"Task.init called outside of Hydra-App. For full Hydra multi-run support, "
|
||||
"move the Task.init call into the Hydra-App main function"
|
||||
)
|
||||
|
||||
kwargs["config"] = config
|
||||
kwargs["task_function"] = partial(PatchHydra._patched_task_function, task_function,)
|
||||
|
@ -77,6 +77,8 @@ class PatchedJoblib(object):
|
||||
@staticmethod
|
||||
def update_current_task(task):
|
||||
PatchedJoblib._current_task = task
|
||||
if not task:
|
||||
return
|
||||
PatchedJoblib.patch_joblib()
|
||||
|
||||
@staticmethod
|
||||
|
@ -14,7 +14,7 @@ from ..utilities.proxy_object import flatten_dictionary
|
||||
|
||||
class PatchJsonArgParse(object):
|
||||
_args = {}
|
||||
_main_task = None
|
||||
_current_task = None
|
||||
_args_sep = "/"
|
||||
_args_type = {}
|
||||
_commands_sep = "."
|
||||
@ -24,14 +24,19 @@ class PatchJsonArgParse(object):
|
||||
__remote_task_params = {}
|
||||
__patched = False
|
||||
|
||||
@classmethod
|
||||
def update_current_task(cls, task):
|
||||
cls._current_task = task
|
||||
if not task:
|
||||
return
|
||||
cls.patch(task)
|
||||
|
||||
@classmethod
|
||||
def patch(cls, task):
|
||||
if ArgumentParser is None:
|
||||
return
|
||||
|
||||
if task:
|
||||
cls._main_task = task
|
||||
PatchJsonArgParse._update_task_args()
|
||||
PatchJsonArgParse._update_task_args()
|
||||
|
||||
if not cls.__patched:
|
||||
cls.__patched = True
|
||||
@ -39,14 +44,16 @@ class PatchJsonArgParse(object):
|
||||
|
||||
@classmethod
|
||||
def _update_task_args(cls):
|
||||
if running_remotely() or not cls._main_task or not cls._args:
|
||||
if running_remotely() or not cls._current_task or not cls._args:
|
||||
return
|
||||
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
||||
args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()}
|
||||
cls._main_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
||||
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
||||
|
||||
@staticmethod
|
||||
def _parse_args(original_fn, obj, *args, **kwargs):
|
||||
if not PatchJsonArgParse._current_task:
|
||||
return original_fn(obj, *args, **kwargs)
|
||||
if len(args) == 1:
|
||||
kwargs["args"] = args[0]
|
||||
args = []
|
||||
|
@ -207,12 +207,15 @@ class PatchedMatplotlib:
|
||||
PatchedMatplotlib._current_task = task
|
||||
PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib)
|
||||
PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib)
|
||||
elif PatchedMatplotlib.patch_matplotlib():
|
||||
else:
|
||||
PatchedMatplotlib.patch_matplotlib()
|
||||
PatchedMatplotlib._current_task = task
|
||||
|
||||
@staticmethod
|
||||
def patched_imshow(*args, **kw):
|
||||
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
||||
if not PatchedMatplotlib._current_task:
|
||||
return ret
|
||||
try:
|
||||
from matplotlib import _pylab_helpers
|
||||
# store on the plot that this is an imshow plot
|
||||
@ -227,6 +230,8 @@ class PatchedMatplotlib:
|
||||
@staticmethod
|
||||
def patched_savefig(self, *args, **kw):
|
||||
ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
|
||||
if not PatchedMatplotlib._current_task:
|
||||
return ret
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
fname = kw.get('fname') or args[0]
|
||||
@ -256,6 +261,8 @@ class PatchedMatplotlib:
|
||||
|
||||
@staticmethod
|
||||
def patched_figure_show(self, *args, **kw):
|
||||
if not PatchedMatplotlib._current_task:
|
||||
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
||||
tid = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||
if PatchedMatplotlib._recursion_guard.get(tid):
|
||||
# we are inside a gaurd do nothing
|
||||
@ -269,6 +276,8 @@ class PatchedMatplotlib:
|
||||
|
||||
@staticmethod
|
||||
def patched_show(*args, **kw):
|
||||
if not PatchedMatplotlib._current_task:
|
||||
return PatchedMatplotlib._patched_original_plot(*args, **kw)
|
||||
tid = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||
PatchedMatplotlib._recursion_guard[tid] = True
|
||||
# noinspection PyBroadException
|
||||
|
@ -626,8 +626,7 @@ class Task(_Task):
|
||||
task.__register_at_exit(task._at_exit)
|
||||
|
||||
# always patch OS forking because of ProcessPool and the alike
|
||||
PatchOsFork.patch_fork()
|
||||
|
||||
PatchOsFork.patch_fork(task)
|
||||
if auto_connect_frameworks:
|
||||
def should_connect(*keys):
|
||||
"""
|
||||
@ -708,13 +707,15 @@ class Task(_Task):
|
||||
make_deterministic(task.get_random_seed())
|
||||
|
||||
if auto_connect_arg_parser:
|
||||
EnvironmentBind.update_current_task(Task.__main_task)
|
||||
EnvironmentBind.update_current_task(task)
|
||||
|
||||
PatchJsonArgParse.update_current_task(task)
|
||||
|
||||
# Patch ArgParser to be aware of the current task
|
||||
argparser_update_currenttask(Task.__main_task)
|
||||
PatchClick.patch(Task.__main_task)
|
||||
PatchFire.patch(Task.__main_task)
|
||||
PatchJsonArgParse.patch(Task.__main_task)
|
||||
argparser_update_currenttask(task)
|
||||
|
||||
PatchClick.patch(task)
|
||||
PatchFire.patch(task)
|
||||
|
||||
# set excluded arguments
|
||||
if isinstance(auto_connect_arg_parser, dict):
|
||||
@ -1686,6 +1687,22 @@ class Task(_Task):
|
||||
# noinspection PyProtectedMember
|
||||
Logger._remove_std_logger()
|
||||
|
||||
# unbind everything
|
||||
PatchHydra.update_current_task(None)
|
||||
PatchedJoblib.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
PatchAbsl.update_current_task(None)
|
||||
TensorflowBinding.update_current_task(None)
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchMegEngineModelIO.update_current_task(None)
|
||||
PatchXGBoostModelIO.update_current_task(None)
|
||||
PatchCatBoostModelIO.update_current_task(None)
|
||||
PatchFastai.update_current_task(None)
|
||||
PatchLIGHTgbmModelIO.update_current_task(None)
|
||||
EnvironmentBind.update_current_task(None)
|
||||
PatchJsonArgParse.update_current_task(None)
|
||||
PatchOsFork.patch_fork(None)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
delete_artifacts_and_models=True,
|
||||
|
Loading…
Reference in New Issue
Block a user