mirror of
https://github.com/clearml/clearml
synced 2025-04-21 14:54:23 +00:00
Tensorboard text logging report as debug samples (.txt files), instead of as console output.
This commit is contained in:
parent
934771184d
commit
2f5b519cd8
@ -192,7 +192,8 @@ class UploadEvent(MetricsEventAdapter):
|
|||||||
file_history_size=None, delete_after_upload=False, **kwargs):
|
file_history_size=None, delete_after_upload=False, **kwargs):
|
||||||
# param override_filename: override uploaded file name (notice extension will be added from local path
|
# param override_filename: override uploaded file name (notice extension will be added from local path
|
||||||
# param override_filename_ext: override uploaded file extension
|
# param override_filename_ext: override uploaded file extension
|
||||||
if image_data is not None and (not hasattr(image_data, 'shape') and not isinstance(image_data, six.BytesIO)):
|
if image_data is not None and (
|
||||||
|
not hasattr(image_data, 'shape') and not isinstance(image_data, (six.StringIO, six.BytesIO))):
|
||||||
raise ValueError('Image must have a shape attribute')
|
raise ValueError('Image must have a shape attribute')
|
||||||
self._image_data = image_data
|
self._image_data = image_data
|
||||||
self._local_image_path = local_image_path
|
self._local_image_path = local_image_path
|
||||||
@ -263,7 +264,7 @@ class UploadEvent(MetricsEventAdapter):
|
|||||||
last_count = self._get_metric_count(self.metric, self.variant, next=False)
|
last_count = self._get_metric_count(self.metric, self.variant, next=False)
|
||||||
if abs(self._count - last_count) > self._file_history_size:
|
if abs(self._count - last_count) > self._file_history_size:
|
||||||
output = None
|
output = None
|
||||||
elif isinstance(self._image_data, six.BytesIO):
|
elif isinstance(self._image_data, (six.StringIO, six.BytesIO)):
|
||||||
output = self._image_data
|
output = self._image_data
|
||||||
elif self._image_data is not None:
|
elif self._image_data is not None:
|
||||||
image_data = self._image_data
|
image_data = self._image_data
|
||||||
|
@ -277,7 +277,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
:type iter: int
|
:type iter: int
|
||||||
:param path: A path to an image file. Required unless matrix is provided.
|
:param path: A path to an image file. Required unless matrix is provided.
|
||||||
:type path: str
|
:type path: str
|
||||||
:param stream: File stream
|
:param stream: File/String stream
|
||||||
:param file_extension: file extension to use when stream is passed
|
:param file_extension: file extension to use when stream is passed
|
||||||
:param max_history: maximum number of files to store per metric/variant combination
|
:param max_history: maximum number of files to store per metric/variant combination
|
||||||
use negative value for unlimited. default is set in global configuration (default=5)
|
use negative value for unlimited. default is set in global configuration (default=5)
|
||||||
@ -288,6 +288,9 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
|||||||
raise ValueError('Upload configuration is required (use setup_upload())')
|
raise ValueError('Upload configuration is required (use setup_upload())')
|
||||||
if len([x for x in (path, stream) if x is not None]) != 1:
|
if len([x for x in (path, stream) if x is not None]) != 1:
|
||||||
raise ValueError('Expected only one of [filename, stream]')
|
raise ValueError('Expected only one of [filename, stream]')
|
||||||
|
if isinstance(stream, six.string_types):
|
||||||
|
stream = six.StringIO(stream)
|
||||||
|
|
||||||
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
|
kwargs = dict(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
|
||||||
file_history_size=max_history)
|
file_history_size=max_history)
|
||||||
ev = MediaEvent(stream=stream, upload_uri=upload_uri, local_image_path=path,
|
ev = MediaEvent(stream=stream, upload_uri=upload_uri, local_image_path=path,
|
||||||
|
@ -51,6 +51,7 @@ class IsTensorboardInit(object):
|
|||||||
return original_init(self, *args, **kwargs)
|
return original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember
|
||||||
class WeightsGradientHistHelper(object):
|
class WeightsGradientHistHelper(object):
|
||||||
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
|
def __init__(self, logger, report_freq=100, histogram_update_freq_multiplier=10, histogram_granularity=50):
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
@ -138,6 +139,7 @@ class WeightsGradientHistHelper(object):
|
|||||||
if minmax is None:
|
if minmax is None:
|
||||||
minmax = hist[:, 0].min(), hist[:, 0].max()
|
minmax = hist[:, 0].min(), hist[:, 0].max()
|
||||||
else:
|
else:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
|
minmax = min(minmax[0], hist[:, 0].min()), max(minmax[1], hist[:, 0].max())
|
||||||
|
|
||||||
# update the cache
|
# update the cache
|
||||||
@ -185,6 +187,7 @@ class WeightsGradientHistHelper(object):
|
|||||||
camera=(-0.1, +1.3, 1.4))
|
camera=(-0.1, +1.3, 1.4))
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyMethodMayBeStatic,PyProtectedMember,SpellCheckingInspection
|
||||||
class EventTrainsWriter(object):
|
class EventTrainsWriter(object):
|
||||||
"""
|
"""
|
||||||
TF SummaryWriter implementation that converts the tensorboard's summary into
|
TF SummaryWriter implementation that converts the tensorboard's summary into
|
||||||
@ -347,6 +350,7 @@ class EventTrainsWriter(object):
|
|||||||
image = np.asarray(im)
|
image = np.asarray(im)
|
||||||
output.close()
|
output.close()
|
||||||
if height > 0 and width > 0:
|
if height > 0 and width > 0:
|
||||||
|
# noinspection PyArgumentList
|
||||||
val = image.reshape(height, width, -1).astype(np.uint8)
|
val = image.reshape(height, width, -1).astype(np.uint8)
|
||||||
else:
|
else:
|
||||||
val = image.astype(np.uint8)
|
val = image.astype(np.uint8)
|
||||||
@ -369,6 +373,7 @@ class EventTrainsWriter(object):
|
|||||||
return val
|
return val
|
||||||
|
|
||||||
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
|
def _add_image_numpy(self, tag, step, img_data_np, max_keep_images=None):
|
||||||
|
# type: (str, int, np.ndarray, int) -> ()
|
||||||
# only report images every specific interval
|
# only report images every specific interval
|
||||||
if step % self.image_report_freq != 0:
|
if step % self.image_report_freq != 0:
|
||||||
return None
|
return None
|
||||||
@ -390,6 +395,7 @@ class EventTrainsWriter(object):
|
|||||||
if img_data_np.ndim == 4:
|
if img_data_np.ndim == 4:
|
||||||
dims = img_data_np.shape
|
dims = img_data_np.shape
|
||||||
stack_dim = int(np.sqrt(dims[0]))
|
stack_dim = int(np.sqrt(dims[0]))
|
||||||
|
# noinspection PyArgumentList
|
||||||
res = img_data_np.reshape(stack_dim, stack_dim, *dims[1:]).transpose((0, 2, 1, 3, 4))
|
res = img_data_np.reshape(stack_dim, stack_dim, *dims[1:]).transpose((0, 2, 1, 3, 4))
|
||||||
tile_size = res.shape[0] * res.shape[1]
|
tile_size = res.shape[0] * res.shape[1]
|
||||||
img_data_np = res.reshape(tile_size, tile_size, -1)
|
img_data_np = res.reshape(tile_size, tile_size, -1)
|
||||||
@ -555,6 +561,23 @@ class EventTrainsWriter(object):
|
|||||||
max_history=self.max_keep_images,
|
max_history=self.max_keep_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_text(self, tag, step, tensor_bytes):
|
||||||
|
# noinspection PyProtectedMember
|
||||||
|
title, series = self.tag_splitter(tag, num_split_parts=3, default_title='Text', logdir_header='title',
|
||||||
|
auto_reduce_num_split=True,
|
||||||
|
force_add_prefix=self._logger._get_tensorboard_series_prefix())
|
||||||
|
step = self._fix_step_counter(title, series, step)
|
||||||
|
|
||||||
|
text = tensor_bytes.decode('utf-8', errors='replace')
|
||||||
|
self._logger.report_media(
|
||||||
|
title=title,
|
||||||
|
series=series,
|
||||||
|
iteration=step,
|
||||||
|
stream=six.StringIO(text),
|
||||||
|
file_extension='.txt',
|
||||||
|
max_history=self.max_keep_images,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fix_step_counter(title, series, step):
|
def _fix_step_counter(title, series, step):
|
||||||
key = (title, series)
|
key = (title, series)
|
||||||
@ -573,7 +596,7 @@ class EventTrainsWriter(object):
|
|||||||
wraparound_counter['last_step'] = step
|
wraparound_counter['last_step'] = step
|
||||||
return step + wraparound_counter['adjust_counter']
|
return step + wraparound_counter['adjust_counter']
|
||||||
|
|
||||||
def add_event(self, event, step=None, walltime=None, **kwargs):
|
def add_event(self, event, step=None, walltime=None, **_):
|
||||||
supported_metrics = {
|
supported_metrics = {
|
||||||
'simpleValue', 'image', 'histo', 'tensor', 'audio'
|
'simpleValue', 'image', 'histo', 'tensor', 'audio'
|
||||||
}
|
}
|
||||||
@ -603,6 +626,7 @@ class EventTrainsWriter(object):
|
|||||||
'event summary not found, message type unsupported: %s' % keys_list)
|
'event summary not found, message type unsupported: %s' % keys_list)
|
||||||
return
|
return
|
||||||
value_dicts = summary.get('value')
|
value_dicts = summary.get('value')
|
||||||
|
# noinspection PyUnusedLocal
|
||||||
walltime = walltime or msg_dict.get('step')
|
walltime = walltime or msg_dict.get('step')
|
||||||
step = step or msg_dict.get('step')
|
step = step or msg_dict.get('step')
|
||||||
if step is None:
|
if step is None:
|
||||||
@ -646,9 +670,8 @@ class EventTrainsWriter(object):
|
|||||||
self._generic_tensor_type_name_lookup[tag] = plugin_type
|
self._generic_tensor_type_name_lookup[tag] = plugin_type
|
||||||
self._add_audio(tag, step, None, tensor_bytes)
|
self._add_audio(tag, step, None, tensor_bytes)
|
||||||
elif plugin_type == 'text':
|
elif plugin_type == 'text':
|
||||||
# text, just print to console
|
self._generic_tensor_type_name_lookup[tag] = plugin_type
|
||||||
text = tensor_bytes.decode('utf-8', errors='replace')
|
self._add_text(tag, step, tensor_bytes)
|
||||||
self._logger.report_text(msg='SUMMARY LOG: {} {}'.format(tag, text), print_console=False)
|
|
||||||
else:
|
else:
|
||||||
# we do not support it
|
# we do not support it
|
||||||
pass
|
pass
|
||||||
@ -700,6 +723,7 @@ class EventTrainsWriter(object):
|
|||||||
# ~/torch/utils/tensorboard/summary.py
|
# ~/torch/utils/tensorboard/summary.py
|
||||||
def _clean_tag(name):
|
def _clean_tag(name):
|
||||||
import re as _re
|
import re as _re
|
||||||
|
# noinspection RegExpRedundantEscape
|
||||||
_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')
|
_INVALID_TAG_CHARACTERS = _re.compile(r'[^-/\w\.]')
|
||||||
if name is not None:
|
if name is not None:
|
||||||
new_name = _INVALID_TAG_CHARACTERS.sub('_', name)
|
new_name = _INVALID_TAG_CHARACTERS.sub('_', name)
|
||||||
@ -711,6 +735,7 @@ class EventTrainsWriter(object):
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
main_path = self._logdir
|
main_path = self._logdir
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
main_path = _clean_tag(main_path)
|
main_path = _clean_tag(main_path)
|
||||||
origin_tag = main_path.rpartition("/")[2].replace(title_prefix, "", 1)
|
origin_tag = main_path.rpartition("/")[2].replace(title_prefix, "", 1)
|
||||||
@ -723,6 +748,7 @@ class EventTrainsWriter(object):
|
|||||||
return origin_tag
|
return origin_tag
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyCallingNonCallable
|
||||||
class ProxyEventsWriter(object):
|
class ProxyEventsWriter(object):
|
||||||
def __init__(self, events):
|
def __init__(self, events):
|
||||||
IsTensorboardInit.set_tensorboard_used()
|
IsTensorboardInit.set_tensorboard_used()
|
||||||
@ -771,6 +797,7 @@ class ProxyEventsWriter(object):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyPep8Naming
|
||||||
class PatchSummaryToEventTransformer(object):
|
class PatchSummaryToEventTransformer(object):
|
||||||
__main_task = None
|
__main_task = None
|
||||||
__original_getattribute = None
|
__original_getattribute = None
|
||||||
@ -785,6 +812,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def trains_object(self):
|
def trains_object(self):
|
||||||
if isinstance(self.event_writer, ProxyEventsWriter):
|
if isinstance(self.event_writer, ProxyEventsWriter):
|
||||||
|
# noinspection PyProtectedMember
|
||||||
trains_writer = [e for e in self.event_writer._events if isinstance(e, EventTrainsWriter)]
|
trains_writer = [e for e in self.event_writer._events if isinstance(e, EventTrainsWriter)]
|
||||||
return trains_writer[0] if trains_writer else None
|
return trains_writer[0] if trains_writer else None
|
||||||
elif isinstance(self.event_writer, EventTrainsWriter):
|
elif isinstance(self.event_writer, EventTrainsWriter):
|
||||||
@ -824,6 +852,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
try:
|
try:
|
||||||
# only patch once
|
# only patch once
|
||||||
if PatchSummaryToEventTransformer._original_add_eventT is None:
|
if PatchSummaryToEventTransformer._original_add_eventT is None:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
from torch.utils.tensorboard.writer import FileWriter as FileWriterT
|
from torch.utils.tensorboard.writer import FileWriter as FileWriterT
|
||||||
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
|
PatchSummaryToEventTransformer._original_add_eventT = FileWriterT.add_event
|
||||||
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
|
FileWriterT.add_event = PatchSummaryToEventTransformer._patched_add_eventT
|
||||||
@ -838,6 +867,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
try:
|
try:
|
||||||
# only patch once
|
# only patch once
|
||||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
|
from tensorboardX.writer import SummaryToEventTransformer as SummaryToEventTransformerX
|
||||||
PatchSummaryToEventTransformer.__original_getattributeX = \
|
PatchSummaryToEventTransformer.__original_getattributeX = \
|
||||||
SummaryToEventTransformerX.__getattribute__
|
SummaryToEventTransformerX.__getattribute__
|
||||||
@ -869,6 +899,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventT(self, *args, **kwargs)
|
||||||
if not self.trains:
|
if not self.trains:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
logdir = self.get_logdir()
|
logdir = self.get_logdir()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -887,6 +918,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
if not hasattr(self, 'trains') or not PatchSummaryToEventTransformer.__main_task:
|
||||||
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
return PatchSummaryToEventTransformer._original_add_eventX(self, *args, **kwargs)
|
||||||
if not self.trains:
|
if not self.trains:
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
logdir = self.get_logdir()
|
logdir = self.get_logdir()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -924,6 +956,7 @@ class PatchSummaryToEventTransformer(object):
|
|||||||
|
|
||||||
# patch the events writer field, and add a double Event Logger (Trains and original)
|
# patch the events writer field, and add a double Event Logger (Trains and original)
|
||||||
base_eventwriter = __dict__['event_writer']
|
base_eventwriter = __dict__['event_writer']
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
logdir = base_eventwriter.get_logdir()
|
logdir = base_eventwriter.get_logdir()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -984,6 +1017,7 @@ class PatchModelCheckPointCallback(object):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def trains_object(self):
|
def trains_object(self):
|
||||||
if isinstance(self.model, _ModelAdapter):
|
if isinstance(self.model, _ModelAdapter):
|
||||||
|
# noinspection PyProtectedMember
|
||||||
return self.model._output_model
|
return self.model._output_model
|
||||||
if not self.__dict__.get('_trains_defaults'):
|
if not self.__dict__.get('_trains_defaults'):
|
||||||
self.__dict__['_trains_defaults'] = {}
|
self.__dict__['_trains_defaults'] = {}
|
||||||
@ -1067,6 +1101,7 @@ class PatchModelCheckPointCallback(object):
|
|||||||
return get_base(self, attr)
|
return get_base(self, attr)
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyProtectedMember,PyUnresolvedReferences
|
||||||
class PatchTensorFlowEager(object):
|
class PatchTensorFlowEager(object):
|
||||||
__main_task = None
|
__main_task = None
|
||||||
__original_fn_scalar = None
|
__original_fn_scalar = None
|
||||||
@ -1171,6 +1206,11 @@ class PatchTensorFlowEager(object):
|
|||||||
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
hist_data=tensor.numpy()
|
hist_data=tensor.numpy()
|
||||||
)
|
)
|
||||||
|
elif plugin_type.endswith('text'):
|
||||||
|
event_writer._add_text(
|
||||||
|
tag=str(tag), step=int(step.numpy()) if not isinstance(step, int) else step,
|
||||||
|
tensor_bytes=tensor.numpy()
|
||||||
|
)
|
||||||
elif 'audio' in plugin_type:
|
elif 'audio' in plugin_type:
|
||||||
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
audio_bytes_list = [a for a in tensor.numpy().flatten() if a]
|
||||||
for i, audio_bytes in enumerate(audio_bytes_list):
|
for i, audio_bytes in enumerate(audio_bytes_list):
|
||||||
@ -1241,13 +1281,14 @@ class PatchTensorFlowEager(object):
|
|||||||
max_keep_images=kwargs.get('max_images'))
|
max_keep_images=kwargs.get('max_images'))
|
||||||
|
|
||||||
|
|
||||||
|
# noinspection PyPep8Naming,SpellCheckingInspection
|
||||||
class PatchKerasModelIO(object):
|
class PatchKerasModelIO(object):
|
||||||
__main_task = None
|
__main_task = None
|
||||||
__patched_keras = None
|
__patched_keras = None
|
||||||
__patched_tensorflow = None
|
__patched_tensorflow = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **_):
|
||||||
PatchKerasModelIO.__main_task = task
|
PatchKerasModelIO.__main_task = task
|
||||||
PatchKerasModelIO._patch_model_checkpoint()
|
PatchKerasModelIO._patch_model_checkpoint()
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('tensorflow', PatchKerasModelIO._patch_model_checkpoint)
|
||||||
@ -1314,6 +1355,7 @@ class PatchKerasModelIO(object):
|
|||||||
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
Sequential._updated_config = _patched_call(Sequential._updated_config,
|
||||||
PatchKerasModelIO._updated_config)
|
PatchKerasModelIO._updated_config)
|
||||||
if hasattr(Sequential.from_config, '__func__'):
|
if hasattr(Sequential.from_config, '__func__'):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
|
Sequential.from_config = classmethod(_patched_call(Sequential.from_config.__func__,
|
||||||
PatchKerasModelIO._from_config))
|
PatchKerasModelIO._from_config))
|
||||||
else:
|
else:
|
||||||
@ -1322,6 +1364,7 @@ class PatchKerasModelIO(object):
|
|||||||
if Network is not None:
|
if Network is not None:
|
||||||
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
|
Network._updated_config = _patched_call(Network._updated_config, PatchKerasModelIO._updated_config)
|
||||||
if hasattr(Sequential.from_config, '__func__'):
|
if hasattr(Sequential.from_config, '__func__'):
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
|
Network.from_config = classmethod(_patched_call(Network.from_config.__func__,
|
||||||
PatchKerasModelIO._from_config))
|
PatchKerasModelIO._from_config))
|
||||||
else:
|
else:
|
||||||
@ -1539,7 +1582,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
__patched = None
|
__patched = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **_):
|
||||||
PatchTensorflowModelIO.__main_task = task
|
PatchTensorflowModelIO.__main_task = task
|
||||||
PatchTensorflowModelIO._patch_model_checkpoint()
|
PatchTensorflowModelIO._patch_model_checkpoint()
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflowModelIO._patch_model_checkpoint)
|
||||||
@ -1557,6 +1600,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
try:
|
try:
|
||||||
# hack: make sure tensorflow.__init__ is called
|
# hack: make sure tensorflow.__init__ is called
|
||||||
import tensorflow
|
import tensorflow
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorflow.python.training.saver import Saver
|
from tensorflow.python.training.saver import Saver
|
||||||
# noinspection PyBroadException
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
@ -1605,6 +1649,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa: F811
|
||||||
# actual import
|
# actual import
|
||||||
from tensorflow.saved_model import load # noqa: F401
|
from tensorflow.saved_model import load # noqa: F401
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
import tensorflow.saved_model as saved_model_load
|
import tensorflow.saved_model as saved_model_load
|
||||||
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
|
saved_model_load.load = _patched_call(saved_model_load.load, PatchTensorflowModelIO._load)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -1617,6 +1662,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa: F811
|
||||||
# actual import
|
# actual import
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorflow.saved_model import loader as loader1
|
from tensorflow.saved_model import loader as loader1
|
||||||
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
|
loader1.load = _patched_call(loader1.load, PatchTensorflowModelIO._load)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -1629,6 +1675,7 @@ class PatchTensorflowModelIO(object):
|
|||||||
# make sure we import the correct version of save
|
# make sure we import the correct version of save
|
||||||
import tensorflow # noqa: F811
|
import tensorflow # noqa: F811
|
||||||
# actual import
|
# actual import
|
||||||
|
# noinspection PyUnresolvedReferences
|
||||||
from tensorflow.compat.v1.saved_model import loader as loader2
|
from tensorflow.compat.v1.saved_model import loader as loader2
|
||||||
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
|
loader2.load = _patched_call(loader2.load, PatchTensorflowModelIO._load)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -1772,7 +1819,7 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
__patched = None
|
__patched = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_current_task(task, **kwargs):
|
def update_current_task(task, **_):
|
||||||
PatchTensorflow2ModelIO.__main_task = task
|
PatchTensorflow2ModelIO.__main_task = task
|
||||||
PatchTensorflow2ModelIO._patch_model_checkpoint()
|
PatchTensorflow2ModelIO._patch_model_checkpoint()
|
||||||
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
|
PostImportHookPatching.add_on_import('tensorflow', PatchTensorflow2ModelIO._patch_model_checkpoint)
|
||||||
@ -1812,6 +1859,7 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
def _save(original_fn, self, file_prefix, *args, **kwargs):
|
||||||
model = original_fn(self, file_prefix, *args, **kwargs)
|
model = original_fn(self, file_prefix, *args, **kwargs)
|
||||||
# store output Model
|
# store output Model
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
|
WeightsFileHandler.create_output_model(self, file_prefix, Framework.tensorflow,
|
||||||
PatchTensorflow2ModelIO.__main_task)
|
PatchTensorflow2ModelIO.__main_task)
|
||||||
@ -1827,6 +1875,7 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
# Hack: disabled
|
# Hack: disabled
|
||||||
if False and running_remotely():
|
if False and running_remotely():
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
save_path = WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
PatchTensorflow2ModelIO.__main_task)
|
PatchTensorflow2ModelIO.__main_task)
|
||||||
@ -1838,6 +1887,7 @@ class PatchTensorflow2ModelIO(object):
|
|||||||
# load model, if something is wrong, exception will be raised before we register the input model
|
# load model, if something is wrong, exception will be raised before we register the input model
|
||||||
model = original_fn(self, save_path, *args, **kwargs)
|
model = original_fn(self, save_path, *args, **kwargs)
|
||||||
# register/load model weights
|
# register/load model weights
|
||||||
|
# noinspection PyBroadException
|
||||||
try:
|
try:
|
||||||
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
WeightsFileHandler.restore_weights_file(self, save_path, Framework.tensorflow,
|
||||||
PatchTensorflow2ModelIO.__main_task)
|
PatchTensorflow2ModelIO.__main_task)
|
||||||
|
@ -828,7 +828,7 @@ class Logger(object):
|
|||||||
series, # type: str
|
series, # type: str
|
||||||
iteration, # type: int
|
iteration, # type: int
|
||||||
local_path=None, # type: Optional[str]
|
local_path=None, # type: Optional[str]
|
||||||
stream=None, # type: Optional[six.BytesIO]
|
stream=None, # type: Optional[Union[six.BytesIO, six.StringIO]]
|
||||||
file_extension=None, # type: Optional[str]
|
file_extension=None, # type: Optional[str]
|
||||||
max_history=None, # type: Optional[int]
|
max_history=None, # type: Optional[int]
|
||||||
delete_after_upload=False, # type: bool
|
delete_after_upload=False, # type: bool
|
||||||
|
Loading…
Reference in New Issue
Block a user