mirror of
https://github.com/clearml/clearml
synced 2025-05-24 13:54:16 +00:00
Unbind all patches when Task.close()
is called
This commit is contained in:
parent
3c396b1f7e
commit
5dff9dd404
@ -8,12 +8,15 @@ from ..config import running_remotely
|
|||||||
class PatchAbsl(object):
|
class PatchAbsl(object):
|
||||||
_original_DEFINE_flag = None
|
_original_DEFINE_flag = None
|
||||||
_original_FLAGS_parse_call = None
|
_original_FLAGS_parse_call = None
|
||||||
_task = None
|
_current_task = None
|
||||||
|
__patched = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_current_task(cls, current_task):
|
def update_current_task(cls, task):
|
||||||
cls._task = current_task
|
cls._current_task = task
|
||||||
cls._patch_absl()
|
if not cls.__patched:
|
||||||
|
cls._patch_absl()
|
||||||
|
cls.__patched = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _patch_absl(cls):
|
def _patch_absl(cls):
|
||||||
@ -53,7 +56,7 @@ class PatchAbsl(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_define_flag(*args, **kwargs):
|
def _patched_define_flag(*args, **kwargs):
|
||||||
if not PatchAbsl._task or not PatchAbsl._original_DEFINE_flag:
|
if not PatchAbsl._current_task or not PatchAbsl._original_DEFINE_flag:
|
||||||
if PatchAbsl._original_DEFINE_flag:
|
if PatchAbsl._original_DEFINE_flag:
|
||||||
return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
return PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -73,7 +76,7 @@ class PatchAbsl(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if param_name and flag:
|
if param_name and flag:
|
||||||
param_dict = PatchAbsl._task._arguments.copy_to_dict(
|
param_dict = PatchAbsl._current_task._arguments.copy_to_dict(
|
||||||
{param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
|
{param_name: flag.value}, prefix=_Arguments._prefix_tf_defines)
|
||||||
flag.value = param_dict.get(param_name, flag.value)
|
flag.value = param_dict.get(param_name, flag.value)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -82,13 +85,17 @@ class PatchAbsl(object):
|
|||||||
else:
|
else:
|
||||||
if flag and param_name:
|
if flag and param_name:
|
||||||
value = flag.value
|
value = flag.value
|
||||||
PatchAbsl._task.update_parameters({_Arguments._prefix_tf_defines + param_name: value}, )
|
PatchAbsl._current_task.update_parameters(
|
||||||
|
{_Arguments._prefix_tf_defines + param_name: value},
|
||||||
|
)
|
||||||
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
ret = PatchAbsl._original_DEFINE_flag(*args, **kwargs)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_FLAGS_parse_call(self, *args, **kwargs):
|
def _patched_FLAGS_parse_call(self, *args, **kwargs):
|
||||||
ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
|
ret = PatchAbsl._original_FLAGS_parse_call(self, *args, **kwargs)
|
||||||
|
if not PatchAbsl._current_task:
|
||||||
|
return ret
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
PatchAbsl._update_current_flags(self)
|
PatchAbsl._update_current_flags(self)
|
||||||
@ -98,13 +105,13 @@ class PatchAbsl(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _update_current_flags(cls, FLAGS):
|
def _update_current_flags(cls, FLAGS):
|
||||||
if not cls._task:
|
if not cls._current_task:
|
||||||
return
|
return
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if running_remotely():
|
if running_remotely():
|
||||||
param_dict = dict((k, FLAGS[k].value) for k in FLAGS)
|
param_dict = dict((k, FLAGS[k].value) for k in FLAGS)
|
||||||
param_dict = cls._task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
|
param_dict = cls._current_task._arguments.copy_to_dict(param_dict, prefix=_Arguments._prefix_tf_defines)
|
||||||
for k, v in param_dict.items():
|
for k, v in param_dict.items():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -127,7 +134,7 @@ class PatchAbsl(object):
|
|||||||
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
|
param_types = dict([(k, FLAGS[k].flag_type() or None) for k in FLAGS])
|
||||||
except Exception:
|
except Exception:
|
||||||
param_types = None
|
param_types = None
|
||||||
cls._task._arguments.copy_from_dict(
|
cls._current_task._arguments.copy_from_dict(
|
||||||
parameters,
|
parameters,
|
||||||
prefix=_Arguments._prefix_tf_defines,
|
prefix=_Arguments._prefix_tf_defines,
|
||||||
descriptions=descriptions, param_types=param_types,
|
descriptions=descriptions, param_types=param_types,
|
||||||
|
@ -16,7 +16,7 @@ class PatchClick:
|
|||||||
_num_commands = 0
|
_num_commands = 0
|
||||||
_command_type = 'click.Command'
|
_command_type = 'click.Command'
|
||||||
_section_name = 'Args'
|
_section_name = 'Args'
|
||||||
_main_task = None
|
_current_task = None
|
||||||
__remote_task_params = None
|
__remote_task_params = None
|
||||||
__remote_task_params_dict = {}
|
__remote_task_params_dict = {}
|
||||||
__patched = False
|
__patched = False
|
||||||
@ -26,8 +26,8 @@ class PatchClick:
|
|||||||
if Command is None:
|
if Command is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cls._current_task = task
|
||||||
if task:
|
if task:
|
||||||
cls._main_task = task
|
|
||||||
PatchClick._update_task_args()
|
PatchClick._update_task_args()
|
||||||
|
|
||||||
if not cls.__patched:
|
if not cls.__patched:
|
||||||
@ -53,11 +53,11 @@ class PatchClick:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _update_task_args(cls):
|
def _update_task_args(cls):
|
||||||
if running_remotely() or not cls._main_task or not cls._args:
|
if running_remotely() or not cls._current_task or not cls._args:
|
||||||
return
|
return
|
||||||
param_val, param_types, param_desc = cls.args()
|
param_val, param_types, param_desc = cls.args()
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
cls._main_task._set_parameters(
|
cls._current_task._set_parameters(
|
||||||
param_val,
|
param_val,
|
||||||
__update=True,
|
__update=True,
|
||||||
__parameters_descriptions=param_desc,
|
__parameters_descriptions=param_desc,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -7,26 +8,29 @@ from ..utilities.process.mp import BackgroundMonitor
|
|||||||
|
|
||||||
|
|
||||||
class EnvironmentBind(object):
|
class EnvironmentBind(object):
|
||||||
_task = None
|
_current_task = None
|
||||||
_environment_section = 'Environment'
|
_environment_section = 'Environment'
|
||||||
|
__patched = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_current_task(cls, current_task):
|
def update_current_task(cls, task):
|
||||||
cls._task = current_task
|
cls._current_task = task
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
cls._bind_environment()
|
if not cls.__patched:
|
||||||
|
cls.__patched = True
|
||||||
|
cls._bind_environment()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _bind_environment(cls):
|
def _bind_environment(cls):
|
||||||
if not cls._task:
|
if not cls._current_task:
|
||||||
return
|
return
|
||||||
|
|
||||||
# get ENVIRONMENT and put it into the OS environment
|
# get ENVIRONMENT and put it into the OS environment
|
||||||
if running_remotely():
|
if running_remotely():
|
||||||
params = cls._task.get_parameters_as_dict()
|
params = cls._current_task.get_parameters_as_dict()
|
||||||
if params and cls._environment_section in params:
|
if params and cls._environment_section in params:
|
||||||
# put back into os:
|
# put back into os:
|
||||||
os.environ.update(params[cls._environment_section])
|
os.environ.update(params[cls._environment_section])
|
||||||
@ -55,14 +59,18 @@ class EnvironmentBind(object):
|
|||||||
elif match in os.environ:
|
elif match in os.environ:
|
||||||
env_param.update({match: os.environ.get(match)})
|
env_param.update({match: os.environ.get(match)})
|
||||||
# store os environments
|
# store os environments
|
||||||
cls._task.connect(env_param, cls._environment_section)
|
cls._current_task.connect(env_param, cls._environment_section)
|
||||||
|
|
||||||
|
|
||||||
class PatchOsFork(object):
|
class PatchOsFork(object):
|
||||||
_original_fork = None
|
_original_fork = None
|
||||||
|
_current_task = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def patch_fork(cls):
|
def patch_fork(cls, task):
|
||||||
|
cls._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# only once
|
# only once
|
||||||
@ -88,11 +96,13 @@ class PatchOsFork(object):
|
|||||||
Task._wait_for_deferred(task)
|
Task._wait_for_deferred(task)
|
||||||
|
|
||||||
ret = PatchOsFork._original_fork(*args, **kwargs)
|
ret = PatchOsFork._original_fork(*args, **kwargs)
|
||||||
|
if not PatchOsFork._current_task:
|
||||||
|
return ret
|
||||||
# Make sure the new process stdout is logged
|
# Make sure the new process stdout is logged
|
||||||
if not ret:
|
if not ret:
|
||||||
# force creating a Task
|
# force creating a Task
|
||||||
task = Task.current_task()
|
task = Task.current_task()
|
||||||
if task is None:
|
if not task:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
|
# # Hack: now make sure we setup the reporter threads (Log+Reporter)
|
||||||
@ -106,7 +116,7 @@ class PatchOsFork(object):
|
|||||||
# just make sure we flush the internal state (the at exist caught by the external signal does the rest
|
# just make sure we flush the internal state (the at exist caught by the external signal does the rest
|
||||||
# in theory we should not have to do any of that, but for some reason if we do not
|
# in theory we should not have to do any of that, but for some reason if we do not
|
||||||
# the signal is never caught by the signal call backs, not sure why....
|
# the signal is never caught by the signal call backs, not sure why....
|
||||||
|
sleep(0.1)
|
||||||
# Since at_exist handlers do not work on forked processes, we have to manually call them here
|
# Since at_exist handlers do not work on forked processes, we have to manually call them here
|
||||||
if task:
|
if task:
|
||||||
try:
|
try:
|
||||||
|
@ -20,7 +20,7 @@ class PatchFire:
|
|||||||
_section_name = "Args"
|
_section_name = "Args"
|
||||||
_args_sep = "/"
|
_args_sep = "/"
|
||||||
_commands_sep = "."
|
_commands_sep = "."
|
||||||
_main_task = None
|
_current_task = None
|
||||||
__remote_task_params = None
|
__remote_task_params = None
|
||||||
__remote_task_params_dict = {}
|
__remote_task_params_dict = {}
|
||||||
__patched = False
|
__patched = False
|
||||||
@ -38,8 +38,8 @@ class PatchFire:
|
|||||||
if fire is None:
|
if fire is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cls._current_task = task
|
||||||
if task:
|
if task:
|
||||||
cls._main_task = task
|
|
||||||
cls._update_task_args()
|
cls._update_task_args()
|
||||||
|
|
||||||
if not cls.__patched:
|
if not cls.__patched:
|
||||||
@ -53,7 +53,7 @@ class PatchFire:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _update_task_args(cls):
|
def _update_task_args(cls):
|
||||||
if running_remotely() or not cls._main_task:
|
if running_remotely() or not cls._current_task:
|
||||||
return
|
return
|
||||||
args = {}
|
args = {}
|
||||||
parameters_types = {}
|
parameters_types = {}
|
||||||
@ -118,7 +118,7 @@ class PatchFire:
|
|||||||
parameters_types = {**parameters_types, **unused_paramenters_types}
|
parameters_types = {**parameters_types, **unused_paramenters_types}
|
||||||
|
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
cls._main_task._set_parameters(
|
cls._current_task._set_parameters(
|
||||||
args,
|
args,
|
||||||
__update=True,
|
__update=True,
|
||||||
__parameters_types=parameters_types,
|
__parameters_types=parameters_types,
|
||||||
@ -126,25 +126,25 @@ class PatchFire:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
|
def __Fire(original_fn, component, args_, parsed_flag_args, context, name, *args, **kwargs): # noqa
|
||||||
if running_remotely():
|
if not running_remotely():
|
||||||
command = PatchFire._load_task_params()
|
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
|
||||||
if command is not None:
|
command = PatchFire._load_task_params()
|
||||||
replaced_args = command.split(PatchFire._commands_sep)
|
if command is not None:
|
||||||
else:
|
replaced_args = command.split(PatchFire._commands_sep)
|
||||||
replaced_args = []
|
else:
|
||||||
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
|
replaced_args = []
|
||||||
if command is not None and param.type == PatchFire._command_arg_type_template % command:
|
for param in PatchFire.__remote_task_params[PatchFire._section_name].values():
|
||||||
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
|
if command is not None and param.type == PatchFire._command_arg_type_template % command:
|
||||||
value = PatchFire.__remote_task_params_dict[param.name]
|
replaced_args.append("--" + param.name[len(command + PatchFire._args_sep):])
|
||||||
if len(value) > 0:
|
value = PatchFire.__remote_task_params_dict[param.name]
|
||||||
replaced_args.append(value)
|
if len(value) > 0:
|
||||||
if param.type == PatchFire._shared_arg_type:
|
replaced_args.append(value)
|
||||||
replaced_args.append("--" + param.name)
|
if param.type == PatchFire._shared_arg_type:
|
||||||
value = PatchFire.__remote_task_params_dict[param.name]
|
replaced_args.append("--" + param.name)
|
||||||
if len(value) > 0:
|
value = PatchFire.__remote_task_params_dict[param.name]
|
||||||
replaced_args.append(value)
|
if len(value) > 0:
|
||||||
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
|
replaced_args.append(value)
|
||||||
return original_fn(component, args_, parsed_flag_args, context, name, *args, **kwargs)
|
return original_fn(component, replaced_args, parsed_flag_args, context, name, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def __CallAndUpdateTrace( # noqa
|
def __CallAndUpdateTrace( # noqa
|
||||||
|
@ -11,13 +11,15 @@ from ...model import Framework
|
|||||||
|
|
||||||
|
|
||||||
class PatchCatBoostModelIO(PatchBaseModelIO):
|
class PatchCatBoostModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
__callback_cls = None
|
__callback_cls = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **kwargs):
|
||||||
PatchCatBoostModelIO.__main_task = task
|
PatchCatBoostModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchCatBoostModelIO._patch_model_io()
|
PatchCatBoostModelIO._patch_model_io()
|
||||||
PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io)
|
PostImportHookPatching.add_on_import("catboost", PatchCatBoostModelIO._patch_model_io)
|
||||||
|
|
||||||
@ -38,16 +40,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
|||||||
CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit)
|
CatBoost.fit = _patched_call(CatBoost.fit, PatchCatBoostModelIO._fit)
|
||||||
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
|
CatBoostClassifier.fit = _patched_call(CatBoostClassifier.fit, PatchCatBoostModelIO._fit)
|
||||||
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
CatBoostRegressor.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
||||||
CatBoostRanker.fit = _patched_call(CatBoostRegressor.fit, PatchCatBoostModelIO._fit)
|
CatBoostRanker.fit = _patched_call(CatBoostRanker.fit, PatchCatBoostModelIO._fit)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger = PatchCatBoostModelIO.__main_task.get_logger()
|
logger = PatchCatBoostModelIO._current_task.get_logger()
|
||||||
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")
|
logger.report_text("Failed patching Catboost. Exception is: '" + str(e) + "'")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
|
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_save_model
|
||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
if not PatchCatBoostModelIO.__main_task:
|
if not PatchCatBoostModelIO._current_task:
|
||||||
return ret
|
return ret
|
||||||
if isinstance(f, six.string_types):
|
if isinstance(f, six.string_types):
|
||||||
filename = f
|
filename = f
|
||||||
@ -60,13 +62,16 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
|||||||
except Exception:
|
except Exception:
|
||||||
model_name = None
|
model_name = None
|
||||||
WeightsFileHandler.create_output_model(
|
WeightsFileHandler.create_output_model(
|
||||||
obj, filename, Framework.catboost, PatchCatBoostModelIO.__main_task, singlefile=True, model_name=model_name
|
obj, filename, Framework.catboost, PatchCatBoostModelIO._current_task, singlefile=True, model_name=model_name
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model
|
# see https://catboost.ai/en/docs/concepts/python-reference_catboost_load_model
|
||||||
|
if not PatchCatBoostModelIO._current_task:
|
||||||
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(f, six.string_types):
|
if isinstance(f, six.string_types):
|
||||||
filename = f
|
filename = f
|
||||||
elif len(args) >= 1 and isinstance(args[0], six.string_types):
|
elif len(args) >= 1 and isinstance(args[0], six.string_types):
|
||||||
@ -74,13 +79,10 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
|||||||
else:
|
else:
|
||||||
filename = None
|
filename = None
|
||||||
|
|
||||||
if not PatchCatBoostModelIO.__main_task:
|
|
||||||
return original_fn(f, *args, **kwargs)
|
|
||||||
|
|
||||||
# register input model
|
# register input model
|
||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
model = original_fn(f, *args, **kwargs)
|
model = original_fn(f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO.__main_task)
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.catboost, PatchCatBoostModelIO._current_task)
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -91,13 +93,15 @@ class PatchCatBoostModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fit(original_fn, obj, *args, **kwargs):
|
def _fit(original_fn, obj, *args, **kwargs):
|
||||||
|
if not PatchCatBoostModelIO._current_task:
|
||||||
|
return original_fn(obj, *args, **kwargs)
|
||||||
callbacks = kwargs.get("callbacks") or []
|
callbacks = kwargs.get("callbacks") or []
|
||||||
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO.__main_task)]
|
kwargs["callbacks"] = callbacks + [PatchCatBoostModelIO.__callback_cls(task=PatchCatBoostModelIO._current_task)]
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
return original_fn(obj, *args, **kwargs)
|
return original_fn(obj, *args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger = PatchCatBoostModelIO.__main_task.get_logger()
|
logger = PatchCatBoostModelIO._current_task.get_logger()
|
||||||
logger.report_text(
|
logger.report_text(
|
||||||
"Catboost metrics logging is not supported for GPU. "
|
"Catboost metrics logging is not supported for GPU. "
|
||||||
"See https://github.com/catboost/catboost/issues/1792"
|
"See https://github.com/catboost/catboost/issues/1792"
|
||||||
|
@ -25,11 +25,9 @@ class PatchFastai(object):
|
|||||||
try:
|
try:
|
||||||
if Version(fastai.__version__) < Version("2.0.0"):
|
if Version(fastai.__version__) < Version("2.0.0"):
|
||||||
PatchFastaiV1.update_current_task(task)
|
PatchFastaiV1.update_current_task(task)
|
||||||
PatchFastaiV1.patch_model_callback()
|
|
||||||
PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback)
|
PostImportHookPatching.add_on_import("fastai", PatchFastaiV1.patch_model_callback)
|
||||||
else:
|
else:
|
||||||
PatchFastaiV2.update_current_task(task)
|
PatchFastaiV2.update_current_task(task)
|
||||||
PatchFastaiV2.patch_model_callback()
|
|
||||||
PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback)
|
PostImportHookPatching.add_on_import("fastai", PatchFastaiV2.patch_model_callback)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@ -38,11 +36,17 @@ class PatchFastai(object):
|
|||||||
class PatchFastaiV1(object):
|
class PatchFastaiV1(object):
|
||||||
__metrics_names = {}
|
__metrics_names = {}
|
||||||
__gradient_hist_helpers = {}
|
__gradient_hist_helpers = {}
|
||||||
_main_task = None
|
_current_task = None
|
||||||
|
__patched = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchFastaiV1._main_task = task
|
PatchFastaiV1._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
if not PatchFastaiV1.__patched:
|
||||||
|
PatchFastaiV1.__patched = True
|
||||||
|
PatchFastaiV1.patch_model_callback()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patch_model_callback():
|
def patch_model_callback():
|
||||||
@ -52,7 +56,6 @@ class PatchFastaiV1(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from fastai.basic_train import Recorder
|
from fastai.basic_train import Recorder
|
||||||
|
|
||||||
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end)
|
Recorder.on_batch_end = _patched_call(Recorder.on_batch_end, PatchFastaiV1._on_batch_end)
|
||||||
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end)
|
Recorder.on_backward_end = _patched_call(Recorder.on_backward_end, PatchFastaiV1._on_backward_end)
|
||||||
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end)
|
Recorder.on_epoch_end = _patched_call(Recorder.on_epoch_end, PatchFastaiV1._on_epoch_end)
|
||||||
@ -65,7 +68,7 @@ class PatchFastaiV1(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _on_train_begin(original_fn, recorder, *args, **kwargs):
|
def _on_train_begin(original_fn, recorder, *args, **kwargs):
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
if not PatchFastaiV1._main_task:
|
if not PatchFastaiV1._current_task:
|
||||||
return
|
return
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -84,7 +87,7 @@ class PatchFastaiV1(object):
|
|||||||
|
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
|
|
||||||
if not PatchFastaiV1._main_task:
|
if not PatchFastaiV1._current_task:
|
||||||
return
|
return
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -112,7 +115,7 @@ class PatchFastaiV1(object):
|
|||||||
min_gradient=gradient_stats[:, 5].min(),
|
min_gradient=gradient_stats[:, 5].min(),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = PatchFastaiV1._main_task.get_logger()
|
logger = PatchFastaiV1._current_task.get_logger()
|
||||||
iteration = kwargs.get("iteration", 0)
|
iteration = kwargs.get("iteration", 0)
|
||||||
for name, val in stats_report.items():
|
for name, val in stats_report.items():
|
||||||
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
|
logger.report_scalar(title="model_stats_gradients", series=name, value=val, iteration=iteration)
|
||||||
@ -122,26 +125,26 @@ class PatchFastaiV1(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
|
def _on_epoch_end(original_fn, recorder, *args, **kwargs):
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
if not PatchFastaiV1._main_task:
|
if not PatchFastaiV1._current_task:
|
||||||
return
|
return
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
logger = PatchFastaiV1._main_task.get_logger()
|
logger = PatchFastaiV1._current_task.get_logger()
|
||||||
iteration = kwargs.get("iteration")
|
iteration = kwargs.get("iteration")
|
||||||
for series, value in zip(
|
for series, value in zip(
|
||||||
PatchFastaiV1.__metrics_names[id(recorder)],
|
PatchFastaiV1.__metrics_names[id(recorder)],
|
||||||
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
|
[kwargs.get("smooth_loss")] + kwargs.get("last_metrics", []),
|
||||||
):
|
):
|
||||||
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
|
logger.report_scalar(title="metrics", series=series, value=value, iteration=iteration)
|
||||||
PatchFastaiV1._main_task.flush()
|
PatchFastaiV1._current_task.flush()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _on_batch_end(original_fn, recorder, *args, **kwargs):
|
def _on_batch_end(original_fn, recorder, *args, **kwargs):
|
||||||
original_fn(recorder, *args, **kwargs)
|
original_fn(recorder, *args, **kwargs)
|
||||||
if not PatchFastaiV1._main_task:
|
if not PatchFastaiV1._current_task:
|
||||||
return
|
return
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -150,7 +153,7 @@ class PatchFastaiV1(object):
|
|||||||
if iteration == 0 or not kwargs.get("train"):
|
if iteration == 0 or not kwargs.get("train"):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger = PatchFastaiV1._main_task.get_logger()
|
logger = PatchFastaiV1._current_task.get_logger()
|
||||||
logger.report_scalar(
|
logger.report_scalar(
|
||||||
title="metrics",
|
title="metrics",
|
||||||
series="train_loss",
|
series="train_loss",
|
||||||
@ -174,11 +177,17 @@ class PatchFastaiV1(object):
|
|||||||
|
|
||||||
|
|
||||||
class PatchFastaiV2(object):
|
class PatchFastaiV2(object):
|
||||||
_main_task = None
|
_current_task = None
|
||||||
|
__patched = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchFastaiV2._main_task = task
|
PatchFastaiV2._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
if not PatchFastaiV2.__patched:
|
||||||
|
PatchFastaiV2.__patched = True
|
||||||
|
PatchFastaiV2.patch_model_callback()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patch_model_callback():
|
def patch_model_callback():
|
||||||
@ -212,13 +221,15 @@ class PatchFastaiV2(object):
|
|||||||
self.logger = noop
|
self.logger = noop
|
||||||
self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id)
|
self.__id = str(PatchFastaiV2.PatchFastaiCallbacks.__id)
|
||||||
PatchFastaiV2.PatchFastaiCallbacks.__id += 1
|
PatchFastaiV2.PatchFastaiCallbacks.__id += 1
|
||||||
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._main_task.get_logger())
|
self.__gradient_hist_helper = WeightsGradientHistHelper(PatchFastaiV2._current_task.get_logger())
|
||||||
|
|
||||||
def after_batch(self):
|
def after_batch(self):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
super().after_batch() # noqa
|
super().after_batch() # noqa
|
||||||
logger = PatchFastaiV2._main_task.get_logger()
|
if not PatchFastaiV2._current_task:
|
||||||
|
return
|
||||||
|
logger = PatchFastaiV2._current_task.get_logger()
|
||||||
if not self.training: # noqa
|
if not self.training: # noqa
|
||||||
return
|
return
|
||||||
self.__train_iter += 1
|
self.__train_iter += 1
|
||||||
@ -254,7 +265,9 @@ class PatchFastaiV2(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
super().after_epoch() # noqa
|
super().after_epoch() # noqa
|
||||||
logger = PatchFastaiV2._main_task.get_logger()
|
if not PatchFastaiV2._current_task:
|
||||||
|
return
|
||||||
|
logger = PatchFastaiV2._current_task.get_logger()
|
||||||
for metric in self._valid_mets: # noqa
|
for metric in self._valid_mets: # noqa
|
||||||
logger.report_scalar(
|
logger.report_scalar(
|
||||||
title="metrics_" + self.__id,
|
title="metrics_" + self.__id,
|
||||||
@ -270,7 +283,9 @@ class PatchFastaiV2(object):
|
|||||||
try:
|
try:
|
||||||
if hasattr(fastai.learner.Recorder, "before_step"):
|
if hasattr(fastai.learner.Recorder, "before_step"):
|
||||||
super().before_step() # noqa
|
super().before_step() # noqa
|
||||||
logger = PatchFastaiV2._main_task.get_logger()
|
if not PatchFastaiV2._current_task:
|
||||||
|
return
|
||||||
|
logger = PatchFastaiV2._current_task.get_logger()
|
||||||
gradients = [
|
gradients = [
|
||||||
x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None
|
x.grad.clone().detach().cpu() for x in self.learn.model.parameters() if x.grad is not None
|
||||||
] # noqa
|
] # noqa
|
||||||
|
@ -11,12 +11,14 @@ from ...model import Framework
|
|||||||
|
|
||||||
|
|
||||||
class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **kwargs):
|
||||||
PatchLIGHTgbmModelIO.__main_task = task
|
PatchLIGHTgbmModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchLIGHTgbmModelIO._patch_model_io()
|
PatchLIGHTgbmModelIO._patch_model_io()
|
||||||
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
|
PostImportHookPatching.add_on_import('lightgbm', PatchLIGHTgbmModelIO._patch_model_io)
|
||||||
|
|
||||||
@ -43,7 +45,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
if not PatchLIGHTgbmModelIO.__main_task:
|
if not PatchLIGHTgbmModelIO._current_task:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
if isinstance(f, six.string_types):
|
if isinstance(f, six.string_types):
|
||||||
@ -64,12 +66,15 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
|||||||
model_name = Path(filename).stem
|
model_name = Path(filename).stem
|
||||||
except Exception:
|
except Exception:
|
||||||
model_name = None
|
model_name = None
|
||||||
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO.__main_task,
|
WeightsFileHandler.create_output_model(obj, filename, Framework.lightgbm, PatchLIGHTgbmModelIO._current_task,
|
||||||
singlefile=True, model_name=model_name)
|
singlefile=True, model_name=model_name)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, model_file, *args, **kwargs):
|
def _load(original_fn, model_file, *args, **kwargs):
|
||||||
|
if not PatchLIGHTgbmModelIO._current_task:
|
||||||
|
return original_fn(model_file, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(model_file, six.string_types):
|
if isinstance(model_file, six.string_types):
|
||||||
filename = model_file
|
filename = model_file
|
||||||
elif hasattr(model_file, 'name'):
|
elif hasattr(model_file, 'name'):
|
||||||
@ -79,21 +84,18 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
|||||||
else:
|
else:
|
||||||
filename = None
|
filename = None
|
||||||
|
|
||||||
if not PatchLIGHTgbmModelIO.__main_task:
|
|
||||||
return original_fn(model_file, *args, **kwargs)
|
|
||||||
|
|
||||||
# register input model
|
# register input model
|
||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||||
PatchLIGHTgbmModelIO.__main_task)
|
PatchLIGHTgbmModelIO._current_task)
|
||||||
model = original_fn(model_file=filename or model_file, *args, **kwargs)
|
model = original_fn(model_file=filename or model_file, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering, in case we fail
|
# try to load model before registering, in case we fail
|
||||||
model = original_fn(model_file=model_file, *args, **kwargs)
|
model = original_fn(model_file=model_file, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.lightgbm,
|
||||||
PatchLIGHTgbmModelIO.__main_task)
|
PatchLIGHTgbmModelIO._current_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -110,7 +112,7 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
|||||||
# logging the results to scalars section
|
# logging the results to scalars section
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
logger = PatchLIGHTgbmModelIO.__main_task.get_logger()
|
logger = PatchLIGHTgbmModelIO._current_task.get_logger()
|
||||||
iteration = env.iteration
|
iteration = env.iteration
|
||||||
for data_title, data_series, value, _ in env.evaluation_result_list:
|
for data_title, data_series, value, _ in env.evaluation_result_list:
|
||||||
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
|
logger.report_scalar(title=data_title, series=data_series, value="{:.6f}".format(value),
|
||||||
@ -121,12 +123,12 @@ class PatchLIGHTgbmModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
|
kwargs.setdefault("callbacks", []).append(trains_lightgbm_callback())
|
||||||
ret = original_fn(*args, **kwargs)
|
ret = original_fn(*args, **kwargs)
|
||||||
if not PatchLIGHTgbmModelIO.__main_task:
|
if not PatchLIGHTgbmModelIO._current_task:
|
||||||
return ret
|
return ret
|
||||||
params = args[0] if args else kwargs.get('params', {})
|
params = args[0] if args else kwargs.get('params', {})
|
||||||
for k, v in params.items():
|
for k, v in params.items():
|
||||||
if isinstance(v, set):
|
if isinstance(v, set):
|
||||||
params[k] = list(v)
|
params[k] = list(v)
|
||||||
if params:
|
if params:
|
||||||
PatchLIGHTgbmModelIO.__main_task.connect(params)
|
PatchLIGHTgbmModelIO._current_task.connect(params)
|
||||||
return ret
|
return ret
|
||||||
|
@ -10,14 +10,14 @@ from ...model import Framework
|
|||||||
|
|
||||||
|
|
||||||
class PatchMegEngineModelIO(PatchBaseModelIO):
|
class PatchMegEngineModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
|
|
||||||
# __patched_lightning = None
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchMegEngineModelIO.__main_task = task
|
PatchMegEngineModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchMegEngineModelIO._patch_model_io()
|
PatchMegEngineModelIO._patch_model_io()
|
||||||
PostImportHookPatching.add_on_import(
|
PostImportHookPatching.add_on_import(
|
||||||
'megengine', PatchMegEngineModelIO._patch_model_io
|
'megengine', PatchMegEngineModelIO._patch_model_io
|
||||||
@ -58,7 +58,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
|||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
|
||||||
# if there is no main task or this is a nested call
|
# if there is no main task or this is a nested call
|
||||||
if not PatchMegEngineModelIO.__main_task:
|
if not PatchMegEngineModelIO._current_task:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -93,7 +93,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
WeightsFileHandler.create_output_model(
|
WeightsFileHandler.create_output_model(
|
||||||
obj, filename, Framework.megengine,
|
obj, filename, Framework.megengine,
|
||||||
PatchMegEngineModelIO.__main_task,
|
PatchMegEngineModelIO._current_task,
|
||||||
singlefile=True, model_name=model_name,
|
singlefile=True, model_name=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
# if there is no main task or this is a nested call
|
# if there is no main task or this is a nested call
|
||||||
if not PatchMegEngineModelIO.__main_task:
|
if not PatchMegEngineModelIO._current_task:
|
||||||
return original_fn(f, *args, **kwargs)
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -124,7 +124,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
|||||||
model = original_fn(f, *args, **kwargs)
|
model = original_fn(f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(
|
WeightsFileHandler.restore_weights_file(
|
||||||
empty, filename, Framework.megengine,
|
empty, filename, Framework.megengine,
|
||||||
PatchMegEngineModelIO.__main_task
|
PatchMegEngineModelIO._current_task
|
||||||
)
|
)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
@ -139,7 +139,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
||||||
# if there is no main task or this is a nested call
|
# if there is no main task or this is a nested call
|
||||||
if not PatchMegEngineModelIO.__main_task:
|
if not PatchMegEngineModelIO._current_task:
|
||||||
return original_fn(obj, f, *args, **kwargs)
|
return original_fn(obj, f, *args, **kwargs)
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -161,7 +161,7 @@ class PatchMegEngineModelIO(PatchBaseModelIO):
|
|||||||
model = original_fn(obj, f, *args, **kwargs)
|
model = original_fn(obj, f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(
|
WeightsFileHandler.restore_weights_file(
|
||||||
empty, filename, Framework.megengine,
|
empty, filename, Framework.megengine,
|
||||||
PatchMegEngineModelIO.__main_task,
|
PatchMegEngineModelIO._current_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
|
@ -11,13 +11,15 @@ from ...model import Framework
|
|||||||
|
|
||||||
|
|
||||||
class PatchPyTorchModelIO(PatchBaseModelIO):
|
class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
__patched_lightning = None
|
__patched_lightning = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchPyTorchModelIO.__main_task = task
|
PatchPyTorchModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchPyTorchModelIO._patch_model_io()
|
PatchPyTorchModelIO._patch_model_io()
|
||||||
PatchPyTorchModelIO._patch_lightning_io()
|
PatchPyTorchModelIO._patch_lightning_io()
|
||||||
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||||
@ -112,7 +114,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
|
||||||
# if there is no main task or this is a nested call
|
# if there is no main task or this is a nested call
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
if not PatchPyTorchModelIO._current_task:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
# pytorch-lightning check if rank is zero
|
# pytorch-lightning check if rank is zero
|
||||||
@ -154,14 +156,14 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
model_name = None
|
model_name = None
|
||||||
|
|
||||||
WeightsFileHandler.create_output_model(
|
WeightsFileHandler.create_output_model(
|
||||||
obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task, singlefile=True, model_name=model_name)
|
obj, filename, Framework.pytorch, PatchPyTorchModelIO._current_task, singlefile=True, model_name=model_name)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
# if there is no main task or this is a nested call
|
# if there is no main task or this is a nested call
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
if not PatchPyTorchModelIO._current_task:
|
||||||
return original_fn(f, *args, **kwargs)
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -182,13 +184,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
filename = WeightsFileHandler.restore_weights_file(
|
filename = WeightsFileHandler.restore_weights_file(
|
||||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering, in case we fail
|
# try to load model before registering, in case we fail
|
||||||
model = original_fn(f, *args, **kwargs)
|
model = original_fn(f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(
|
WeightsFileHandler.restore_weights_file(
|
||||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -202,7 +204,7 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
def _load_from_obj(original_fn, obj, f, *args, **kwargs):
|
||||||
# if there is no main task or this is a nested call
|
# if there is no main task or this is a nested call
|
||||||
if not PatchPyTorchModelIO.__main_task:
|
if not PatchPyTorchModelIO._current_task:
|
||||||
return original_fn(obj, f, *args, **kwargs)
|
return original_fn(obj, f, *args, **kwargs)
|
||||||
|
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -223,13 +225,13 @@ class PatchPyTorchModelIO(PatchBaseModelIO):
|
|||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
filename = WeightsFileHandler.restore_weights_file(
|
filename = WeightsFileHandler.restore_weights_file(
|
||||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||||
model = original_fn(obj, filename or f, *args, **kwargs)
|
model = original_fn(obj, filename or f, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering, in case we fail
|
# try to load model before registering, in case we fail
|
||||||
model = original_fn(obj, f, *args, **kwargs)
|
model = original_fn(obj, f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(
|
WeightsFileHandler.restore_weights_file(
|
||||||
empty, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task)
|
empty, filename, Framework.pytorch, PatchPyTorchModelIO._current_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
|
@ -232,7 +232,7 @@ class EventTrainsWriter(object):
|
|||||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||||
ClearML events and reports the events (metrics) for an ClearML task (logger).
|
ClearML events and reports the events (metrics) for an ClearML task (logger).
|
||||||
"""
|
"""
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__report_hparams = True
|
__report_hparams = True
|
||||||
_add_lock = threading.RLock()
|
_add_lock = threading.RLock()
|
||||||
_series_name_lookup = {}
|
_series_name_lookup = {}
|
||||||
@ -641,7 +641,7 @@ class EventTrainsWriter(object):
|
|||||||
session_start_info = parse_session_start_info_plugin_data(content)
|
session_start_info = parse_session_start_info_plugin_data(content)
|
||||||
session_start_info = MessageToDict(session_start_info)
|
session_start_info = MessageToDict(session_start_info)
|
||||||
hparams = session_start_info["hparams"]
|
hparams = session_start_info["hparams"]
|
||||||
EventTrainsWriter.__main_task.update_parameters(
|
EventTrainsWriter._current_task.update_parameters(
|
||||||
{"TB_hparams/{}".format(k): v for k, v in hparams.items()}
|
{"TB_hparams/{}".format(k): v for k, v in hparams.items()}
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -844,13 +844,13 @@ class EventTrainsWriter(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def update_current_task(cls, task, **kwargs):
|
def update_current_task(cls, task, **kwargs):
|
||||||
cls.__report_hparams = kwargs.get('report_hparams', False)
|
cls.__report_hparams = kwargs.get('report_hparams', False)
|
||||||
if cls.__main_task != task:
|
if cls._current_task != task:
|
||||||
with cls._add_lock:
|
with cls._add_lock:
|
||||||
cls._series_name_lookup = {}
|
cls._series_name_lookup = {}
|
||||||
cls._title_series_writers_lookup = {}
|
cls._title_series_writers_lookup = {}
|
||||||
cls._event_writers_id_to_logdir = {}
|
cls._event_writers_id_to_logdir = {}
|
||||||
cls._title_series_wraparound_counter = {}
|
cls._title_series_wraparound_counter = {}
|
||||||
cls.__main_task = task
|
cls._current_task = task
|
||||||
|
|
||||||
|
|
||||||
# noinspection PyCallingNonCallable
|
# noinspection PyCallingNonCallable
|
||||||
@ -905,9 +905,10 @@ class ProxyEventsWriter(object):
|
|||||||
|
|
||||||
# noinspection PyPep8Naming
|
# noinspection PyPep8Naming
|
||||||
class PatchSummaryToEventTransformer(object):
|
class PatchSummaryToEventTransformer(object):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__original_getattribute = None
|
__original_getattribute = None
|
||||||
__original_getattributeX = None
|
__original_getattributeX = None
|
||||||
|
__patched = False
|
||||||
_original_add_event = None
|
_original_add_event = None
|
||||||
_original_add_eventT = None
|
_original_add_eventT = None
|
||||||
_original_add_eventX = None
|
_original_add_eventX = None
|
||||||
@ -930,15 +931,19 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **kwargs):
|
||||||
PatchSummaryToEventTransformer.defaults_dict.update(kwargs)
|
PatchSummaryToEventTransformer.defaults_dict.update(kwargs)
|
||||||
PatchSummaryToEventTransformer.__main_task = task
|
PatchSummaryToEventTransformer._current_task = task
|
||||||
# make sure we patched the SummaryToEventTransformer
|
if not task:
|
||||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
|
return
|
||||||
PostImportHookPatching.add_on_import('tensorflow',
|
if not PatchSummaryToEventTransformer.__patched:
|
||||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
PatchSummaryToEventTransformer.__patched = True
|
||||||
PostImportHookPatching.add_on_import('torch',
|
# make sure we patched the SummaryToEventTransformer
|
||||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
PatchSummaryToEventTransformer._patch_summary_to_event_transformer()
|
||||||
PostImportHookPatching.add_on_import('tensorboardX',
|
PostImportHookPatching.add_on_import('tensorflow',
|
||||||
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||||
|
PostImportHookPatching.add_on_import('torch',
|
||||||
|
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||||
|
PostImportHookPatching.add_on_import('tensorboardX',
|
||||||
|
PatchSummaryToEventTransformer._patch_summary_to_event_transformer)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_summary_to_event_transformer():
|
def _patch_summary_to_event_transformer():
|
||||||
@ -1002,7 +1007,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_add_eventT(self, *args, **kwargs):
|
def _patched_add_eventT(self, *args, **kwargs):
|
||||||
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||||
if not self.clearml: # noqa
|
if not self.clearml: # noqa
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -1010,7 +1015,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
logdir = self.get_logdir()
|
logdir = self.get_logdir()
|
||||||
except Exception:
|
except Exception:
|
||||||
logdir = None
|
logdir = None
|
||||||
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
|
||||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1021,7 +1026,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_add_eventX(self, *args, **kwargs):
|
def _patched_add_eventX(self, *args, **kwargs):
|
||||||
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'clearml') or not PatchSummaryToEventTransformer._current_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||||
if not self.clearml:
|
if not self.clearml:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -1029,7 +1034,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
logdir = self.get_logdir()
|
logdir = self.get_logdir()
|
||||||
except Exception:
|
except Exception:
|
||||||
logdir = None
|
logdir = None
|
||||||
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.clearml = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
|
||||||
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
logdir=logdir, **PatchSummaryToEventTransformer.defaults_dict)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1051,7 +1056,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _patched_getattribute_(self, attr, get_base):
|
def _patched_getattribute_(self, attr, get_base):
|
||||||
# no main task, zero chance we have an ClearML event logger
|
# no main task, zero chance we have an ClearML event logger
|
||||||
if PatchSummaryToEventTransformer.__main_task is None:
|
if PatchSummaryToEventTransformer._current_task is None:
|
||||||
return get_base(self, attr)
|
return get_base(self, attr)
|
||||||
|
|
||||||
# check if we already have an ClearML event logger
|
# check if we already have an ClearML event logger
|
||||||
@ -1068,7 +1073,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logdir = None
|
logdir = None
|
||||||
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
defaults_dict = __dict__.get('_trains_defaults') or PatchSummaryToEventTransformer.defaults_dict
|
||||||
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
trains_event = EventTrainsWriter(PatchSummaryToEventTransformer._current_task.get_logger(),
|
||||||
logdir=logdir, **defaults_dict)
|
logdir=logdir, **defaults_dict)
|
||||||
|
|
||||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||||
@ -1111,8 +1116,9 @@ class _ModelAdapter(object):
|
|||||||
|
|
||||||
|
|
||||||
class PatchModelCheckPointCallback(object):
|
class PatchModelCheckPointCallback(object):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__original_getattribute = None
|
__original_getattribute = None
|
||||||
|
__patched = False
|
||||||
defaults_dict = dict(
|
defaults_dict = dict(
|
||||||
config_text=None,
|
config_text=None,
|
||||||
config_dict=None,
|
config_dict=None,
|
||||||
@ -1132,11 +1138,15 @@ class PatchModelCheckPointCallback(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **kwargs):
|
||||||
PatchModelCheckPointCallback.defaults_dict.update(kwargs)
|
PatchModelCheckPointCallback.defaults_dict.update(kwargs)
|
||||||
PatchModelCheckPointCallback.__main_task = task
|
PatchModelCheckPointCallback._current_task = task
|
||||||
# make sure we patched the SummaryToEventTransformer
|
if not task:
|
||||||
PatchModelCheckPointCallback._patch_model_checkpoint()
|
return
|
||||||
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
|
if not PatchModelCheckPointCallback.__patched:
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
|
PatchModelCheckPointCallback.__patched = True
|
||||||
|
# make sure we patched the SummaryToEventTransformer
|
||||||
|
PatchModelCheckPointCallback._patch_model_checkpoint()
|
||||||
|
PostImportHookPatching.add_on_import('keras', PatchModelCheckPointCallback._patch_model_checkpoint)
|
||||||
|
PostImportHookPatching.add_on_import('tensorflow', PatchModelCheckPointCallback._patch_model_checkpoint)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_model_checkpoint():
|
def _patch_model_checkpoint():
|
||||||
@ -1176,7 +1186,7 @@ class PatchModelCheckPointCallback(object):
|
|||||||
get_base = PatchModelCheckPointCallback.__original_getattribute
|
get_base = PatchModelCheckPointCallback.__original_getattribute
|
||||||
|
|
||||||
# no main task, zero chance we have an ClearML event logger
|
# no main task, zero chance we have an ClearML event logger
|
||||||
if PatchModelCheckPointCallback.__main_task is None:
|
if PatchModelCheckPointCallback._current_task is None:
|
||||||
return get_base(self, attr)
|
return get_base(self, attr)
|
||||||
|
|
||||||
# check if we already have an ClearML event logger
|
# check if we already have an ClearML event logger
|
||||||
@ -1189,17 +1199,17 @@ class PatchModelCheckPointCallback(object):
|
|||||||
base_model = __dict__['model']
|
base_model = __dict__['model']
|
||||||
defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict
|
defaults_dict = __dict__.get('_trains_defaults') or PatchModelCheckPointCallback.defaults_dict
|
||||||
output_model = OutputModel(
|
output_model = OutputModel(
|
||||||
PatchModelCheckPointCallback.__main_task,
|
PatchModelCheckPointCallback._current_task,
|
||||||
config_text=defaults_dict.get('config_text'),
|
config_text=defaults_dict.get('config_text'),
|
||||||
config_dict=defaults_dict.get('config_dict'),
|
config_dict=defaults_dict.get('config_dict'),
|
||||||
name=defaults_dict.get('name'),
|
name=defaults_dict.get('name'),
|
||||||
comment=defaults_dict.get('comment'),
|
comment=defaults_dict.get('comment'),
|
||||||
label_enumeration=defaults_dict.get('label_enumeration') or
|
label_enumeration=defaults_dict.get('label_enumeration') or
|
||||||
PatchModelCheckPointCallback.__main_task.get_labels_enumeration(),
|
PatchModelCheckPointCallback._current_task.get_labels_enumeration(),
|
||||||
framework=Framework.keras,
|
framework=Framework.keras,
|
||||||
)
|
)
|
||||||
output_model.set_upload_destination(
|
output_model.set_upload_destination(
|
||||||
PatchModelCheckPointCallback.__main_task.get_output_destination(raise_on_error=False))
|
PatchModelCheckPointCallback._current_task.get_output_destination(raise_on_error=False))
|
||||||
trains_model = _ModelAdapter(base_model, output_model)
|
trains_model = _ModelAdapter(base_model, output_model)
|
||||||
|
|
||||||
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
# order is important, the return value of ProxyEventsWriter is the last object in the list
|
||||||
@ -1209,26 +1219,32 @@ class PatchModelCheckPointCallback(object):
|
|||||||
|
|
||||||
# noinspection PyProtectedMember,PyUnresolvedReferences
|
# noinspection PyProtectedMember,PyUnresolvedReferences
|
||||||
class PatchTensorFlowEager(object):
|
class PatchTensorFlowEager(object):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__original_fn_scalar = None
|
__original_fn_scalar = None
|
||||||
__original_fn_hist = None
|
__original_fn_hist = None
|
||||||
__original_fn_image = None
|
__original_fn_image = None
|
||||||
|
__original_fn_write_summary = None
|
||||||
__trains_event_writer = {}
|
__trains_event_writer = {}
|
||||||
__tf_tb_writer_id_to_logdir = {}
|
__tf_tb_writer_id_to_logdir = {}
|
||||||
|
__patched = False
|
||||||
defaults_dict = dict(
|
defaults_dict = dict(
|
||||||
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
|
report_freq=1, image_report_freq=1, histogram_update_freq_multiplier=5,
|
||||||
histogram_granularity=50)
|
histogram_granularity=50)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **kwargs):
|
||||||
if task != PatchTensorFlowEager.__main_task:
|
if task != PatchTensorFlowEager._current_task:
|
||||||
PatchTensorFlowEager.__trains_event_writer = {}
|
PatchTensorFlowEager.__trains_event_writer = {}
|
||||||
|
|
||||||
PatchTensorFlowEager.defaults_dict.update(kwargs)
|
PatchTensorFlowEager.defaults_dict.update(kwargs)
|
||||||
PatchTensorFlowEager.__main_task = task
|
PatchTensorFlowEager._current_task = task
|
||||||
# make sure we patched the SummaryToEventTransformer
|
if not task:
|
||||||
PatchTensorFlowEager._patch_summary_ops()
|
return
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
|
if not PatchTensorFlowEager.__patched:
|
||||||
|
PatchTensorFlowEager.__patched = True
|
||||||
|
# make sure we patched the SummaryToEventTransformer
|
||||||
|
PatchTensorFlowEager._patch_summary_ops()
|
||||||
|
PostImportHookPatching.add_on_import('tensorflow', PatchTensorFlowEager._patch_summary_ops)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _patch_summary_ops():
|
def _patch_summary_ops():
|
||||||
@ -1245,7 +1261,7 @@ class PatchTensorFlowEager(object):
|
|||||||
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
|
gen_summary_ops.write_image_summary = PatchTensorFlowEager._write_image_summary
|
||||||
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
|
PatchTensorFlowEager.__original_fn_hist = gen_summary_ops.write_histogram_summary
|
||||||
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
|
gen_summary_ops.write_histogram_summary = PatchTensorFlowEager._write_hist_summary
|
||||||
PatchTensorFlowEager.__write_summary = gen_summary_ops.write_summary
|
PatchTensorFlowEager.__original_fn_write_summary = gen_summary_ops.write_summary
|
||||||
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
|
gen_summary_ops.write_summary = PatchTensorFlowEager._write_summary
|
||||||
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
|
gen_summary_ops.create_summary_file_writer = partial(IsTensorboardInit._patched_tb__init__,
|
||||||
gen_summary_ops.create_summary_file_writer)
|
gen_summary_ops.create_summary_file_writer)
|
||||||
@ -1270,6 +1286,8 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_summary_file_writer(original_fn, *args, **kwargs):
|
def _create_summary_file_writer(original_fn, *args, **kwargs):
|
||||||
|
if not PatchTensorFlowEager._current_task:
|
||||||
|
return original_fn(*args, **kwargs)
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
a_logdir = None
|
a_logdir = None
|
||||||
@ -1294,7 +1312,7 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_event_writer(writer):
|
def _get_event_writer(writer):
|
||||||
if not PatchTensorFlowEager.__main_task:
|
if not PatchTensorFlowEager._current_task:
|
||||||
return None
|
return None
|
||||||
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
|
if not PatchTensorFlowEager.__trains_event_writer.get(id(writer)):
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -1324,7 +1342,7 @@ class PatchTensorFlowEager(object):
|
|||||||
logdir = None
|
logdir = None
|
||||||
|
|
||||||
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
|
PatchTensorFlowEager.__trains_event_writer[id(writer)] = EventTrainsWriter(
|
||||||
logger=PatchTensorFlowEager.__main_task.get_logger(), logdir=logdir,
|
logger=PatchTensorFlowEager._current_task.get_logger(), logdir=logdir,
|
||||||
**PatchTensorFlowEager.defaults_dict)
|
**PatchTensorFlowEager.defaults_dict)
|
||||||
return PatchTensorFlowEager.__trains_event_writer[id(writer)]
|
return PatchTensorFlowEager.__trains_event_writer[id(writer)]
|
||||||
|
|
||||||
@ -1337,6 +1355,10 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
def _write_summary(writer, step, tensor, tag, summary_metadata, name=None, **kwargs):
|
||||||
|
if not PatchTensorFlowEager._current_task:
|
||||||
|
return PatchTensorFlowEager.__original_fn_write_summary(
|
||||||
|
writer, step, tensor, tag, summary_metadata, name, **kwargs
|
||||||
|
)
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
# make sure we can get the tensors values
|
# make sure we can get the tensors values
|
||||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||||
@ -1370,13 +1392,17 @@ class PatchTensorFlowEager(object):
|
|||||||
step=int(step.numpy()) if not isinstance(step, int) else step,
|
step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
values=None, audio_data=audio_bytes)
|
values=None, audio_data=audio_bytes)
|
||||||
else:
|
else:
|
||||||
pass # print('unsupported plugin_type', plugin_type)
|
pass
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return PatchTensorFlowEager.__write_summary(writer, step, tensor, tag, summary_metadata, name, **kwargs)
|
return PatchTensorFlowEager.__original_fn_write_summary(
|
||||||
|
writer, step, tensor, tag, summary_metadata, name, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
def _write_scalar_summary(writer, step, tag, value, name=None, **kwargs):
|
||||||
|
if not PatchTensorFlowEager._current_task:
|
||||||
|
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||||
try:
|
try:
|
||||||
@ -1417,6 +1443,8 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
def _write_hist_summary(writer, step, tag, values, name, **kwargs):
|
||||||
|
if not PatchTensorFlowEager._current_task:
|
||||||
|
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||||
try:
|
try:
|
||||||
@ -1459,6 +1487,9 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
def _write_image_summary(writer, step, tag, tensor, bad_color, max_images, name, **kwargs):
|
||||||
|
if not PatchTensorFlowEager._current_task:
|
||||||
|
return PatchTensorFlowEager.__original_fn_image(
|
||||||
|
writer, step, tag, tensor, bad_color, max_images, name, **kwargs)
|
||||||
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
event_writer = PatchTensorFlowEager._get_event_writer(writer)
|
||||||
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
if event_writer and isinstance(step, int) or hasattr(step, 'numpy'):
|
||||||
try:
|
try:
|
||||||
@ -1526,13 +1557,15 @@ class PatchTensorFlowEager(object):
|
|||||||
|
|
||||||
# noinspection PyPep8Naming,SpellCheckingInspection
|
# noinspection PyPep8Naming,SpellCheckingInspection
|
||||||
class PatchKerasModelIO(object):
|
class PatchKerasModelIO(object):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched_keras = None
|
__patched_keras = None
|
||||||
__patched_tensorflow = None
|
__patched_tensorflow = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchKerasModelIO.__main_task = task
|
PatchKerasModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchKerasModelIO._patch_model_checkpoint()
|
PatchKerasModelIO._patch_model_checkpoint()
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
|
||||||
PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('keras', PatchKerasModelIO._patch_model_checkpoint)
|
||||||
@ -1692,7 +1725,7 @@ class PatchKerasModelIO(object):
|
|||||||
def _updated_config(original_fn, self):
|
def _updated_config(original_fn, self):
|
||||||
config = original_fn(self)
|
config = original_fn(self)
|
||||||
# check if we have main task
|
# check if we have main task
|
||||||
if PatchKerasModelIO.__main_task is None:
|
if PatchKerasModelIO._current_task is None:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1717,10 +1750,10 @@ class PatchKerasModelIO(object):
|
|||||||
else:
|
else:
|
||||||
# todo: support multiple models for the same task
|
# todo: support multiple models for the same task
|
||||||
self.trains_out_model.append(OutputModel(
|
self.trains_out_model.append(OutputModel(
|
||||||
task=PatchKerasModelIO.__main_task,
|
task=PatchKerasModelIO._current_task,
|
||||||
config_dict=safe_config,
|
config_dict=safe_config,
|
||||||
name=PatchKerasModelIO.__main_task.name + ' ' + model_name_id,
|
name=PatchKerasModelIO._current_task.name + ' ' + model_name_id,
|
||||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
|
||||||
framework=Framework.keras,
|
framework=Framework.keras,
|
||||||
))
|
))
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -1738,7 +1771,7 @@ class PatchKerasModelIO(object):
|
|||||||
self = _Empty()
|
self = _Empty()
|
||||||
|
|
||||||
# check if we have main task
|
# check if we have main task
|
||||||
if PatchKerasModelIO.__main_task is None:
|
if PatchKerasModelIO._current_task is None:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1751,10 +1784,10 @@ class PatchKerasModelIO(object):
|
|||||||
# check if object already has InputModel
|
# check if object already has InputModel
|
||||||
self.trains_in_model = InputModel.empty(
|
self.trains_in_model = InputModel.empty(
|
||||||
config_dict=config_dict,
|
config_dict=config_dict,
|
||||||
label_enumeration=PatchKerasModelIO.__main_task.get_labels_enumeration(),
|
label_enumeration=PatchKerasModelIO._current_task.get_labels_enumeration(),
|
||||||
)
|
)
|
||||||
# todo: support multiple models for the same task
|
# todo: support multiple models for the same task
|
||||||
PatchKerasModelIO.__main_task.connect(self.trains_in_model)
|
PatchKerasModelIO._current_task.connect(self.trains_in_model)
|
||||||
# if we are running remotely we should deserialize the object
|
# if we are running remotely we should deserialize the object
|
||||||
# because someone might have changed the configuration
|
# because someone might have changed the configuration
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
@ -1781,7 +1814,7 @@ class PatchKerasModelIO(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_weights(original_fn, self, *args, **kwargs):
|
def _load_weights(original_fn, self, *args, **kwargs):
|
||||||
# check if we have main task
|
# check if we have main task
|
||||||
if PatchKerasModelIO.__main_task is None:
|
if PatchKerasModelIO._current_task is None:
|
||||||
return original_fn(self, *args, **kwargs)
|
return original_fn(self, *args, **kwargs)
|
||||||
|
|
||||||
# get filepath
|
# get filepath
|
||||||
@ -1794,7 +1827,7 @@ class PatchKerasModelIO(object):
|
|||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
|
filepath = WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras,
|
||||||
PatchKerasModelIO.__main_task)
|
PatchKerasModelIO._current_task)
|
||||||
if 'filepath' in kwargs:
|
if 'filepath' in kwargs:
|
||||||
kwargs['filepath'] = filepath
|
kwargs['filepath'] = filepath
|
||||||
else:
|
else:
|
||||||
@ -1805,11 +1838,13 @@ class PatchKerasModelIO(object):
|
|||||||
# try to load the files, if something happened exception will be raised before we register the file
|
# try to load the files, if something happened exception will be raised before we register the file
|
||||||
model = original_fn(self, *args, **kwargs)
|
model = original_fn(self, *args, **kwargs)
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO.__main_task)
|
WeightsFileHandler.restore_weights_file(self, filepath, Framework.keras, PatchKerasModelIO._current_task)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, self, *args, **kwargs):
|
def _save(original_fn, self, *args, **kwargs):
|
||||||
|
if not PatchKerasModelIO._current_task:
|
||||||
|
return original_fn(self, *args, **kwargs)
|
||||||
if hasattr(self, 'trains_out_model') and self.trains_out_model:
|
if hasattr(self, 'trains_out_model') and self.trains_out_model:
|
||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
self.trains_out_model[-1]._processed = False
|
self.trains_out_model[-1]._processed = False
|
||||||
@ -1823,12 +1858,13 @@ class PatchKerasModelIO(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_weights(original_fn, self, *args, **kwargs):
|
def _save_weights(original_fn, self, *args, **kwargs):
|
||||||
original_fn(self, *args, **kwargs)
|
original_fn(self, *args, **kwargs)
|
||||||
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
|
if PatchKerasModelIO._current_task:
|
||||||
|
PatchKerasModelIO._update_outputmodel(self, *args, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_outputmodel(self, *args, **kwargs):
|
def _update_outputmodel(self, *args, **kwargs):
|
||||||
# check if we have main task
|
# check if we have main task
|
||||||
if PatchKerasModelIO.__main_task is None:
|
if not PatchKerasModelIO._current_task:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1851,7 +1887,7 @@ class PatchKerasModelIO(object):
|
|||||||
|
|
||||||
if filepath:
|
if filepath:
|
||||||
WeightsFileHandler.create_output_model(
|
WeightsFileHandler.create_output_model(
|
||||||
self, filepath, Framework.keras, PatchKerasModelIO.__main_task,
|
self, filepath, Framework.keras, PatchKerasModelIO._current_task,
|
||||||
config_obj=config or None, singlefile=True)
|
config_obj=config or None, singlefile=True)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -1860,12 +1896,12 @@ class PatchKerasModelIO(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_model(original_fn, model, filepath, *args, **kwargs):
|
def _save_model(original_fn, model, filepath, *args, **kwargs):
|
||||||
original_fn(model, filepath, *args, **kwargs)
|
original_fn(model, filepath, *args, **kwargs)
|
||||||
if PatchKerasModelIO.__main_task:
|
if PatchKerasModelIO._current_task:
|
||||||
PatchKerasModelIO._update_outputmodel(model, filepath)
|
PatchKerasModelIO._update_outputmodel(model, filepath)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_model(original_fn, filepath, *args, **kwargs):
|
def _load_model(original_fn, filepath, *args, **kwargs):
|
||||||
if not PatchKerasModelIO.__main_task:
|
if not PatchKerasModelIO._current_task:
|
||||||
return original_fn(filepath, *args, **kwargs)
|
return original_fn(filepath, *args, **kwargs)
|
||||||
|
|
||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
@ -1873,12 +1909,12 @@ class PatchKerasModelIO(object):
|
|||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
|
filepath = WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras,
|
||||||
PatchKerasModelIO.__main_task)
|
PatchKerasModelIO._current_task)
|
||||||
model = original_fn(filepath, *args, **kwargs)
|
model = original_fn(filepath, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
model = original_fn(filepath, *args, **kwargs)
|
model = original_fn(filepath, *args, **kwargs)
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
|
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO._current_task)
|
||||||
# update the input model object
|
# update the input model object
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -1891,12 +1927,14 @@ class PatchKerasModelIO(object):
|
|||||||
|
|
||||||
|
|
||||||
class PatchTensorflowModelIO(object):
|
class PatchTensorflowModelIO(object):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchTensorflowModelIO.__main_task = task
|
PatchTensorflowModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchTensorflowModelIO._patch_model_checkpoint()
|
PatchTensorflowModelIO._patch_model_checkpoint()
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
|
||||||
|
|
||||||
@ -2028,29 +2066,31 @@ class PatchTensorflowModelIO(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, self, sess, save_path, *args, **kwargs):
|
def _save(original_fn, self, sess, save_path, *args, **kwargs):
|
||||||
saved_path = original_fn(self, sess, save_path, *args, **kwargs)
|
saved_path = original_fn(self, sess, save_path, *args, **kwargs)
|
||||||
if not saved_path:
|
if not saved_path or not PatchTensorflowModelIO._current_task:
|
||||||
return saved_path
|
return saved_path
|
||||||
# store output Model
|
# store output Model
|
||||||
return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow,
|
return WeightsFileHandler.create_output_model(self, saved_path, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
|
def _save_model(original_fn, obj, export_dir, *args, **kwargs):
|
||||||
original_fn(obj, export_dir, *args, **kwargs)
|
original_fn(obj, export_dir, *args, **kwargs)
|
||||||
|
if not PatchKerasModelIO._current_task:
|
||||||
|
return
|
||||||
# store output Model
|
# store output Model
|
||||||
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
|
WeightsFileHandler.create_output_model(obj, export_dir, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _restore(original_fn, self, sess, save_path, *args, **kwargs):
|
def _restore(original_fn, self, sess, save_path, *args, **kwargs):
|
||||||
if PatchTensorflowModelIO.__main_task is None:
|
if PatchTensorflowModelIO._current_task is None:
|
||||||
return original_fn(self, sess, save_path, *args, **kwargs)
|
return original_fn(self, sess, save_path, *args, **kwargs)
|
||||||
|
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
# load model
|
# load model
|
||||||
return original_fn(self, sess, save_path, *args, **kwargs)
|
return original_fn(self, sess, save_path, *args, **kwargs)
|
||||||
|
|
||||||
@ -2058,12 +2098,12 @@ class PatchTensorflowModelIO(object):
|
|||||||
model = original_fn(self, sess, save_path, *args, **kwargs)
|
model = original_fn(self, sess, save_path, *args, **kwargs)
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
|
def _load_lt_2_0(original_fn, sess, tags=None, export_dir=None, *args, **saver_kwargs):
|
||||||
if PatchTensorflowModelIO.__main_task is None:
|
if PatchTensorflowModelIO._current_task is None:
|
||||||
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
return original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||||
|
|
||||||
# register input model
|
# register input model
|
||||||
@ -2072,7 +2112,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
export_dir = WeightsFileHandler.restore_weights_file(
|
export_dir = WeightsFileHandler.restore_weights_file(
|
||||||
empty, export_dir, Framework.tensorflow,
|
empty, export_dir, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task
|
PatchTensorflowModelIO._current_task
|
||||||
)
|
)
|
||||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||||
else:
|
else:
|
||||||
@ -2080,7 +2120,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
model = original_fn(sess, tags, export_dir, *args, **saver_kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(
|
WeightsFileHandler.restore_weights_file(
|
||||||
empty, export_dir, Framework.tensorflow,
|
empty, export_dir, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task
|
PatchTensorflowModelIO._current_task
|
||||||
)
|
)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
@ -2093,7 +2133,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, export_dir, *args, **saver_kwargs):
|
def _load(original_fn, export_dir, *args, **saver_kwargs):
|
||||||
if PatchTensorflowModelIO.__main_task is None:
|
if PatchTensorflowModelIO._current_task is None:
|
||||||
return original_fn(export_dir, *args, **saver_kwargs)
|
return original_fn(export_dir, *args, **saver_kwargs)
|
||||||
|
|
||||||
# register input model
|
# register input model
|
||||||
@ -2102,7 +2142,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
export_dir = WeightsFileHandler.restore_weights_file(
|
export_dir = WeightsFileHandler.restore_weights_file(
|
||||||
empty, export_dir, Framework.tensorflow,
|
empty, export_dir, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task
|
PatchTensorflowModelIO._current_task
|
||||||
)
|
)
|
||||||
model = original_fn(export_dir, *args, **saver_kwargs)
|
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||||
else:
|
else:
|
||||||
@ -2110,7 +2150,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
model = original_fn(export_dir, *args, **saver_kwargs)
|
model = original_fn(export_dir, *args, **saver_kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(
|
WeightsFileHandler.restore_weights_file(
|
||||||
empty, export_dir, Framework.tensorflow,
|
empty, export_dir, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task
|
PatchTensorflowModelIO._current_task
|
||||||
)
|
)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
@ -2124,24 +2164,24 @@ class PatchTensorflowModelIO(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs):
|
def _ckpt_save(original_fn, self, file_prefix, *args, **kwargs):
|
||||||
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
|
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
|
||||||
if PatchTensorflowModelIO.__main_task is None:
|
if PatchTensorflowModelIO._current_task is None:
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
|
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs):
|
def _ckpt_write(original_fn, self, file_prefix, *args, **kwargs):
|
||||||
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
|
checkpoint_path = original_fn(self, file_prefix, *args, **kwargs)
|
||||||
if PatchTensorflowModelIO.__main_task is None:
|
if PatchTensorflowModelIO._current_task is None:
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
|
WeightsFileHandler.create_output_model(self, checkpoint_path, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
return checkpoint_path
|
return checkpoint_path
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ckpt_restore(original_fn, self, save_path, *args, **kwargs):
|
def _ckpt_restore(original_fn, self, save_path, *args, **kwargs):
|
||||||
if PatchTensorflowModelIO.__main_task is None:
|
if PatchTensorflowModelIO._current_task is None:
|
||||||
return original_fn(self, save_path, *args, **kwargs)
|
return original_fn(self, save_path, *args, **kwargs)
|
||||||
|
|
||||||
# register input model
|
# register input model
|
||||||
@ -2149,13 +2189,13 @@ class PatchTensorflowModelIO(object):
|
|||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
|
save_path = WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
model = original_fn(self, save_path, *args, **kwargs)
|
model = original_fn(self, save_path, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering it, in case it fails.
|
# try to load model before registering it, in case it fails.
|
||||||
model = original_fn(self, save_path, *args, **kwargs)
|
model = original_fn(self, save_path, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
|
WeightsFileHandler.restore_weights_file(empty, save_path, Framework.tensorflow,
|
||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO._current_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -2167,12 +2207,14 @@ class PatchTensorflowModelIO(object):
|
|||||||
|
|
||||||
|
|
||||||
class PatchTensorflow2ModelIO(object):
|
class PatchTensorflow2ModelIO(object):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **_):
|
def update_current_task(task, **_):
|
||||||
PatchTensorflow2ModelIO.__main_task = task
|
PatchTensorflow2ModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchTensorflow2ModelIO._patch_model_checkpoint()
|
PatchTensorflow2ModelIO._patch_model_checkpoint()
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
|
||||||
|
|
||||||
@ -2210,18 +2252,20 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
||||||
model = original_fn(self, file_prefix, *args, **kwargs)
|
model = original_fn(self, file_prefix, *args, **kwargs)
|
||||||
|
if not PatchTensorflow2ModelIO._current_task:
|
||||||
|
return model
|
||||||
# store output Model
|
# store output Model
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
|
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
|
||||||
PatchTensorflow2ModelIO.__main_task)
|
PatchTensorflow2ModelIO._current_task)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _restore(original_fn, self, save_path, *args, **kwargs):
|
def _restore(original_fn, self, save_path, *args, **kwargs):
|
||||||
if PatchTensorflow2ModelIO.__main_task is None:
|
if not PatchTensorflow2ModelIO._current_task:
|
||||||
return original_fn(self, save_path, *args, **kwargs)
|
return original_fn(self, save_path, *args, **kwargs)
|
||||||
|
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
@ -2230,7 +2274,7 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
PatchTensorflow2ModelIO.__main_task)
|
PatchTensorflow2ModelIO._current_task)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# load model
|
# load model
|
||||||
@ -2242,7 +2286,7 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
PatchTensorflow2ModelIO.__main_task)
|
PatchTensorflow2ModelIO._current_task)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
@ -11,13 +11,15 @@ from ...model import Framework
|
|||||||
|
|
||||||
|
|
||||||
class PatchXGBoostModelIO(PatchBaseModelIO):
|
class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||||
__main_task = None
|
_current_task = None
|
||||||
__patched = None
|
__patched = None
|
||||||
__callback_cls = None
|
__callback_cls = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **kwargs):
|
||||||
PatchXGBoostModelIO.__main_task = task
|
PatchXGBoostModelIO._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchXGBoostModelIO._patch_model_io()
|
PatchXGBoostModelIO._patch_model_io()
|
||||||
PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io)
|
PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io)
|
||||||
|
|
||||||
@ -55,7 +57,7 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _save(original_fn, obj, f, *args, **kwargs):
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
ret = original_fn(obj, f, *args, **kwargs)
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
if not PatchXGBoostModelIO.__main_task:
|
if not PatchXGBoostModelIO._current_task:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
if isinstance(f, six.string_types):
|
if isinstance(f, six.string_types):
|
||||||
@ -76,12 +78,15 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
model_name = Path(filename).stem
|
model_name = Path(filename).stem
|
||||||
except Exception:
|
except Exception:
|
||||||
model_name = None
|
model_name = None
|
||||||
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task,
|
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO._current_task,
|
||||||
singlefile=True, model_name=model_name)
|
singlefile=True, model_name=model_name)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load(original_fn, f, *args, **kwargs):
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
if not PatchXGBoostModelIO._current_task:
|
||||||
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
if isinstance(f, six.string_types):
|
if isinstance(f, six.string_types):
|
||||||
filename = f
|
filename = f
|
||||||
elif hasattr(f, 'name'):
|
elif hasattr(f, 'name'):
|
||||||
@ -91,21 +96,18 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
else:
|
else:
|
||||||
filename = None
|
filename = None
|
||||||
|
|
||||||
if not PatchXGBoostModelIO.__main_task:
|
|
||||||
return original_fn(f, *args, **kwargs)
|
|
||||||
|
|
||||||
# register input model
|
# register input model
|
||||||
empty = _Empty()
|
empty = _Empty()
|
||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||||
PatchXGBoostModelIO.__main_task)
|
PatchXGBoostModelIO._current_task)
|
||||||
model = original_fn(filename or f, *args, **kwargs)
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
# try to load model before registering, in case we fail
|
# try to load model before registering, in case we fail
|
||||||
model = original_fn(f, *args, **kwargs)
|
model = original_fn(f, *args, **kwargs)
|
||||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||||
PatchXGBoostModelIO.__main_task)
|
PatchXGBoostModelIO._current_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
@ -117,10 +119,12 @@ class PatchXGBoostModelIO(PatchBaseModelIO):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _train(original_fn, *args, **kwargs):
|
def _train(original_fn, *args, **kwargs):
|
||||||
|
if not PatchXGBoostModelIO._current_task:
|
||||||
|
return original_fn(*args, **kwargs)
|
||||||
if PatchXGBoostModelIO.__callback_cls:
|
if PatchXGBoostModelIO.__callback_cls:
|
||||||
callbacks = kwargs.get('callbacks') or []
|
callbacks = kwargs.get('callbacks') or []
|
||||||
kwargs['callbacks'] = callbacks + [
|
kwargs['callbacks'] = callbacks + [
|
||||||
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO.__main_task)
|
PatchXGBoostModelIO.__callback_cls(task=PatchXGBoostModelIO._current_task)
|
||||||
]
|
]
|
||||||
return original_fn(*args, **kwargs)
|
return original_fn(*args, **kwargs)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import io
|
|||||||
import sys
|
import sys
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ..config import running_remotely, get_remote_task_id, DEV_TASK_NO_REUSE
|
from ..config import running_remotely, DEV_TASK_NO_REUSE
|
||||||
from ..debugging.log import LoggerRoot
|
from ..debugging.log import LoggerRoot
|
||||||
|
|
||||||
|
|
||||||
@ -46,6 +46,8 @@ class PatchHydra(object):
|
|||||||
def update_current_task(task):
|
def update_current_task(task):
|
||||||
# set current Task before patching
|
# set current Task before patching
|
||||||
PatchHydra._current_task = task
|
PatchHydra._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
if PatchHydra.patch_hydra():
|
if PatchHydra.patch_hydra():
|
||||||
# check if we have an untracked state, store it.
|
# check if we have an untracked state, store it.
|
||||||
if PatchHydra._last_untracked_state.get('connect'):
|
if PatchHydra._last_untracked_state.get('connect'):
|
||||||
@ -66,9 +68,6 @@ class PatchHydra(object):
|
|||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
if running_remotely():
|
if running_remotely():
|
||||||
if not PatchHydra._current_task:
|
|
||||||
from ..task import Task
|
|
||||||
PatchHydra._current_task = Task.get_task(task_id=get_remote_task_id())
|
|
||||||
# get the _parameter_allow_full_edit casted back to boolean
|
# get the _parameter_allow_full_edit casted back to boolean
|
||||||
connected_config = dict()
|
connected_config = dict()
|
||||||
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
connected_config[PatchHydra._parameter_allow_full_edit] = False
|
||||||
@ -126,8 +125,9 @@ class PatchHydra(object):
|
|||||||
|
|
||||||
if pre_app_task_init_call and not running_remotely():
|
if pre_app_task_init_call and not running_remotely():
|
||||||
LoggerRoot.get_base_logger(PatchHydra).info(
|
LoggerRoot.get_base_logger(PatchHydra).info(
|
||||||
'Task.init called outside of Hydra-App. For full Hydra multi-run support, '
|
"Task.init called outside of Hydra-App. For full Hydra multi-run support, "
|
||||||
'move the Task.init call into the Hydra-App main function')
|
"move the Task.init call into the Hydra-App main function"
|
||||||
|
)
|
||||||
|
|
||||||
kwargs["config"] = config
|
kwargs["config"] = config
|
||||||
kwargs["task_function"] = partial(PatchHydra._patched_task_function, task_function,)
|
kwargs["task_function"] = partial(PatchHydra._patched_task_function, task_function,)
|
||||||
|
@ -77,6 +77,8 @@ class PatchedJoblib(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task):
|
def update_current_task(task):
|
||||||
PatchedJoblib._current_task = task
|
PatchedJoblib._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
PatchedJoblib.patch_joblib()
|
PatchedJoblib.patch_joblib()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -14,7 +14,7 @@ from ..utilities.proxy_object import flatten_dictionary
|
|||||||
|
|
||||||
class PatchJsonArgParse(object):
|
class PatchJsonArgParse(object):
|
||||||
_args = {}
|
_args = {}
|
||||||
_main_task = None
|
_current_task = None
|
||||||
_args_sep = "/"
|
_args_sep = "/"
|
||||||
_args_type = {}
|
_args_type = {}
|
||||||
_commands_sep = "."
|
_commands_sep = "."
|
||||||
@ -24,14 +24,19 @@ class PatchJsonArgParse(object):
|
|||||||
__remote_task_params = {}
|
__remote_task_params = {}
|
||||||
__patched = False
|
__patched = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def update_current_task(cls, task):
|
||||||
|
cls._current_task = task
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
cls.patch(task)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def patch(cls, task):
|
def patch(cls, task):
|
||||||
if ArgumentParser is None:
|
if ArgumentParser is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if task:
|
PatchJsonArgParse._update_task_args()
|
||||||
cls._main_task = task
|
|
||||||
PatchJsonArgParse._update_task_args()
|
|
||||||
|
|
||||||
if not cls.__patched:
|
if not cls.__patched:
|
||||||
cls.__patched = True
|
cls.__patched = True
|
||||||
@ -39,14 +44,16 @@ class PatchJsonArgParse(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _update_task_args(cls):
|
def _update_task_args(cls):
|
||||||
if running_remotely() or not cls._main_task or not cls._args:
|
if running_remotely() or not cls._current_task or not cls._args:
|
||||||
return
|
return
|
||||||
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
args = {cls._section_name + cls._args_sep + k: v for k, v in cls._args.items()}
|
||||||
args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()}
|
args_type = {cls._section_name + cls._args_sep + k: v for k, v in cls._args_type.items()}
|
||||||
cls._main_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
cls._current_task._set_parameters(args, __update=True, __parameters_types=args_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_args(original_fn, obj, *args, **kwargs):
|
def _parse_args(original_fn, obj, *args, **kwargs):
|
||||||
|
if not PatchJsonArgParse._current_task:
|
||||||
|
return original_fn(obj, *args, **kwargs)
|
||||||
if len(args) == 1:
|
if len(args) == 1:
|
||||||
kwargs["args"] = args[0]
|
kwargs["args"] = args[0]
|
||||||
args = []
|
args = []
|
||||||
|
@ -207,12 +207,15 @@ class PatchedMatplotlib:
|
|||||||
PatchedMatplotlib._current_task = task
|
PatchedMatplotlib._current_task = task
|
||||||
PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib)
|
PostImportHookPatching.add_on_import('matplotlib.pyplot', PatchedMatplotlib.patch_matplotlib)
|
||||||
PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib)
|
PostImportHookPatching.add_on_import('matplotlib.pylab', PatchedMatplotlib.patch_matplotlib)
|
||||||
elif PatchedMatplotlib.patch_matplotlib():
|
else:
|
||||||
|
PatchedMatplotlib.patch_matplotlib()
|
||||||
PatchedMatplotlib._current_task = task
|
PatchedMatplotlib._current_task = task
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_imshow(*args, **kw):
|
def patched_imshow(*args, **kw):
|
||||||
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
ret = PatchedMatplotlib._patched_original_imshow(*args, **kw)
|
||||||
|
if not PatchedMatplotlib._current_task:
|
||||||
|
return ret
|
||||||
try:
|
try:
|
||||||
from matplotlib import _pylab_helpers
|
from matplotlib import _pylab_helpers
|
||||||
# store on the plot that this is an imshow plot
|
# store on the plot that this is an imshow plot
|
||||||
@ -227,6 +230,8 @@ class PatchedMatplotlib:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_savefig(self, *args, **kw):
|
def patched_savefig(self, *args, **kw):
|
||||||
ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
|
ret = PatchedMatplotlib._patched_original_savefig(self, *args, **kw)
|
||||||
|
if not PatchedMatplotlib._current_task:
|
||||||
|
return ret
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
fname = kw.get('fname') or args[0]
|
fname = kw.get('fname') or args[0]
|
||||||
@ -256,6 +261,8 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_figure_show(self, *args, **kw):
|
def patched_figure_show(self, *args, **kw):
|
||||||
|
if not PatchedMatplotlib._current_task:
|
||||||
|
return PatchedMatplotlib._patched_original_figure(self, *args, **kw)
|
||||||
tid = threading._get_ident() if six.PY2 else threading.get_ident()
|
tid = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||||
if PatchedMatplotlib._recursion_guard.get(tid):
|
if PatchedMatplotlib._recursion_guard.get(tid):
|
||||||
# we are inside a gaurd do nothing
|
# we are inside a gaurd do nothing
|
||||||
@ -269,6 +276,8 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def patched_show(*args, **kw):
|
def patched_show(*args, **kw):
|
||||||
|
if not PatchedMatplotlib._current_task:
|
||||||
|
return PatchedMatplotlib._patched_original_plot(*args, **kw)
|
||||||
tid = threading._get_ident() if six.PY2 else threading.get_ident()
|
tid = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||||
PatchedMatplotlib._recursion_guard[tid] = True
|
PatchedMatplotlib._recursion_guard[tid] = True
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
|
@ -626,8 +626,7 @@ class Task(_Task):
|
|||||||
task.__register_at_exit(task._at_exit)
|
task.__register_at_exit(task._at_exit)
|
||||||
|
|
||||||
# always patch OS forking because of ProcessPool and the alike
|
# always patch OS forking because of ProcessPool and the alike
|
||||||
PatchOsFork.patch_fork()
|
PatchOsFork.patch_fork(task)
|
||||||
|
|
||||||
if auto_connect_frameworks:
|
if auto_connect_frameworks:
|
||||||
def should_connect(*keys):
|
def should_connect(*keys):
|
||||||
"""
|
"""
|
||||||
@ -708,13 +707,15 @@ class Task(_Task):
|
|||||||
make_deterministic(task.get_random_seed())
|
make_deterministic(task.get_random_seed())
|
||||||
|
|
||||||
if auto_connect_arg_parser:
|
if auto_connect_arg_parser:
|
||||||
EnvironmentBind.update_current_task(Task.__main_task)
|
EnvironmentBind.update_current_task(task)
|
||||||
|
|
||||||
|
PatchJsonArgParse.update_current_task(task)
|
||||||
|
|
||||||
# Patch ArgParser to be aware of the current task
|
# Patch ArgParser to be aware of the current task
|
||||||
argparser_update_currenttask(Task.__main_task)
|
argparser_update_currenttask(task)
|
||||||
PatchClick.patch(Task.__main_task)
|
|
||||||
PatchFire.patch(Task.__main_task)
|
PatchClick.patch(task)
|
||||||
PatchJsonArgParse.patch(Task.__main_task)
|
PatchFire.patch(task)
|
||||||
|
|
||||||
# set excluded arguments
|
# set excluded arguments
|
||||||
if isinstance(auto_connect_arg_parser, dict):
|
if isinstance(auto_connect_arg_parser, dict):
|
||||||
@ -1686,6 +1687,22 @@ class Task(_Task):
|
|||||||
# noinspection PyProtectedMember
|
# noinspection PyProtectedMember
|
||||||
Logger._remove_std_logger()
|
Logger._remove_std_logger()
|
||||||
|
|
||||||
|
# unbind everything
|
||||||
|
PatchHydra.update_current_task(None)
|
||||||
|
PatchedJoblib.update_current_task(None)
|
||||||
|
PatchedMatplotlib.update_current_task(None)
|
||||||
|
PatchAbsl.update_current_task(None)
|
||||||
|
TensorflowBinding.update_current_task(None)
|
||||||
|
PatchPyTorchModelIO.update_current_task(None)
|
||||||
|
PatchMegEngineModelIO.update_current_task(None)
|
||||||
|
PatchXGBoostModelIO.update_current_task(None)
|
||||||
|
PatchCatBoostModelIO.update_current_task(None)
|
||||||
|
PatchFastai.update_current_task(None)
|
||||||
|
PatchLIGHTgbmModelIO.update_current_task(None)
|
||||||
|
EnvironmentBind.update_current_task(None)
|
||||||
|
PatchJsonArgParse.update_current_task(None)
|
||||||
|
PatchOsFork.patch_fork(None)
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
delete_artifacts_and_models=True,
|
delete_artifacts_and_models=True,
|
||||||
|
Loading…
Reference in New Issue
Block a user