mirror of
https://github.com/clearml/clearml
synced 2025-06-23 01:55:38 +00:00
Refactored binding, better support for matplotlib jupyter binding
This commit is contained in:
parent
ff8652f39f
commit
a77b470500
@ -1,5 +1,5 @@
|
|||||||
""" absl-py FLAGS binding utility functions """
|
""" absl-py FLAGS binding utility functions """
|
||||||
from trains.backend_interface.task.args import _Arguments
|
from ..backend_interface.task.args import _Arguments
|
||||||
from ..config import running_remotely
|
from ..config import running_remotely
|
||||||
|
|
||||||
|
|
178
trains/binding/frameworks/__init__.py
Normal file
178
trains/binding/frameworks/__init__.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import threading
|
||||||
|
import weakref
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
import six
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from ...config import running_remotely
|
||||||
|
from ...model import InputModel, OutputModel
|
||||||
|
|
||||||
|
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
|
||||||
|
_recursion_guard = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_call(original_fn, patched_fn):
|
||||||
|
def _inner_patch(*args, **kwargs):
|
||||||
|
ident = threading._get_ident() if six.PY2 else threading.get_ident()
|
||||||
|
if ident in _recursion_guard:
|
||||||
|
return original_fn(*args, **kwargs)
|
||||||
|
_recursion_guard[ident] = 1
|
||||||
|
ret = None
|
||||||
|
try:
|
||||||
|
ret = patched_fn(original_fn, *args, **kwargs)
|
||||||
|
except Exception as ex:
|
||||||
|
raise ex
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
_recursion_guard.pop(ident)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return _inner_patch
|
||||||
|
|
||||||
|
|
||||||
|
class _Empty(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.trains_in_model = None
|
||||||
|
|
||||||
|
|
||||||
|
class WeightsFileHandler(object):
|
||||||
|
_model_out_store_lookup = {}
|
||||||
|
_model_in_store_lookup = {}
|
||||||
|
_model_store_lookup_lock = threading.Lock()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def restore_weights_file(model, filepath, framework, task):
|
||||||
|
if task is None:
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
if not filepath:
|
||||||
|
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
try:
|
||||||
|
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||||
|
|
||||||
|
# check if object already has InputModel
|
||||||
|
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
|
||||||
|
if ref_model is not None and model != ref_model():
|
||||||
|
# old id pop it - it was probably reused because the object is dead
|
||||||
|
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
||||||
|
trains_in_model, ref_model = None, None
|
||||||
|
|
||||||
|
# check if object already has InputModel
|
||||||
|
model_name_id = getattr(model, 'name', '')
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
config_text = None
|
||||||
|
config_dict = trains_in_model.config_dict if trains_in_model else None
|
||||||
|
except Exception:
|
||||||
|
config_dict = None
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
config_text = trains_in_model.config_text if trains_in_model else None
|
||||||
|
except Exception:
|
||||||
|
config_text = None
|
||||||
|
trains_in_model = InputModel.import_model(
|
||||||
|
weights_url=filepath,
|
||||||
|
config_dict=config_dict,
|
||||||
|
config_text=config_text,
|
||||||
|
name=task.name + ' ' + model_name_id,
|
||||||
|
label_enumeration=task.get_labels_enumeration(),
|
||||||
|
framework=framework,
|
||||||
|
create_as_published=False,
|
||||||
|
)
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
ref_model = weakref.ref(model)
|
||||||
|
except Exception:
|
||||||
|
ref_model = None
|
||||||
|
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
|
||||||
|
# todo: support multiple models for the same task
|
||||||
|
task.connect(trains_in_model)
|
||||||
|
# if we are running remotely we should deserialize the object
|
||||||
|
# because someone might have changed the config_dict
|
||||||
|
if running_remotely():
|
||||||
|
# reload the model
|
||||||
|
model_config = trains_in_model.config_dict
|
||||||
|
# verify that this is the same model so we are not deserializing a diff model
|
||||||
|
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
|
||||||
|
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
|
||||||
|
(not config_dict and not model_config):
|
||||||
|
filepath = trains_in_model.get_weights()
|
||||||
|
# update filepath to point to downloaded weights file
|
||||||
|
# actual model weights loading will be done outside the try/exception block
|
||||||
|
except Exception as ex:
|
||||||
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
finally:
|
||||||
|
WeightsFileHandler._model_store_lookup_lock.release()
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
||||||
|
if task is None:
|
||||||
|
return saved_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
WeightsFileHandler._model_store_lookup_lock.acquire()
|
||||||
|
|
||||||
|
# check if object already has InputModel
|
||||||
|
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
|
||||||
|
if ref_model is not None and model != ref_model():
|
||||||
|
# old id pop it - it was probably reused because the object is dead
|
||||||
|
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
||||||
|
trains_out_model, ref_model = None, None
|
||||||
|
|
||||||
|
# check if object already has InputModel
|
||||||
|
if trains_out_model is None:
|
||||||
|
trains_out_model = OutputModel(
|
||||||
|
task=task,
|
||||||
|
# config_dict=config,
|
||||||
|
name=(task.name + ' - ' + model_name) if model_name else None,
|
||||||
|
label_enumeration=task.get_labels_enumeration(),
|
||||||
|
framework=framework, )
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
ref_model = weakref.ref(model)
|
||||||
|
except Exception:
|
||||||
|
ref_model = None
|
||||||
|
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
||||||
|
|
||||||
|
if not saved_path:
|
||||||
|
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
|
||||||
|
return saved_path
|
||||||
|
|
||||||
|
# check if we have output storage, and generate list of files to upload
|
||||||
|
if trains_out_model.upload_storage_uri:
|
||||||
|
if Path(saved_path).is_dir():
|
||||||
|
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
||||||
|
elif singlefile:
|
||||||
|
files = [str(Path(saved_path).absolute())]
|
||||||
|
else:
|
||||||
|
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name) + '.*')]
|
||||||
|
else:
|
||||||
|
files = None
|
||||||
|
|
||||||
|
# upload files if we found them, or just register the original path
|
||||||
|
if files:
|
||||||
|
if len(files) > 1:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
target_filename = Path(saved_path).stem
|
||||||
|
except Exception:
|
||||||
|
target_filename = None
|
||||||
|
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
||||||
|
target_filename=target_filename)
|
||||||
|
else:
|
||||||
|
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
|
||||||
|
else:
|
||||||
|
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
||||||
|
except Exception as ex:
|
||||||
|
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||||
|
finally:
|
||||||
|
WeightsFileHandler._model_store_lookup_lock.release()
|
||||||
|
|
||||||
|
return saved_path
|
101
trains/binding/frameworks/pytorch_bind.py
Normal file
101
trains/binding/frameworks/pytorch_bind.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
import six
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||||
|
from ..import_bind import PostImportHookPatching
|
||||||
|
from ...config import running_remotely
|
||||||
|
from ...model import Framework
|
||||||
|
|
||||||
|
|
||||||
|
class PatchPyTorchModelIO(object):
|
||||||
|
__main_task = None
|
||||||
|
__patched = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_task(task, **_):
|
||||||
|
PatchPyTorchModelIO.__main_task = task
|
||||||
|
PatchPyTorchModelIO._patch_model_io()
|
||||||
|
PostImportHookPatching.add_on_import('torch', PatchPyTorchModelIO._patch_model_io)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patch_model_io():
|
||||||
|
if PatchPyTorchModelIO.__patched:
|
||||||
|
return
|
||||||
|
|
||||||
|
if 'torch' not in sys.modules:
|
||||||
|
return
|
||||||
|
|
||||||
|
PatchPyTorchModelIO.__patched = True
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
# hack: make sure tensorflow.__init__ is called
|
||||||
|
import torch
|
||||||
|
torch.save = _patched_call(torch.save, PatchPyTorchModelIO._save)
|
||||||
|
torch.load = _patched_call(torch.load, PatchPyTorchModelIO._load)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass # print('Failed patching pytorch')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _save(original_fn, obj, f, *args, **kwargs):
|
||||||
|
ret = original_fn(obj, f, *args, **kwargs)
|
||||||
|
if not PatchPyTorchModelIO.__main_task:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
f.flush()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
# give the model a descriptive name based on the file name
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model_name = Path(filename).stem
|
||||||
|
except Exception:
|
||||||
|
model_name = None
|
||||||
|
WeightsFileHandler.create_output_model(obj, filename, Framework.pytorch, PatchPyTorchModelIO.__main_task,
|
||||||
|
singlefile=True, model_name=model_name)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _load(original_fn, f, *args, **kwargs):
|
||||||
|
if isinstance(f, six.string_types):
|
||||||
|
filename = f
|
||||||
|
elif hasattr(f, 'name'):
|
||||||
|
filename = f.name
|
||||||
|
else:
|
||||||
|
filename = None
|
||||||
|
|
||||||
|
if not PatchPyTorchModelIO.__main_task:
|
||||||
|
return original_fn(f, *args, **kwargs)
|
||||||
|
|
||||||
|
# register input model
|
||||||
|
empty = _Empty()
|
||||||
|
if running_remotely():
|
||||||
|
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||||
|
PatchPyTorchModelIO.__main_task)
|
||||||
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
# try to load model before registering, in case we fail
|
||||||
|
model = original_fn(filename or f, *args, **kwargs)
|
||||||
|
WeightsFileHandler.restore_weights_file(empty, filename, Framework.pytorch,
|
||||||
|
PatchPyTorchModelIO.__main_task)
|
||||||
|
|
||||||
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
|
try:
|
||||||
|
model.trains_in_model = empty.trains_in_model
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return model
|
@ -1,256 +1,25 @@
|
|||||||
import base64
|
import base64
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import weakref
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from logging import ERROR, WARNING, getLogger
|
from logging import ERROR, WARNING, getLogger
|
||||||
from pathlib2 import Path
|
from typing import Any
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
|
from pathlib2 import Path
|
||||||
|
|
||||||
from ..config import running_remotely
|
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
|
||||||
from ..model import InputModel, OutputModel, Framework
|
from ..import_bind import PostImportHookPatching
|
||||||
|
from ...config import running_remotely
|
||||||
|
from ...model import InputModel, OutputModel, Framework
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from google.protobuf.json_format import MessageToDict
|
from google.protobuf.json_format import MessageToDict
|
||||||
except ImportError:
|
except ImportError:
|
||||||
MessageToDict = None
|
MessageToDict = None
|
||||||
|
|
||||||
if six.PY2:
|
|
||||||
# python2.x
|
|
||||||
import __builtin__ as builtins
|
|
||||||
else:
|
|
||||||
# python3.x
|
|
||||||
import builtins
|
|
||||||
|
|
||||||
|
|
||||||
TrainsFrameworkAdapter = 'TrainsFrameworkAdapter'
|
|
||||||
_recursion_guard = {}
|
|
||||||
|
|
||||||
|
|
||||||
class _Empty(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.trains_in_model = None
|
|
||||||
|
|
||||||
|
|
||||||
class PostImportHookPatching(object):
|
|
||||||
_patched = False
|
|
||||||
_post_import_hooks = defaultdict(list)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _init_hook():
|
|
||||||
if PostImportHookPatching._patched:
|
|
||||||
return
|
|
||||||
PostImportHookPatching._patched = True
|
|
||||||
|
|
||||||
if six.PY2:
|
|
||||||
# python2.x
|
|
||||||
builtins.__org_import__ = builtins.__import__
|
|
||||||
builtins.__import__ = PostImportHookPatching._patched_import2
|
|
||||||
else:
|
|
||||||
# python3.x
|
|
||||||
builtins.__org_import__ = builtins.__import__
|
|
||||||
builtins.__import__ = PostImportHookPatching._patched_import3
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
|
|
||||||
already_imported = name in sys.modules
|
|
||||||
mod = builtins.__org_import__(
|
|
||||||
name,
|
|
||||||
globals=globals,
|
|
||||||
locals=locals,
|
|
||||||
fromlist=fromlist,
|
|
||||||
level=level)
|
|
||||||
|
|
||||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
|
||||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
|
||||||
hook()
|
|
||||||
return mod
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
|
|
||||||
already_imported = name in sys.modules
|
|
||||||
mod = builtins.__org_import__(
|
|
||||||
name,
|
|
||||||
globals=globals,
|
|
||||||
locals=locals,
|
|
||||||
fromlist=fromlist,
|
|
||||||
level=level)
|
|
||||||
|
|
||||||
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
|
||||||
for hook in PostImportHookPatching._post_import_hooks[name]:
|
|
||||||
hook()
|
|
||||||
return mod
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_on_import(name, func):
|
|
||||||
PostImportHookPatching._init_hook()
|
|
||||||
if not name in PostImportHookPatching._post_import_hooks or \
|
|
||||||
func not in PostImportHookPatching._post_import_hooks[name]:
|
|
||||||
PostImportHookPatching._post_import_hooks[name].append(func)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def remove_on_import(name, func):
|
|
||||||
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
|
|
||||||
PostImportHookPatching._post_import_hooks[name].remove(func)
|
|
||||||
|
|
||||||
|
|
||||||
def _patched_call(original_fn, patched_fn):
|
|
||||||
def _inner_patch(*args, **kwargs):
|
|
||||||
ident = threading._get_ident() if six.PY2 else threading.get_ident()
|
|
||||||
if ident in _recursion_guard:
|
|
||||||
return original_fn(*args, **kwargs)
|
|
||||||
_recursion_guard[ident] = 1
|
|
||||||
ret = None
|
|
||||||
try:
|
|
||||||
ret = patched_fn(original_fn, *args, **kwargs)
|
|
||||||
except Exception as ex:
|
|
||||||
raise ex
|
|
||||||
finally:
|
|
||||||
try:
|
|
||||||
_recursion_guard.pop(ident)
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return ret
|
|
||||||
return _inner_patch
|
|
||||||
|
|
||||||
|
|
||||||
class WeightsFileHandler(object):
|
|
||||||
_model_out_store_lookup = {}
|
|
||||||
_model_in_store_lookup = {}
|
|
||||||
_model_store_lookup_lock = threading.Lock()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def restore_weights_file(model, filepath, framework, task):
|
|
||||||
if task is None:
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
if not filepath:
|
|
||||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, model not restored")
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
try:
|
|
||||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
|
||||||
|
|
||||||
# check if object already has InputModel
|
|
||||||
trains_in_model, ref_model = WeightsFileHandler._model_in_store_lookup.get(id(model), (None, None))
|
|
||||||
if ref_model is not None and model != ref_model():
|
|
||||||
# old id pop it - it was probably reused because the object is dead
|
|
||||||
WeightsFileHandler._model_in_store_lookup.pop(id(model))
|
|
||||||
trains_in_model, ref_model = None, None
|
|
||||||
|
|
||||||
# check if object already has InputModel
|
|
||||||
model_name_id = getattr(model, 'name', '')
|
|
||||||
try:
|
|
||||||
config_text = None
|
|
||||||
config_dict = trains_in_model.config_dict if trains_in_model else None
|
|
||||||
except Exception:
|
|
||||||
config_dict = None
|
|
||||||
try:
|
|
||||||
config_text = trains_in_model.config_text if trains_in_model else None
|
|
||||||
except Exception:
|
|
||||||
config_text = None
|
|
||||||
trains_in_model = InputModel.import_model(
|
|
||||||
weights_url=filepath,
|
|
||||||
config_dict=config_dict,
|
|
||||||
config_text=config_text,
|
|
||||||
name=task.name + ' ' + model_name_id,
|
|
||||||
label_enumeration=task.get_labels_enumeration(),
|
|
||||||
framework=framework,
|
|
||||||
create_as_published=False,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
ref_model = weakref.ref(model)
|
|
||||||
except Exception:
|
|
||||||
ref_model = None
|
|
||||||
WeightsFileHandler._model_in_store_lookup[id(model)] = (trains_in_model, ref_model)
|
|
||||||
# todo: support multiple models for the same task
|
|
||||||
task.connect(trains_in_model)
|
|
||||||
# if we are running remotely we should deserialize the object
|
|
||||||
# because someone might have changed the config_dict
|
|
||||||
if running_remotely():
|
|
||||||
# reload the model
|
|
||||||
model_config = trains_in_model.config_dict
|
|
||||||
# verify that this is the same model so we are not deserializing a diff model
|
|
||||||
if (config_dict and config_dict.get('config') and model_config and model_config.get('config') and
|
|
||||||
config_dict.get('config').get('name') == model_config.get('config').get('name')) or \
|
|
||||||
(not config_dict and not model_config):
|
|
||||||
filepath = trains_in_model.get_weights()
|
|
||||||
# update filepath to point to downloaded weights file
|
|
||||||
# actual model weights loading will be done outside the try/exception block
|
|
||||||
except Exception as ex:
|
|
||||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
|
||||||
finally:
|
|
||||||
WeightsFileHandler._model_store_lookup_lock.release()
|
|
||||||
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_output_model(model, saved_path, framework, task, singlefile=False, model_name=None):
|
|
||||||
if task is None:
|
|
||||||
return saved_path
|
|
||||||
|
|
||||||
try:
|
|
||||||
WeightsFileHandler._model_store_lookup_lock.acquire()
|
|
||||||
|
|
||||||
# check if object already has InputModel
|
|
||||||
trains_out_model, ref_model = WeightsFileHandler._model_out_store_lookup.get(id(model), (None, None))
|
|
||||||
if ref_model is not None and model != ref_model():
|
|
||||||
# old id pop it - it was probably reused because the object is dead
|
|
||||||
WeightsFileHandler._model_out_store_lookup.pop(id(model))
|
|
||||||
trains_out_model, ref_model = None, None
|
|
||||||
|
|
||||||
# check if object already has InputModel
|
|
||||||
if trains_out_model is None:
|
|
||||||
trains_out_model = OutputModel(
|
|
||||||
task=task,
|
|
||||||
# config_dict=config,
|
|
||||||
name=(task.name + ' - ' + model_name) if model_name else None,
|
|
||||||
label_enumeration=task.get_labels_enumeration(),
|
|
||||||
framework=framework,)
|
|
||||||
try:
|
|
||||||
ref_model = weakref.ref(model)
|
|
||||||
except Exception:
|
|
||||||
ref_model = None
|
|
||||||
WeightsFileHandler._model_out_store_lookup[id(model)] = (trains_out_model, ref_model)
|
|
||||||
|
|
||||||
if not saved_path:
|
|
||||||
getLogger(TrainsFrameworkAdapter).warning("Could retrieve model location, stored as unknown ")
|
|
||||||
return saved_path
|
|
||||||
|
|
||||||
# check if we have output storage, and generate list of files to upload
|
|
||||||
if trains_out_model.upload_storage_uri:
|
|
||||||
if Path(saved_path).is_dir():
|
|
||||||
files = [str(f) for f in Path(saved_path).rglob('*') if f.is_file()]
|
|
||||||
elif singlefile:
|
|
||||||
files = [str(Path(saved_path).absolute())]
|
|
||||||
else:
|
|
||||||
files = [str(f) for f in Path(saved_path).parent.glob(str(Path(saved_path).name)+'.*')]
|
|
||||||
else:
|
|
||||||
files = None
|
|
||||||
|
|
||||||
# upload files if we found them, or just register the original path
|
|
||||||
if files:
|
|
||||||
if len(files) > 1:
|
|
||||||
try:
|
|
||||||
target_filename = Path(saved_path).stem
|
|
||||||
except Exception:
|
|
||||||
target_filename = None
|
|
||||||
trains_out_model.update_weights_package(weights_filenames=files, auto_delete_file=False,
|
|
||||||
target_filename=target_filename)
|
|
||||||
else:
|
|
||||||
trains_out_model.update_weights(weights_filename=files[0], auto_delete_file=False)
|
|
||||||
else:
|
|
||||||
trains_out_model.update_weights(weights_filename=None, register_uri=saved_path)
|
|
||||||
except Exception as ex:
|
|
||||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
|
||||||
finally:
|
|
||||||
WeightsFileHandler._model_store_lookup_lock.release()
|
|
||||||
|
|
||||||
return saved_path
|
|
||||||
|
|
||||||
|
|
||||||
class EventTrainsWriter(object):
|
class EventTrainsWriter(object):
|
||||||
"""
|
"""
|
||||||
@ -271,7 +40,7 @@ class EventTrainsWriter(object):
|
|||||||
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
|
def tag_splitter(tag, num_split_parts, split_char='/', join_char='_', default_title='variant'):
|
||||||
"""
|
"""
|
||||||
Split a tf.summary tag line to variant and metric.
|
Split a tf.summary tag line to variant and metric.
|
||||||
Variant is the first part of the splitted tag, metric is the second.
|
Variant is the first part of the split tag, metric is the second.
|
||||||
:param str tag:
|
:param str tag:
|
||||||
:param int num_split_parts:
|
:param int num_split_parts:
|
||||||
:param str split_char: a character to split the tag on
|
:param str split_char: a character to split the tag on
|
||||||
@ -313,6 +82,7 @@ class EventTrainsWriter(object):
|
|||||||
self._max_step = 0
|
self._max_step = 0
|
||||||
|
|
||||||
def _decode_image(self, img_str, width, height, color_channels):
|
def _decode_image(self, img_str, width, height, color_channels):
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8)
|
image_string = np.asarray(bytearray(base64.b64decode(img_str)), dtype=np.uint8)
|
||||||
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR)
|
image = cv2.imdecode(image_string, cv2.IMREAD_COLOR)
|
||||||
@ -495,6 +265,7 @@ class EventTrainsWriter(object):
|
|||||||
camera=(-0.1, +1.3, 1.4))
|
camera=(-0.1, +1.3, 1.4))
|
||||||
|
|
||||||
def _add_plot(self, tag, step, values, vdict):
|
def _add_plot(self, tag, step, values, vdict):
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
|
plot_values = np.frombuffer(base64.b64decode(values['tensorContent'].encode('utf-8')),
|
||||||
dtype=np.float32)
|
dtype=np.float32)
|
||||||
@ -749,7 +520,8 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
# only patch once
|
# only patch once
|
||||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||||
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
|
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
|
||||||
PatchSummaryToEventTransformer.__original_getattributeX = SummaryToEventTransformerX.__getattribute__
|
PatchSummaryToEventTransformer.__original_getattributeX = \
|
||||||
|
SummaryToEventTransformerX.__getattribute__
|
||||||
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
|
SummaryToEventTransformerX.__getattribute__ = PatchSummaryToEventTransformer._patched_getattributeX
|
||||||
setattr(SummaryToEventTransformerX, 'trains',
|
setattr(SummaryToEventTransformerX, 'trains',
|
||||||
property(PatchSummaryToEventTransformer.trains_object))
|
property(PatchSummaryToEventTransformer.trains_object))
|
||||||
@ -780,6 +552,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not self.trains:
|
if not self.trains:
|
||||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
**PatchSummaryToEventTransformer.defaults_dict)
|
**PatchSummaryToEventTransformer.defaults_dict)
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
self.trains.add_event(*args, **kwargs)
|
self.trains.add_event(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -793,6 +566,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not self.trains:
|
if not self.trains:
|
||||||
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
self.trains = EventTrainsWriter(PatchSummaryToEventTransformer.__main_task.get_logger(),
|
||||||
**PatchSummaryToEventTransformer.defaults_dict)
|
**PatchSummaryToEventTransformer.defaults_dict)
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
self.trains.add_event(*args, **kwargs)
|
self.trains.add_event(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1313,6 +1087,7 @@ class PatchKerasModelIO(object):
|
|||||||
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
|
WeightsFileHandler.restore_weights_file(empty, filepath, Framework.keras, PatchKerasModelIO.__main_task)
|
||||||
# update the input model object
|
# update the input model object
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model.trains_in_model = empty.trains_in_model
|
model.trains_in_model = empty.trains_in_model
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1340,15 +1115,17 @@ class PatchTensorflowModelIO(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
PatchTensorflowModelIO.__patched = True
|
PatchTensorflowModelIO.__patched = True
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow
|
import tensorflow
|
||||||
from tensorflow.python.training.saver import Saver
|
from tensorflow.python.training.saver import Saver
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
|
Saver.save = _patched_call(Saver.save, PatchTensorflowModelIO._save)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
|
Saver.restore = _patched_call(Saver.restore, PatchTensorflowModelIO._restore)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1358,6 +1135,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
pass # print('Failed patching tensorflow')
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow
|
||||||
@ -1365,6 +1143,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
# actual import
|
# actual import
|
||||||
import tensorflow.saved_model.experimental as saved_model
|
import tensorflow.saved_model.experimental as saved_model
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow
|
||||||
@ -1383,6 +1162,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
if saved_model is not None:
|
if saved_model is not None:
|
||||||
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
|
saved_model.save = _patched_call(saved_model.save, PatchTensorflowModelIO._save_model)
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow
|
||||||
@ -1395,6 +1175,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
pass # print('Failed patching tensorflow')
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow
|
||||||
@ -1406,6 +1187,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
pass # print('Failed patching tensorflow')
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow
|
import tensorflow
|
||||||
@ -1417,17 +1199,21 @@ class PatchTensorflowModelIO(object):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass # print('Failed patching tensorflow')
|
pass # print('Failed patching tensorflow')
|
||||||
|
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
import tensorflow
|
import tensorflow
|
||||||
from tensorflow.train import Checkpoint
|
from tensorflow.train import Checkpoint
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
|
Checkpoint.save = _patched_call(Checkpoint.save, PatchTensorflowModelIO._ckpt_save)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore)
|
Checkpoint.restore = _patched_call(Checkpoint.restore, PatchTensorflowModelIO._ckpt_restore)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
|
Checkpoint.write = _patched_call(Checkpoint.write, PatchTensorflowModelIO._ckpt_write)
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1490,6 +1276,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO.__main_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model.trains_in_model = empty.trains_in_model
|
model.trains_in_model = empty.trains_in_model
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1532,6 +1319,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
PatchTensorflowModelIO.__main_task)
|
PatchTensorflowModelIO.__main_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model.trains_in_model = empty.trains_in_model
|
model.trains_in_model = empty.trains_in_model
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1558,7 +1346,7 @@ class PatchPyTorchModelIO(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
PatchPyTorchModelIO.__patched = True
|
PatchPyTorchModelIO.__patched = True
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import torch
|
import torch
|
||||||
@ -1579,6 +1367,7 @@ class PatchPyTorchModelIO(object):
|
|||||||
filename = f
|
filename = f
|
||||||
elif hasattr(f, 'name'):
|
elif hasattr(f, 'name'):
|
||||||
filename = f.name
|
filename = f.name
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
f.flush()
|
f.flush()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1586,7 +1375,8 @@ class PatchPyTorchModelIO(object):
|
|||||||
else:
|
else:
|
||||||
filename = None
|
filename = None
|
||||||
|
|
||||||
# if the model a screptive name based on the file name
|
# give the model a descriptive name based on the file name
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model_name = Path(filename).stem
|
model_name = Path(filename).stem
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -1620,6 +1410,7 @@ class PatchPyTorchModelIO(object):
|
|||||||
PatchPyTorchModelIO.__main_task)
|
PatchPyTorchModelIO.__main_task)
|
||||||
|
|
||||||
if empty.trains_in_model:
|
if empty.trains_in_model:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
model.trains_in_model = empty.trains_in_model
|
model.trains_in_model = empty.trains_in_model
|
||||||
except Exception:
|
except Exception:
|
73
trains/binding/import_bind.py
Normal file
73
trains/binding/import_bind.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import sys
|
||||||
|
from collections import defaultdict
|
||||||
|
import six
|
||||||
|
|
||||||
|
if six.PY2:
|
||||||
|
# python2.x
|
||||||
|
import __builtin__ as builtins
|
||||||
|
else:
|
||||||
|
# python3.x
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
|
||||||
|
class PostImportHookPatching(object):
|
||||||
|
_patched = False
|
||||||
|
_post_import_hooks = defaultdict(list)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _init_hook():
|
||||||
|
if PostImportHookPatching._patched:
|
||||||
|
return
|
||||||
|
PostImportHookPatching._patched = True
|
||||||
|
|
||||||
|
if six.PY2:
|
||||||
|
# python2.x
|
||||||
|
builtins.__org_import__ = builtins.__import__
|
||||||
|
builtins.__import__ = PostImportHookPatching._patched_import2
|
||||||
|
else:
|
||||||
|
# python3.x
|
||||||
|
builtins.__org_import__ = builtins.__import__
|
||||||
|
builtins.__import__ = PostImportHookPatching._patched_import3
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_import2(name, globals={}, locals={}, fromlist=[], level=-1):
|
||||||
|
already_imported = name in sys.modules
|
||||||
|
mod = builtins.__org_import__(
|
||||||
|
name,
|
||||||
|
globals=globals,
|
||||||
|
locals=locals,
|
||||||
|
fromlist=fromlist,
|
||||||
|
level=level)
|
||||||
|
|
||||||
|
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
||||||
|
for hook in PostImportHookPatching._post_import_hooks[name]:
|
||||||
|
hook()
|
||||||
|
return mod
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patched_import3(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
already_imported = name in sys.modules
|
||||||
|
mod = builtins.__org_import__(
|
||||||
|
name,
|
||||||
|
globals=globals,
|
||||||
|
locals=locals,
|
||||||
|
fromlist=fromlist,
|
||||||
|
level=level)
|
||||||
|
|
||||||
|
if not already_imported and name in PostImportHookPatching._post_import_hooks:
|
||||||
|
for hook in PostImportHookPatching._post_import_hooks[name]:
|
||||||
|
hook()
|
||||||
|
return mod
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_on_import(name, func):
|
||||||
|
PostImportHookPatching._init_hook()
|
||||||
|
if not name in PostImportHookPatching._post_import_hooks or \
|
||||||
|
func not in PostImportHookPatching._post_import_hooks[name]:
|
||||||
|
PostImportHookPatching._post_import_hooks[name].append(func)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def remove_on_import(name, func):
|
||||||
|
if name in PostImportHookPatching._post_import_hooks and func in PostImportHookPatching._post_import_hooks[name]:
|
||||||
|
PostImportHookPatching._post_import_hooks[name].remove(func)
|
||||||
|
|
@ -12,6 +12,8 @@ from ..config import running_remotely
|
|||||||
class PatchedMatplotlib:
|
class PatchedMatplotlib:
|
||||||
_patched_original_plot = None
|
_patched_original_plot = None
|
||||||
__patched_original_imshow = None
|
__patched_original_imshow = None
|
||||||
|
__patched_original_draw_all = None
|
||||||
|
__patched_draw_all_recursion_guard = False
|
||||||
_global_plot_counter = -1
|
_global_plot_counter = -1
|
||||||
_global_image_counter = -1
|
_global_image_counter = -1
|
||||||
_current_task = None
|
_current_task = None
|
||||||
@ -45,18 +47,18 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
if running_remotely():
|
if running_remotely():
|
||||||
# disable GUI backend - make headless
|
# disable GUI backend - make headless
|
||||||
sys.modules['matplotlib'].rcParams['backend'] = 'agg'
|
matplotlib.rcParams['backend'] = 'agg'
|
||||||
import matplotlib.pyplot
|
import matplotlib.pyplot
|
||||||
sys.modules['matplotlib'].pyplot.switch_backend('agg')
|
matplotlib.pyplot.switch_backend('agg')
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib import _pylab_helpers
|
from matplotlib import _pylab_helpers
|
||||||
if six.PY2:
|
if six.PY2:
|
||||||
PatchedMatplotlib._patched_original_plot = staticmethod(sys.modules['matplotlib'].pyplot.show)
|
PatchedMatplotlib._patched_original_plot = staticmethod(plt.show)
|
||||||
PatchedMatplotlib._patched_original_imshow = staticmethod(sys.modules['matplotlib'].pyplot.imshow)
|
PatchedMatplotlib._patched_original_imshow = staticmethod(plt.imshow)
|
||||||
else:
|
else:
|
||||||
PatchedMatplotlib._patched_original_plot = sys.modules['matplotlib'].pyplot.show
|
PatchedMatplotlib._patched_original_plot = plt.show
|
||||||
PatchedMatplotlib._patched_original_imshow = sys.modules['matplotlib'].pyplot.imshow
|
PatchedMatplotlib._patched_original_imshow = plt.imshow
|
||||||
sys.modules['matplotlib'].pyplot.show = PatchedMatplotlib.patched_show
|
plt.show = PatchedMatplotlib.patched_show
|
||||||
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
# sys.modules['matplotlib'].pyplot.imshow = PatchedMatplotlib.patched_imshow
|
||||||
# patch plotly so we know it failed us.
|
# patch plotly so we know it failed us.
|
||||||
from plotly.matplotlylib import renderer
|
from plotly.matplotlylib import renderer
|
||||||
@ -71,7 +73,11 @@ class PatchedMatplotlib:
|
|||||||
from IPython import get_ipython
|
from IPython import get_ipython
|
||||||
ip = get_ipython()
|
ip = get_ipython()
|
||||||
if ip and matplotlib.is_interactive():
|
if ip and matplotlib.is_interactive():
|
||||||
ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
|
# instead of hooking ipython, we should hook the matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
PatchedMatplotlib.__patched_original_draw_all = plt.draw_all
|
||||||
|
plt.draw_all = PatchedMatplotlib.__patched_draw_all
|
||||||
|
# ip.events.register('post_execute', PatchedMatplotlib.ipython_post_execute_hook)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -188,6 +194,19 @@ class PatchedMatplotlib:
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __patched_draw_all(*args, **kwargs):
|
||||||
|
recursion_guard = PatchedMatplotlib.__patched_draw_all_recursion_guard
|
||||||
|
if not recursion_guard:
|
||||||
|
PatchedMatplotlib.__patched_draw_all_recursion_guard = True
|
||||||
|
|
||||||
|
ret = PatchedMatplotlib.__patched_original_draw_all(*args, **kwargs)
|
||||||
|
|
||||||
|
if not recursion_guard:
|
||||||
|
PatchedMatplotlib.ipython_post_execute_hook()
|
||||||
|
PatchedMatplotlib.__patched_draw_all_recursion_guard = False
|
||||||
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ipython_post_execute_hook():
|
def ipython_post_execute_hook():
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
@ -27,12 +27,13 @@ from .errors import UsageError
|
|||||||
from .logger import Logger
|
from .logger import Logger
|
||||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||||
from .task_parameters import TaskParameters
|
from .task_parameters import TaskParameters
|
||||||
from .utilities.absl_bind import PatchAbsl
|
from .binding.absl_bind import PatchAbsl
|
||||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||||
argparser_update_currenttask
|
argparser_update_currenttask
|
||||||
from .utilities.frameworks import PatchSummaryToEventTransformer, PatchTensorFlowEager, PatchKerasModelIO, \
|
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||||
PatchTensorflowModelIO, PatchPyTorchModelIO
|
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
|
||||||
from .utilities.matplotlib_bind import PatchedMatplotlib
|
PatchKerasModelIO, PatchTensorflowModelIO
|
||||||
|
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||||
from .utilities.seed import make_deterministic
|
from .utilities.seed import make_deterministic
|
||||||
|
|
||||||
NotSet = object()
|
NotSet = object()
|
||||||
|
Loading…
Reference in New Issue
Block a user