mirror of
https://github.com/clearml/clearml
synced 2025-04-10 07:26:03 +00:00
Restructured Logger with nice clean interface.
Breaking changes: Logger no longer supports info/error/warning console() replaced with report_text()
This commit is contained in:
parent
1dbe962879
commit
0b4f00af4d
@ -21,7 +21,7 @@ except ImportError:
|
||||
logger = Task.current_task().get_logger()
|
||||
|
||||
# log text
|
||||
logger.console("hello")
|
||||
logger.report_text("hello")
|
||||
|
||||
# report scalar values
|
||||
logger.report_scalar("example_scalar", "series A", iteration=0, value=100)
|
||||
@ -49,11 +49,11 @@ logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter
|
||||
|
||||
# reporting images
|
||||
m = np.eye(256, 256, dtype=np.float)
|
||||
logger.report_image_and_upload("test case", "image float", iteration=1, matrix=m)
|
||||
logger.report_image("test case", "image float", iteration=1, matrix=m)
|
||||
m = np.eye(256, 256, dtype=np.uint8)*255
|
||||
logger.report_image_and_upload("test case", "image uint8", iteration=1, matrix=m)
|
||||
logger.report_image("test case", "image uint8", iteration=1, matrix=m)
|
||||
m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)
|
||||
logger.report_image_and_upload("test case", "image color red", iteration=1, matrix=m)
|
||||
logger.report_image("test case", "image color red", iteration=1, matrix=m)
|
||||
|
||||
# flush reports (otherwise it will be flushed in the background, every couple of seconds)
|
||||
logger.flush()
|
||||
|
@ -81,7 +81,7 @@ class CallResult(object):
|
||||
# response.validate()
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warn('Failed parsing response: %s' % str(e))
|
||||
logger.warning('Failed parsing response: %s' % str(e))
|
||||
return cls(meta=meta, response=response, response_data=response_data, request_cls=request_cls, session=session)
|
||||
|
||||
def ok(self):
|
||||
|
@ -215,7 +215,7 @@ class Session(TokenManager):
|
||||
res.status_code == requests.codes.service_unavailable
|
||||
and self.config.get("api.http.wait_on_maintenance_forever", True)
|
||||
):
|
||||
self._logger.warn(
|
||||
self._logger.warning(
|
||||
"Service unavailable: {} is undergoing maintenance, retrying...".format(
|
||||
host
|
||||
)
|
||||
|
@ -50,7 +50,7 @@ class S3BucketConfig(object):
|
||||
configs = [cls(**entry) for entry in dict_list]
|
||||
valid_configs = [conf for conf in configs if conf.is_valid()]
|
||||
if log and len(valid_configs) < len(configs):
|
||||
log.warn(
|
||||
log.warning(
|
||||
"Invalid bucket configurations detected for {}".format(
|
||||
", ".join(
|
||||
"/".join((config.host, config.bucket))
|
||||
|
212
trains/backend_interface/logger.py
Normal file
212
trains/backend_interface/logger.py
Normal file
@ -0,0 +1,212 @@
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from ..backend_interface.task.development.worker import DevWorker
|
||||
from ..backend_interface.task.log import TaskHandler
|
||||
from ..config import running_remotely
|
||||
|
||||
|
||||
class StdStreamPatch(object):
|
||||
_stdout_proxy = None
|
||||
_stderr_proxy = None
|
||||
_stdout_original_write = None
|
||||
|
||||
@staticmethod
|
||||
def patch_std_streams(logger):
|
||||
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
|
||||
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO)
|
||||
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR)
|
||||
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if StdStreamPatch._stdout_original_write is None:
|
||||
StdStreamPatch._stdout_original_write = sys.stdout.write
|
||||
# this will only work in python 3, guard it with try/catch
|
||||
if not hasattr(sys.stdout, '_original_write'):
|
||||
sys.stdout._original_write = sys.stdout.write
|
||||
sys.stdout.write = StdStreamPatch._stdout__patched__write__
|
||||
if not hasattr(sys.stderr, '_original_write'):
|
||||
sys.stderr._original_write = sys.stderr.write
|
||||
sys.stderr.write = StdStreamPatch._stderr__patched__write__
|
||||
except Exception:
|
||||
pass
|
||||
sys.stdout = StdStreamPatch._stdout_proxy
|
||||
sys.stderr = StdStreamPatch._stderr_proxy
|
||||
# patch the base streams of sys (this way colorama will keep its ANSI colors)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.__stderr__ = sys.stderr
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.__stdout__ = sys.stdout
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# now check if we have loguru and make it re-register the handlers
|
||||
# because it sores internally the stream.write function, which we cant patch
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from loguru import logger
|
||||
register_stderr = None
|
||||
register_stdout = None
|
||||
for k, v in logger._handlers.items():
|
||||
if v._name == '<stderr>':
|
||||
register_stderr = k
|
||||
elif v._name == '<stdout>':
|
||||
register_stderr = k
|
||||
if register_stderr is not None:
|
||||
logger.remove(register_stderr)
|
||||
logger.add(sys.stderr)
|
||||
if register_stdout is not None:
|
||||
logger.remove(register_stdout)
|
||||
logger.add(sys.stdout)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elif DevWorker.report_stdout and not running_remotely():
|
||||
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
|
||||
if StdStreamPatch._stdout_proxy:
|
||||
StdStreamPatch._stdout_proxy.connect(logger)
|
||||
if StdStreamPatch._stderr_proxy:
|
||||
StdStreamPatch._stderr_proxy.connect(logger)
|
||||
|
||||
@staticmethod
|
||||
def remove_std_logger():
|
||||
if isinstance(sys.stdout, PrintPatchLogger):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.stdout.connect(None)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(sys.stderr, PrintPatchLogger):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.stderr.connect(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def stdout_original_write(*args, **kwargs):
|
||||
if StdStreamPatch._stdout_original_write:
|
||||
StdStreamPatch._stdout_original_write(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _stdout__patched__write__(*args, **kwargs):
|
||||
if StdStreamPatch._stdout_proxy:
|
||||
return StdStreamPatch._stdout_proxy.write(*args, **kwargs)
|
||||
return sys.stdout._original_write(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _stderr__patched__write__(*args, **kwargs):
|
||||
if StdStreamPatch._stderr_proxy:
|
||||
return StdStreamPatch._stderr_proxy.write(*args, **kwargs)
|
||||
return sys.stderr._original_write(*args, **kwargs)
|
||||
|
||||
|
||||
class PrintPatchLogger(object):
|
||||
"""
|
||||
Allowed patching a stream into the logger.
|
||||
Used for capturing and logging stdin and stderr when running in development mode pseudo worker.
|
||||
"""
|
||||
patched = False
|
||||
lock = threading.Lock()
|
||||
recursion_protect_lock = threading.RLock()
|
||||
|
||||
def __init__(self, stream, logger=None, level=logging.INFO):
|
||||
PrintPatchLogger.patched = True
|
||||
self._terminal = stream
|
||||
self._log = logger
|
||||
self._log_level = level
|
||||
self._cur_line = ''
|
||||
|
||||
def write(self, message):
|
||||
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
|
||||
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
|
||||
try:
|
||||
self.lock.acquire()
|
||||
with PrintPatchLogger.recursion_protect_lock:
|
||||
if hasattr(self._terminal, '_original_write'):
|
||||
self._terminal._original_write(message)
|
||||
else:
|
||||
self._terminal.write(message)
|
||||
|
||||
do_flush = '\n' in message
|
||||
do_cr = '\r' in message
|
||||
self._cur_line += message
|
||||
if (not do_flush and not do_cr) or not message:
|
||||
return
|
||||
last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
|
||||
next_line = self._cur_line[last_lf + 1:]
|
||||
cur_line = self._cur_line[:last_lf + 1].rstrip()
|
||||
self._cur_line = next_line
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
||||
if cur_line:
|
||||
with PrintPatchLogger.recursion_protect_lock:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self._log:
|
||||
self._log._console(cur_line, level=self._log_level, omit_console=True)
|
||||
except Exception:
|
||||
# what can we do, nothing
|
||||
pass
|
||||
else:
|
||||
if hasattr(self._terminal, '_original_write'):
|
||||
self._terminal._original_write(message)
|
||||
else:
|
||||
self._terminal.write(message)
|
||||
|
||||
def connect(self, logger):
|
||||
self._cur_line = ''
|
||||
self._log = logger
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
|
||||
return self.__dict__.get(attr)
|
||||
return getattr(self._terminal, attr)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in ['_log', '_terminal', '_log_level', '_cur_line']:
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
return setattr(self._terminal, key, value)
|
||||
|
||||
|
||||
class LogFlusher(threading.Thread):
|
||||
def __init__(self, logger, period, **kwargs):
|
||||
super(LogFlusher, self).__init__(**kwargs)
|
||||
self.daemon = True
|
||||
|
||||
self._period = period
|
||||
self._logger = logger
|
||||
self._exit_event = threading.Event()
|
||||
|
||||
@property
|
||||
def period(self):
|
||||
return self._period
|
||||
|
||||
def run(self):
|
||||
self._logger.flush()
|
||||
# store original wait period
|
||||
while True:
|
||||
period = self._period
|
||||
while not self._exit_event.wait(period or 1.0):
|
||||
self._logger.flush()
|
||||
# check if period is negative or None we should exit
|
||||
if self._period is None or self._period < 0:
|
||||
break
|
||||
# check if period was changed, we should restart
|
||||
self._exit_event.clear()
|
||||
|
||||
def exit(self):
|
||||
self._period = None
|
||||
self._exit_event.set()
|
||||
|
||||
def set_period(self, period):
|
||||
self._period = period
|
||||
# make sure we exit the previous wait
|
||||
self._exit_event.set()
|
@ -177,6 +177,8 @@ class UploadEvent(MetricsEventAdapter):
|
||||
|
||||
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
|
||||
image_file_history_size=None, delete_after_upload=False, **kwargs):
|
||||
# param override_filename: override uploaded file name (notice extension will be added from local path
|
||||
# param override_filename_ext: override uploaded file extension
|
||||
if image_data is not None and not hasattr(image_data, 'shape'):
|
||||
raise ValueError('Image must have a shape attribute')
|
||||
self._image_data = image_data
|
||||
@ -197,8 +199,10 @@ class UploadEvent(MetricsEventAdapter):
|
||||
|
||||
# get upload uri upfront, either predefined image format or local file extension
|
||||
# e.g.: image.png -> .png or image.raw.gz -> .raw.gz
|
||||
image_format = self._format.lower() if self._image_data is not None else \
|
||||
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
|
||||
image_format = kwargs.pop('override_filename_ext', None)
|
||||
if image_format is None:
|
||||
image_format = self._format.lower() if self._image_data is not None else \
|
||||
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
|
||||
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
|
||||
|
||||
self._override_storage_key_prefix = kwargs.pop('override_storage_key_prefix', None)
|
||||
|
@ -80,6 +80,6 @@ class AccessMixin(object):
|
||||
expected_num_of_classes += 1 if int(index) > 0 else 0
|
||||
num_of_classes = int(max(model_labels.values()))
|
||||
if num_of_classes != expected_num_of_classes:
|
||||
self.log.warn('The highest label index is %d, while there are %d non-bg labels' %
|
||||
(num_of_classes, expected_num_of_classes))
|
||||
self.log.warning('The highest label index is %d, while there are %d non-bg labels' %
|
||||
(num_of_classes, expected_num_of_classes))
|
||||
return num_of_classes + 1 # +1 is meant for bg!
|
||||
|
@ -51,7 +51,7 @@ class TaskStopSignal(object):
|
||||
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:
|
||||
return TaskStopReason.reset
|
||||
|
||||
self.task.get_logger().warning(
|
||||
self.task.log.warning(
|
||||
"Task {} was reset! if state is consistent we shall terminate.".format(self.task.id),
|
||||
)
|
||||
else:
|
||||
|
@ -160,7 +160,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
conf = get_config_for_bucket(base_url=output_dest)
|
||||
if not conf:
|
||||
msg = 'Failed resolving output destination (no credentials found for %s)' % output_dest
|
||||
self.log.warn(msg)
|
||||
self.log.warning(msg)
|
||||
if raise_errors:
|
||||
raise Exception(msg)
|
||||
else:
|
||||
@ -187,12 +187,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
|
||||
if latest_version:
|
||||
if not latest_version[1]:
|
||||
self.get_logger().console(
|
||||
self.get_logger().report_text(
|
||||
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
|
||||
latest_version[0]),
|
||||
)
|
||||
else:
|
||||
self.get_logger().console(
|
||||
self.get_logger().report_text(
|
||||
'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
|
||||
latest_version[0]),
|
||||
)
|
||||
@ -205,7 +205,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
check_package_update_thread.start()
|
||||
result = ScriptInfo.get(log=self.log)
|
||||
for msg in result.warning_messages:
|
||||
self.get_logger().console(msg)
|
||||
self.get_logger().report_text(msg)
|
||||
|
||||
self.data.script = result.script
|
||||
# Since we might run asynchronously, don't use self.data (lest someone else
|
||||
@ -418,16 +418,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
|
||||
def update_model_desc(self, new_model_desc_file=None):
|
||||
""" Change the task's model_desc """
|
||||
execution = self._get_task_property('execution')
|
||||
p = Path(new_model_desc_file)
|
||||
if not p.is_file():
|
||||
raise IOError('mode_desc file %s cannot be found' % new_model_desc_file)
|
||||
new_model_desc = p.read_text()
|
||||
model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design'
|
||||
execution.model_desc[model_desc_key] = new_model_desc
|
||||
with self._edit_lock:
|
||||
execution = self._get_task_property('execution')
|
||||
p = Path(new_model_desc_file)
|
||||
if not p.is_file():
|
||||
raise IOError('mode_desc file %s cannot be found' % new_model_desc_file)
|
||||
new_model_desc = p.read_text()
|
||||
model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design'
|
||||
execution.model_desc[model_desc_key] = new_model_desc
|
||||
|
||||
res = self._edit(execution=execution)
|
||||
return res.response
|
||||
res = self._edit(execution=execution)
|
||||
return res.response
|
||||
|
||||
def update_output_model(self, model_uri, name=None, comment=None, tags=None):
|
||||
"""
|
||||
@ -536,16 +537,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
model = None
|
||||
model_id = ''
|
||||
|
||||
# store model id
|
||||
self.data.execution.model = model_id
|
||||
with self._edit_lock:
|
||||
# store model id
|
||||
self.data.execution.model = model_id
|
||||
|
||||
# Auto populate input field from model, if they are empty
|
||||
if update_task_design and not self.data.execution.model_desc:
|
||||
self.data.execution.model_desc = model.design if model else ''
|
||||
if update_task_labels and not self.data.execution.model_labels:
|
||||
self.data.execution.model_labels = model.labels if model else {}
|
||||
# Auto populate input field from model, if they are empty
|
||||
if update_task_design and not self.data.execution.model_desc:
|
||||
self.data.execution.model_desc = model.design if model else ''
|
||||
if update_task_labels and not self.data.execution.model_labels:
|
||||
self.data.execution.model_labels = model.labels if model else {}
|
||||
|
||||
self._edit(execution=self.data.execution)
|
||||
self._edit(execution=self.data.execution)
|
||||
|
||||
def set_parameters(self, *args, **kwargs):
|
||||
"""
|
||||
@ -580,12 +582,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
# force cast all variables to strings (so that we can later edit them in UI)
|
||||
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
|
||||
|
||||
execution = self.data.execution
|
||||
if execution is None:
|
||||
execution = tasks.Execution(parameters=parameters)
|
||||
else:
|
||||
execution.parameters = parameters
|
||||
self._edit(execution=execution)
|
||||
with self._edit_lock:
|
||||
execution = self.data.execution
|
||||
if execution is None:
|
||||
execution = tasks.Execution(parameters=parameters)
|
||||
else:
|
||||
execution.parameters = parameters
|
||||
self._edit(execution=execution)
|
||||
|
||||
def set_parameter(self, name, value, description=None):
|
||||
"""
|
||||
@ -630,14 +633,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
:param dict enumeration: For example: {str(label): integer(id)}
|
||||
"""
|
||||
enumeration = enumeration or {}
|
||||
execution = self.data.execution
|
||||
if enumeration is None:
|
||||
return
|
||||
if not (isinstance(enumeration, dict)
|
||||
and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())):
|
||||
raise ValueError('Expected label to be a dict[str => int]')
|
||||
execution.model_labels = enumeration
|
||||
self._edit(execution=execution)
|
||||
with self._edit_lock:
|
||||
execution = self.data.execution
|
||||
if enumeration is None:
|
||||
return
|
||||
if not (isinstance(enumeration, dict)
|
||||
and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())):
|
||||
raise ValueError('Expected label to be a dict[str => int]')
|
||||
execution.model_labels = enumeration
|
||||
self._edit(execution=execution)
|
||||
|
||||
def set_artifacts(self, artifacts_list=None):
|
||||
"""
|
||||
@ -650,16 +654,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if not (isinstance(artifacts_list, (list, tuple))
|
||||
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
|
||||
raise ValueError('Expected artifacts to [tasks.Artifacts]')
|
||||
execution = self.data.execution
|
||||
execution.artifacts = artifacts_list
|
||||
self._edit(execution=execution)
|
||||
with self._edit_lock:
|
||||
execution = self.data.execution
|
||||
execution.artifacts = artifacts_list
|
||||
self._edit(execution=execution)
|
||||
|
||||
def _set_model_design(self, design=None):
|
||||
execution = self.data.execution
|
||||
if design is not None:
|
||||
execution.model_desc = Model._wrap_design(design)
|
||||
with self._edit_lock:
|
||||
execution = self.data.execution
|
||||
if design is not None:
|
||||
execution.model_desc = Model._wrap_design(design)
|
||||
|
||||
self._edit(execution=execution)
|
||||
self._edit(execution=execution)
|
||||
|
||||
def get_labels_enumeration(self):
|
||||
"""
|
||||
|
@ -36,8 +36,8 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
|
||||
log = get_logger()
|
||||
|
||||
if len(results) > 1:
|
||||
log.warn('More than one {entity} found when searching for `{query}`'
|
||||
' (showing first {show_results} {entity}s follow)'.format(**locals()))
|
||||
log.warning('More than one {entity} found when searching for `{query}`'
|
||||
' (showing first {show_results} {entity}s follow)'.format(**locals()))
|
||||
if sort_by_date:
|
||||
# sort results based on timestamp and return the newest one
|
||||
if hasattr(results[0], 'last_update'):
|
||||
@ -49,7 +49,7 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
|
||||
|
||||
for i, obj in enumerate(o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
|
||||
selected = 'Selected' if i == 0 else 'Additionally found'
|
||||
log.warn('{selected} {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
|
||||
log.warning('{selected} {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
|
||||
|
||||
if raise_on_error:
|
||||
raise ValueError('More than one {entity}s found when searching for ``{query}`'.format(**locals()))
|
||||
|
@ -5,13 +5,13 @@ import threading
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from logging import ERROR, WARNING, getLogger
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
|
||||
from ...debugging.log import LoggerRoot
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...config import running_remotely
|
||||
from ...model import InputModel, OutputModel, Framework
|
||||
@ -187,7 +187,8 @@ class EventTrainsWriter(object):
|
||||
else:
|
||||
val = val[:, :, [0, 1, 2]]
|
||||
except Exception:
|
||||
self._logger.warning('Failed decoding debug image [%d, %d, %d]' % (width, height, color_channels))
|
||||
LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]'
|
||||
% (width, height, color_channels))
|
||||
val = None
|
||||
return val
|
||||
|
||||
@ -213,7 +214,7 @@ class EventTrainsWriter(object):
|
||||
tile_size = res.shape[0] * res.shape[1]
|
||||
img_data_np = res.reshape(tile_size, tile_size, -1)
|
||||
|
||||
self._logger.report_image_and_upload(
|
||||
self._logger.report_image(
|
||||
title=title,
|
||||
series=series,
|
||||
iteration=step,
|
||||
@ -419,7 +420,7 @@ class EventTrainsWriter(object):
|
||||
msg_dict.pop('wallTime', None)
|
||||
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
|
||||
keys_list = ', '.join(keys_list)
|
||||
self._logger.debug('event summary not found, message type unsupported: %s' % keys_list)
|
||||
LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list)
|
||||
return
|
||||
value_dicts = summary.get('value')
|
||||
walltime = walltime or msg_dict.get('step')
|
||||
@ -431,19 +432,20 @@ class EventTrainsWriter(object):
|
||||
step = int(event.step)
|
||||
else:
|
||||
step = 0
|
||||
self._logger.debug('Recieved event without step, assuming step = {}'.format(step), WARNING)
|
||||
LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step))
|
||||
else:
|
||||
step = int(step)
|
||||
self._max_step = max(self._max_step, step)
|
||||
if value_dicts is None:
|
||||
self._logger.debug("Summary with arrived without 'value'", ERROR)
|
||||
LoggerRoot.get_base_logger().debug("Summary arrived without 'value'")
|
||||
return
|
||||
|
||||
for vdict in value_dicts:
|
||||
tag = vdict.pop('tag', None)
|
||||
if tag is None:
|
||||
# we should not get here
|
||||
self._logger.debug('No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
|
||||
LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s'
|
||||
% ', '.join(vdict.keys()))
|
||||
continue
|
||||
metric, values = get_data(vdict, supported_metrics)
|
||||
if metric == 'simpleValue':
|
||||
@ -459,7 +461,8 @@ class EventTrainsWriter(object):
|
||||
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
|
||||
self._add_plot(tag, step, values, vdict)
|
||||
else:
|
||||
self._logger.debug('Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys)))
|
||||
LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]'
|
||||
% (tag, ', '.join(vdict.keys)))
|
||||
continue
|
||||
|
||||
def get_logdir(self):
|
||||
@ -589,7 +592,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
setattr(SummaryToEventTransformer, 'trains',
|
||||
property(PatchSummaryToEventTransformer.trains_object))
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
|
||||
if 'torch' in sys.modules:
|
||||
try:
|
||||
@ -603,7 +606,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# this is a new version of TensorflowX
|
||||
pass
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
|
||||
if 'tensorboardX' in sys.modules:
|
||||
try:
|
||||
@ -619,7 +622,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# this is a new version of TensorflowX
|
||||
pass
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
|
||||
if PatchSummaryToEventTransformer.__original_getattributeX is None:
|
||||
try:
|
||||
@ -633,7 +636,7 @@ class PatchSummaryToEventTransformer(object):
|
||||
# this is a new version of TensorflowX
|
||||
pass
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _patched_add_eventT(self, *args, **kwargs):
|
||||
@ -717,7 +720,7 @@ class _ModelAdapter(object):
|
||||
super(_ModelAdapter, self).__init__()
|
||||
super(_ModelAdapter, self).__setattr__('_model', model)
|
||||
super(_ModelAdapter, self).__setattr__('_output_model', output_model)
|
||||
super(_ModelAdapter, self).__setattr__('_logger', getLogger('TrainsModelAdapter'))
|
||||
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger())
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self._model, attr)
|
||||
@ -800,7 +803,7 @@ class PatchModelCheckPointCallback(object):
|
||||
property(PatchModelCheckPointCallback.trains_object))
|
||||
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _patched_getattribute(self, attr):
|
||||
@ -878,7 +881,7 @@ class PatchTensorFlowEager(object):
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).debug(str(ex))
|
||||
LoggerRoot.get_base_logger().debug(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _get_event_writer(writer):
|
||||
@ -905,7 +908,7 @@ class PatchTensorFlowEager(object):
|
||||
try:
|
||||
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@ -915,7 +918,7 @@ class PatchTensorFlowEager(object):
|
||||
try:
|
||||
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@ -926,7 +929,7 @@ class PatchTensorFlowEager(object):
|
||||
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
|
||||
max_keep_images=max_images)
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
|
||||
**kwargs)
|
||||
|
||||
@ -1024,7 +1027,7 @@ class PatchKerasModelIO(object):
|
||||
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
|
||||
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _updated_config(original_fn, self):
|
||||
@ -1052,7 +1055,7 @@ class PatchKerasModelIO(object):
|
||||
framework=Framework.keras,
|
||||
)
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
|
||||
return config
|
||||
|
||||
@ -1102,7 +1105,7 @@ class PatchKerasModelIO(object):
|
||||
return model
|
||||
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
|
||||
return self
|
||||
|
||||
@ -1184,7 +1187,7 @@ class PatchKerasModelIO(object):
|
||||
# if anyone asks, we were here
|
||||
self.trains_out_model._processed = True
|
||||
except Exception as ex:
|
||||
getLogger(TrainsFrameworkAdapter).warning(str(ex))
|
||||
LoggerRoot.get_base_logger().warning(str(ex))
|
||||
|
||||
@staticmethod
|
||||
def _save_model(original_fn, model, filepath, *args, **kwargs):
|
||||
|
@ -329,19 +329,19 @@ class PatchedMatplotlib:
|
||||
PatchedMatplotlib._global_image_counter += 1
|
||||
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
|
||||
|
||||
logger.report_image_and_upload(title=title, series='plot image', path=image,
|
||||
delete_after_upload=True,
|
||||
iteration=PatchedMatplotlib._global_image_counter
|
||||
if plot_title else 0)
|
||||
logger.report_image(title=title, series='plot image', local_path=image,
|
||||
delete_after_upload=True,
|
||||
iteration=PatchedMatplotlib._global_image_counter
|
||||
if plot_title else 0)
|
||||
else:
|
||||
# send the plot as plotly with embedded image
|
||||
PatchedMatplotlib._global_plot_counter += 1
|
||||
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_plot_counter
|
||||
|
||||
logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
|
||||
delete_after_upload=True,
|
||||
iteration=PatchedMatplotlib._global_plot_counter
|
||||
if plot_title else 0)
|
||||
logger._report_image_plot_and_upload(title=title, series='plot image', path=image,
|
||||
delete_after_upload=True,
|
||||
iteration=PatchedMatplotlib._global_plot_counter
|
||||
if plot_title else 0)
|
||||
|
||||
except Exception:
|
||||
# plotly failed
|
||||
|
@ -18,7 +18,7 @@ def get_cache_dir():
|
||||
cache_base_dir = Path(
|
||||
expandvars(
|
||||
expanduser(
|
||||
config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR
|
||||
TRAINS_CACHE_DIR.get() or config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR
|
||||
)
|
||||
)
|
||||
)
|
||||
|
@ -7,7 +7,6 @@ from pathlib2 import Path
|
||||
SESSION_CACHE_FILE = ".session.json"
|
||||
DEFAULT_CACHE_DIR = str(Path(tempfile.gettempdir()) / "trains_cache")
|
||||
|
||||
|
||||
TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID")
|
||||
LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool)
|
||||
NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int)
|
||||
@ -16,6 +15,7 @@ LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LO
|
||||
DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME")
|
||||
DEV_TASK_NO_REUSE = EnvEntry("TRAINS_TASK_NO_REUSE", "ALG_TASK_NO_REUSE", type=bool)
|
||||
TASK_LOG_ENVIRONMENT = EnvEntry("TRAINS_LOG_ENVIRONMENT", "ALG_LOG_ENVIRONMENT", type=str)
|
||||
TRAINS_CACHE_DIR = EnvEntry("TRAINS_CACHE_DIR", "ALG_CACHE_DIR")
|
||||
|
||||
LOG_LEVEL_ENV_VAR = EnvEntry("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL", converter=or_(int, str))
|
||||
|
||||
|
749
trains/logger.py
749
trains/logger.py
@ -1,12 +1,9 @@
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
from pathlib2 import Path
|
||||
|
||||
from .backend_interface.logger import StdStreamPatch, LogFlusher
|
||||
from .debugging.log import LoggerRoot
|
||||
from .backend_interface.task.development.worker import DevWorker
|
||||
from .backend_interface.task.log import TaskHandler
|
||||
@ -17,36 +14,6 @@ from .backend_interface.task import Task as _Task
|
||||
from .config import running_remotely, get_cache_dir
|
||||
|
||||
|
||||
def _safe_names(func):
|
||||
"""
|
||||
Validate the form of title and series parameters.
|
||||
|
||||
This decorator assert that a method receives 'title' and 'series' as its
|
||||
first positional arguments, and that their values have only legal characters.
|
||||
|
||||
'\', '/' and ':' will be replaced automatically by '_'
|
||||
Whitespace chars will be replaced automatically by ' '
|
||||
"""
|
||||
_replacements = {
|
||||
'_': re.compile(r"[/\\:]"),
|
||||
' ': re.compile(r"[\s]"),
|
||||
}
|
||||
|
||||
def _make_safe(value):
|
||||
for repl, regex in _replacements.items():
|
||||
value = regex.sub(repl, value)
|
||||
return value
|
||||
|
||||
@wraps(func)
|
||||
def fixed_names(self, title, series, *args, **kwargs):
|
||||
title = _make_safe(title)
|
||||
series = _make_safe(series)
|
||||
|
||||
func(self, title, series, *args, **kwargs)
|
||||
|
||||
return fixed_names
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""
|
||||
Console log and metric statistics interface.
|
||||
@ -56,9 +23,6 @@ class Logger(object):
|
||||
**Usage:** :func:`Logger.current_logger` or :func:`Task.get_logger`
|
||||
"""
|
||||
SeriesInfo = SeriesInfo
|
||||
_stdout_proxy = None
|
||||
_stderr_proxy = None
|
||||
_stdout_original_write = None
|
||||
|
||||
def __init__(self, private_task):
|
||||
"""
|
||||
@ -75,67 +39,11 @@ class Logger(object):
|
||||
self._report_worker = None
|
||||
self._task_handler = None
|
||||
|
||||
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
|
||||
Logger._stdout_proxy = PrintPatchLogger(sys.stdout, self, level=logging.INFO)
|
||||
Logger._stderr_proxy = PrintPatchLogger(sys.stderr, self, level=logging.ERROR)
|
||||
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if Logger._stdout_original_write is None:
|
||||
Logger._stdout_original_write = sys.stdout.write
|
||||
# this will only work in python 3, guard it with try/catch
|
||||
if not hasattr(sys.stdout, '_original_write'):
|
||||
sys.stdout._original_write = sys.stdout.write
|
||||
sys.stdout.write = stdout__patched__write__
|
||||
if not hasattr(sys.stderr, '_original_write'):
|
||||
sys.stderr._original_write = sys.stderr.write
|
||||
sys.stderr.write = stderr__patched__write__
|
||||
except Exception:
|
||||
pass
|
||||
sys.stdout = Logger._stdout_proxy
|
||||
sys.stderr = Logger._stderr_proxy
|
||||
# patch the base streams of sys (this way colorama will keep its ANSI colors)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.__stderr__ = sys.stderr
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.__stdout__ = sys.stdout
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# now check if we have loguru and make it re-register the handlers
|
||||
# because it sores internally the stream.write function, which we cant patch
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from loguru import logger
|
||||
register_stderr = None
|
||||
register_stdout = None
|
||||
for k, v in logger._handlers.items():
|
||||
if v._name == '<stderr>':
|
||||
register_stderr = k
|
||||
elif v._name == '<stdout>':
|
||||
register_stderr = k
|
||||
if register_stderr is not None:
|
||||
logger.remove(register_stderr)
|
||||
logger.add(sys.stderr)
|
||||
if register_stdout is not None:
|
||||
logger.remove(register_stdout)
|
||||
logger.add(sys.stdout)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elif DevWorker.report_stdout and not running_remotely():
|
||||
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
|
||||
if Logger._stdout_proxy:
|
||||
Logger._stdout_proxy.connect(self)
|
||||
if Logger._stderr_proxy:
|
||||
Logger._stderr_proxy.connect(self)
|
||||
StdStreamPatch.patch_std_streams(self)
|
||||
|
||||
@classmethod
|
||||
def current_logger(cls):
|
||||
# type: () -> Logger
|
||||
"""
|
||||
Return a logger object for the current task. Can be called from anywhere in the code
|
||||
|
||||
@ -147,92 +55,24 @@ class Logger(object):
|
||||
return None
|
||||
return task.get_logger()
|
||||
|
||||
def console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs):
|
||||
def report_text(self, msg, level=logging.INFO, print_console=True, *args, **_):
|
||||
"""
|
||||
print text to log (same as print to console, and also prints to console)
|
||||
print text to log and optionally also prints to console
|
||||
|
||||
:param msg: text to print to the console (always send to the backend and displayed in console)
|
||||
:param level: logging level, default: logging.INFO
|
||||
:param omit_console: If True we only send 'msg' to log (no console print)
|
||||
:param str msg: text to print to the console (always send to the backend and displayed in console)
|
||||
:param int level: logging level, default: logging.INFO
|
||||
:param bool print_console: If True we also print 'msg' to console
|
||||
"""
|
||||
try:
|
||||
level = int(level)
|
||||
except (TypeError, ValueError):
|
||||
self._task.log.log(level=logging.ERROR,
|
||||
msg='Logger failed casting log level "%s" to integer' % str(level))
|
||||
level = logging.INFO
|
||||
|
||||
if not running_remotely():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
record = self._task.log.makeRecord(
|
||||
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
|
||||
)
|
||||
# find the task handler that matches our task
|
||||
if not self._task_handler:
|
||||
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers
|
||||
if isinstance(h, TaskHandler) and h.task_id == self._task.id][0]
|
||||
self._task_handler.emit(record)
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger().warning(msg='Logger failed sending log: [level %s]: "%s"'
|
||||
% (str(level), str(msg)))
|
||||
|
||||
if not omit_console:
|
||||
# if we are here and we grabbed the stdout, we need to print the real thing
|
||||
if DevWorker.report_stdout and not running_remotely():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we are writing to the original stdout
|
||||
Logger._stdout_original_write(str(msg)+'\n')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(str(msg))
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
|
||||
def report_text(self, msg, level=logging.INFO, print_console=False, *args, **_):
|
||||
return self.console(msg, level, not print_console, *args, **_)
|
||||
|
||||
def debug(self, msg, *args, **kwargs):
|
||||
""" Print information to the log. This is the same as console(msg, logging.DEBUG) """
|
||||
self._task.log.log(msg=msg, level=logging.DEBUG, *args, **kwargs)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
""" Print information to the log. This is the same as console(msg, logging.INFO) """
|
||||
self._task.log.log(msg=msg, level=logging.INFO, *args, **kwargs)
|
||||
|
||||
def warn(self, msg, *args, **kwargs):
|
||||
""" Print a warning to the log. This is the same as console(msg, logging.WARNING) """
|
||||
self._task.log.log(msg=msg, level=logging.WARNING, *args, **kwargs)
|
||||
|
||||
warning = warn
|
||||
|
||||
def error(self, msg, *args, **kwargs):
|
||||
""" Print an error to the log. This is the same as console(msg, logging.ERROR) """
|
||||
self._task.log.log(msg=msg, level=logging.ERROR, *args, **kwargs)
|
||||
|
||||
def fatal(self, msg, *args, **kwargs):
|
||||
""" Print a fatal error to the log. This is the same as console(msg, logging.FATAL) """
|
||||
self._task.log.log(msg=msg, level=logging.FATAL, *args, **kwargs)
|
||||
|
||||
def critical(self, msg, *args, **kwargs):
|
||||
""" Print a critical error to the log. This is the same as console(msg, logging.CRITICAL) """
|
||||
self._task.log.log(msg=msg, level=logging.CRITICAL, *args, **kwargs)
|
||||
return self._console(msg, level, not print_console, *args, **_)
|
||||
|
||||
def report_scalar(self, title, series, value, iteration):
|
||||
"""
|
||||
Report a scalar value
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param value: Reported value
|
||||
:type value: float
|
||||
:param iteration: Iteration number
|
||||
:type value: int
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param float value: Reported value
|
||||
:param int iteration: Iteration number
|
||||
"""
|
||||
|
||||
# if task was not started, we have to start it
|
||||
@ -244,18 +84,12 @@ class Logger(object):
|
||||
"""
|
||||
Report a histogram plot
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param values: Reported values (or numpy array)
|
||||
:type values: [float]
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param labels: optional, labels for each bar group.
|
||||
:type labels: list of strings.
|
||||
:param xlabels: optional label per entry in the vector (bucket in the histogram)
|
||||
:type xlabels: list of strings.
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param list(float) values: Reported values (or numpy array)
|
||||
:param int iteration: Iteration number
|
||||
:param list(str) labels: optional, labels for each bar group.
|
||||
:param list(str) xlabels: optional label per entry in the vector (bucket in the histogram)
|
||||
"""
|
||||
return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels)
|
||||
|
||||
@ -263,18 +97,12 @@ class Logger(object):
|
||||
"""
|
||||
Report a histogram plot
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param values: Reported values (or numpy array)
|
||||
:type values: [float]
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param labels: optional, labels for each bar group.
|
||||
:type labels: list of strings.
|
||||
:param xlabels: optional label per entry in the vector (bucket in the histogram)
|
||||
:type xlabels: list of strings.
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param list(float) values: Reported values (or numpy array)
|
||||
:param int iteration: Iteration number
|
||||
:param list(str) labels: optional, labels for each bar group.
|
||||
:param list(str) xlabels: optional label per entry in the vector (bucket in the histogram)
|
||||
"""
|
||||
|
||||
if not isinstance(values, np.ndarray):
|
||||
@ -292,24 +120,19 @@ class Logger(object):
|
||||
xlabels=xlabels,
|
||||
)
|
||||
|
||||
def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines', reverse_xaxis=False, comment=None):
|
||||
def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines',
|
||||
reverse_xaxis=False, comment=None):
|
||||
"""
|
||||
Report a (possibly multiple) line plot.
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: All the series' data, one for each line in the plot.
|
||||
:type series: An iterable of LineSeriesInfo.
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param xaxis: optional x-axis title
|
||||
:param yaxis: optional y-axis title
|
||||
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
|
||||
:type mode: str
|
||||
:param reverse_xaxis: If true X axis will be displayed from high to low (reversed)
|
||||
:type reverse_xaxis: bool
|
||||
:param comment: comment underneath the title
|
||||
:type comment: str
|
||||
:param str title: Title (AKA metric)
|
||||
:param list(LineSeriesInfo) series: All the series' data, one for each line in the plot.
|
||||
:param int iteration: Iteration number
|
||||
:param str xaxis: optional x-axis title
|
||||
:param str yaxis: optional y-axis title
|
||||
:param str mode: scatter plot with 'lines'/'markers'/'lines+markers'
|
||||
:param bool reverse_xaxis: If true X axis will be displayed from high to low (reversed)
|
||||
:param str comment: comment underneath the title
|
||||
"""
|
||||
|
||||
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
|
||||
@ -333,21 +156,15 @@ class Logger(object):
|
||||
"""
|
||||
Report a 2d scatter graph (with lines)
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param scatter: A scattered data: list of (pairs of x,y) (or numpy array)
|
||||
:type scatter: ndarray or list
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param xaxis: optional x-axis title
|
||||
:param yaxis: optional y-axis title
|
||||
:param labels: label (text) per point in the scatter (in the same order)
|
||||
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
|
||||
:type mode: str
|
||||
:param comment: comment underneath the title
|
||||
:type comment: str
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param np.ndarray scatter: A scattered data: list of (pairs of x,y) (or numpy array)
|
||||
:param int iteration: Iteration number
|
||||
:param str xaxis: optional x-axis title
|
||||
:param str yaxis: optional y-axis title
|
||||
:param list(str) labels: label (text) per point in the scatter (in the same order)
|
||||
:param str mode: scatter plot with 'lines'/'markers'/'lines+markers'
|
||||
:param str comment: comment underneath the title
|
||||
"""
|
||||
|
||||
if not isinstance(scatter, np.ndarray):
|
||||
@ -375,18 +192,15 @@ class Logger(object):
|
||||
"""
|
||||
Report a 3d scatter graph (with markers)
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param scatter: A scattered data: list of (pairs of x,y,z) (or numpy array) or list of series [[(x1,y1,z1)...]]
|
||||
:type scatter: ndarray or list
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param labels: label (text) per point in the scatter (in the same order)
|
||||
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
|
||||
:param fill: fill area under the curve
|
||||
:param comment: comment underneath the title
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param np.ndarray scatter: A scattered data: list of (pairs of x,y,z) (or numpy array)
|
||||
or list of series [[(x1,y1,z1)...]]
|
||||
:param int iteration: Iteration number
|
||||
:param list(str) labels: label (text) per point in the scatter (in the same order)
|
||||
:param str mode: scatter plot with 'lines'/'markers'/'lines+markers'
|
||||
:param bool fill: fill area under the curve
|
||||
:param str comment: comment underneath the title
|
||||
"""
|
||||
# check if multiple series
|
||||
multi_series = (
|
||||
@ -429,17 +243,13 @@ class Logger(object):
|
||||
"""
|
||||
Report a heat-map matrix
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param matrix: A heat-map matrix (example: confusion matrix)
|
||||
:type matrix: ndarray
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param xlabels: optional label per column of the matrix
|
||||
:param ylabels: optional label per row of the matrix
|
||||
:param comment: comment underneath the title
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param np.ndarray matrix: A heat-map matrix (example: confusion matrix)
|
||||
:param int iteration: Iteration number
|
||||
:param list(str) xlabels: optional label per column of the matrix
|
||||
:param list(str) ylabels: optional label per row of the matrix
|
||||
:param str comment: comment underneath the title
|
||||
"""
|
||||
|
||||
if not isinstance(matrix, np.ndarray):
|
||||
@ -463,16 +273,12 @@ class Logger(object):
|
||||
Same as report_confusion_matrix
|
||||
Report a heat-map matrix
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param matrix: A heat-map matrix (example: confusion matrix)
|
||||
:type matrix: ndarray
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param xlabels: optional label per column of the matrix
|
||||
:param ylabels: optional label per row of the matrix
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param np.ndarray matrix: A heat-map matrix (example: confusion matrix)
|
||||
:param int iteration: Iteration number
|
||||
:param list(str) xlabels: optional label per column of the matrix
|
||||
:param list(str) ylabels: optional label per row of the matrix
|
||||
"""
|
||||
return self.report_confusion_matrix(title, series, matrix, iteration, xlabels=xlabels, ylabels=ylabels)
|
||||
|
||||
@ -481,21 +287,17 @@ class Logger(object):
|
||||
"""
|
||||
Report a 3d surface (same data as heat-map matrix, only presented differently)
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param matrix: A heat-map matrix (example: confusion matrix)
|
||||
:type matrix: ndarray
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param xlabels: optional label per column of the matrix
|
||||
:param ylabels: optional label per row of the matrix
|
||||
:param xtitle: optional x-axis title
|
||||
:param ytitle: optional y-axis title
|
||||
:param ztitle: optional z-axis title
|
||||
:param camera: X,Y,Z camera position. def: (1,1,1)
|
||||
:param comment: comment underneath the title
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param np.ndarray matrix: A heat-map matrix (example: confusion matrix)
|
||||
:param int iteration: Iteration number
|
||||
:param list(str) xlabels: optional label per column of the matrix
|
||||
:param list(str) ylabels: optional label per row of the matrix
|
||||
:param str xtitle: optional x-axis title
|
||||
:param str ytitle: optional y-axis title
|
||||
:param str ztitle: optional z-axis title
|
||||
:param list(float) camera: X,Y,Z camera position. def: (1,1,1)
|
||||
:param str comment: comment underneath the title
|
||||
"""
|
||||
|
||||
if not isinstance(matrix, np.ndarray):
|
||||
@ -518,56 +320,24 @@ class Logger(object):
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
@_safe_names
|
||||
def report_image(self, title, series, src, iteration):
|
||||
"""
|
||||
Report an image, and register the 'src' as url content.
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param src: Image source URI. This URI will be used by the webapp and workers when trying to obtain the image \
|
||||
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
|
||||
:type src: str
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
"""
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
|
||||
self._task.reporter.report_image(
|
||||
title=title,
|
||||
series=series,
|
||||
src=src,
|
||||
iter=iteration,
|
||||
)
|
||||
|
||||
@_safe_names
|
||||
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
|
||||
delete_after_upload=False):
|
||||
def report_image(self, title, series, iteration, local_path=None, matrix=None, max_image_history=None,
|
||||
delete_after_upload=False):
|
||||
"""
|
||||
Report an image and upload its contents.
|
||||
|
||||
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
|
||||
describing the task ID, title, series and iteration.
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param iteration: Iteration number
|
||||
:type iteration: int
|
||||
:param path: A path to an image file. Required unless matrix is provided.
|
||||
:type path: str
|
||||
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
|
||||
:type matrix: str
|
||||
:param max_image_history: maximum number of image to store per metric/variant combination \
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
:type max_image_history: int
|
||||
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
|
||||
:type delete_after_upload: boolean
|
||||
:param str title: Title (AKA metric)
|
||||
:param str series: Series (AKA variant)
|
||||
:param int iteration: Iteration number
|
||||
:param str local_path: A path to an image file. Required unless matrix is provided.
|
||||
Required unless matrix is provided.
|
||||
:param np.ndarray matrix: A 3D numpy.ndarray object containing image data (RGB).
|
||||
Required unless filename is provided.
|
||||
:param int max_image_history: maximum number of image to store per metric/variant combination
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
:param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted
|
||||
"""
|
||||
|
||||
# if task was not started, we have to start it
|
||||
@ -584,7 +354,7 @@ class Logger(object):
|
||||
self._task.reporter.report_image_and_upload(
|
||||
title=title,
|
||||
series=series,
|
||||
path=path,
|
||||
path=local_path,
|
||||
matrix=matrix,
|
||||
iter=iteration,
|
||||
upload_uri=upload_uri,
|
||||
@ -592,7 +362,138 @@ class Logger(object):
|
||||
delete_after_upload=delete_after_upload,
|
||||
)
|
||||
|
||||
def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
|
||||
def set_default_upload_destination(self, uri):
|
||||
"""
|
||||
Set the uri to upload all the debug images to.
|
||||
|
||||
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
|
||||
a link to the uploaded image is sent in the report
|
||||
Notice: credentials for the upload destination will be pooled from the
|
||||
global configuration file (i.e. ~/trains.conf)
|
||||
|
||||
:param str uri: example: 's3://bucket/directory/' or 'file:///tmp/debug/'
|
||||
:return: True if destination scheme is supported (i.e. s3:// file:// gc:// etc...)
|
||||
"""
|
||||
|
||||
# Create the storage helper
|
||||
storage = StorageHelper.get(uri)
|
||||
|
||||
# Verify that we can upload to this destination
|
||||
uri = storage.verify_upload(folder_uri=uri)
|
||||
|
||||
self._default_upload_destination = uri
|
||||
|
||||
def get_default_upload_destination(self):
|
||||
"""
|
||||
Get the uri to upload all the debug images to.
|
||||
|
||||
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
|
||||
a link to the uploaded image is sent in the report
|
||||
Notice: credentials for the upload destination will be pooled from the
|
||||
global configuration file (i.e. ~/trains.conf)
|
||||
|
||||
:return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc...
|
||||
"""
|
||||
return self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush cached reports and console outputs to backend.
|
||||
|
||||
:return: True if successful
|
||||
"""
|
||||
self._flush_stdout_handler()
|
||||
if self._task:
|
||||
return self._task.flush()
|
||||
return False
|
||||
|
||||
def get_flush_period(self):
|
||||
"""
|
||||
:return: logger flush period in seconds
|
||||
"""
|
||||
if self._flusher:
|
||||
return self._flusher.period
|
||||
return None
|
||||
|
||||
def set_flush_period(self, period):
|
||||
"""
|
||||
Set the period of the logger flush.
|
||||
|
||||
:param float period: The period to flush the logger in seconds. If None or 0,
|
||||
There will be no periodic flush.
|
||||
"""
|
||||
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
|
||||
not running_remotely() and period is not None:
|
||||
period = min(period or DevWorker.report_period, DevWorker.report_period)
|
||||
|
||||
if not period:
|
||||
if self._flusher:
|
||||
self._flusher.exit()
|
||||
self._flusher = None
|
||||
elif self._flusher:
|
||||
self._flusher.set_period(period)
|
||||
else:
|
||||
self._flusher = LogFlusher(self, period)
|
||||
self._flusher.start()
|
||||
|
||||
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
|
||||
delete_after_upload=False):
|
||||
"""
|
||||
Backwards compatibility, please use report_image instead
|
||||
"""
|
||||
self.report_image(title=title, series=series, iteration=iteration, local_path=path, matrix=matrix,
|
||||
max_image_history=max_image_history, delete_after_upload=delete_after_upload)
|
||||
|
||||
@classmethod
|
||||
def _remove_std_logger(cls):
|
||||
StdStreamPatch.remove_std_logger()
|
||||
|
||||
def _console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs):
|
||||
"""
|
||||
print text to log (same as print to console, and also prints to console)
|
||||
|
||||
:param msg: text to print to the console (always send to the backend and displayed in console)
|
||||
:param level: logging level, default: logging.INFO
|
||||
:param omit_console: If True we only send 'msg' to log (no console print)
|
||||
"""
|
||||
try:
|
||||
level = int(level)
|
||||
except (TypeError, ValueError):
|
||||
self._task.log.log(level=logging.ERROR,
|
||||
msg='Logger failed casting log level "%s" to integer' % str(level))
|
||||
level = logging.INFO
|
||||
|
||||
if not running_remotely():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
record = self._task.log.makeRecord(
|
||||
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
|
||||
)
|
||||
# find the task handler that matches our task
|
||||
if not self._task_handler:
|
||||
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers
|
||||
if isinstance(h, TaskHandler) and h.task_id == self._task.id][0]
|
||||
self._task_handler.emit(record)
|
||||
except Exception:
|
||||
LoggerRoot.get_base_logger().warning(msg='Logger failed sending log: [level %s]: "%s"'
|
||||
% (str(level), str(msg)))
|
||||
|
||||
if not omit_console:
|
||||
# if we are here and we grabbed the stdout, we need to print the real thing
|
||||
if DevWorker.report_stdout and not running_remotely():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we are writing to the original stdout
|
||||
StdStreamPatch.stdout_original_write(str(msg)+'\n')
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(str(msg))
|
||||
|
||||
# if task was not started, we have to start it
|
||||
self._start_task_if_needed()
|
||||
|
||||
def _report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
|
||||
delete_after_upload=False):
|
||||
"""
|
||||
Report an image, upload its contents, and present in plots section using plotly
|
||||
@ -639,7 +540,7 @@ class Logger(object):
|
||||
delete_after_upload=delete_after_upload,
|
||||
)
|
||||
|
||||
def report_file_and_upload(self, title, series, iteration, path=None, max_file_history=None,
|
||||
def _report_file_and_upload(self, title, series, iteration, path=None, max_file_history=None,
|
||||
delete_after_upload=False):
|
||||
"""
|
||||
Upload a file and report it as link in the debug images section.
|
||||
@ -684,92 +585,6 @@ class Logger(object):
|
||||
delete_after_upload=delete_after_upload,
|
||||
)
|
||||
|
||||
def set_default_upload_destination(self, uri):
|
||||
"""
|
||||
Set the uri to upload all the debug images to.
|
||||
|
||||
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
|
||||
a link to the uploaded image is sent in the report
|
||||
Notice: credentials for the upload destination will be pooled from the
|
||||
global configuration file (i.e. ~/trains.conf)
|
||||
|
||||
:param uri: example: 's3://bucket/directory/' or 'file:///tmp/debug/'
|
||||
:return: True if destination scheme is supported (i.e. s3:// file:// gc:// etc...)
|
||||
"""
|
||||
|
||||
# Create the storage helper
|
||||
storage = StorageHelper.get(uri)
|
||||
|
||||
# Verify that we can upload to this destination
|
||||
uri = storage.verify_upload(folder_uri=uri)
|
||||
|
||||
self._default_upload_destination = uri
|
||||
|
||||
def get_default_upload_destination(self):
|
||||
"""
|
||||
Get the uri to upload all the debug images to.
|
||||
|
||||
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
|
||||
a link to the uploaded image is sent in the report
|
||||
Notice: credentials for the upload destination will be pooled from the
|
||||
global configuration file (i.e. ~/trains.conf)
|
||||
|
||||
:return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc...
|
||||
"""
|
||||
return self._default_upload_destination or self._task._get_default_report_storage_uri()
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush cached reports and console outputs to backend.
|
||||
|
||||
:return: True if successful
|
||||
"""
|
||||
self._flush_stdout_handler()
|
||||
if self._task:
|
||||
return self._task.flush()
|
||||
return False
|
||||
|
||||
def get_flush_period(self):
|
||||
if self._flusher:
|
||||
return self._flusher.period
|
||||
return None
|
||||
|
||||
def set_flush_period(self, period):
|
||||
"""
|
||||
Set the period of the logger flush.
|
||||
|
||||
:param period: The period to flush the logger in seconds. If None or 0,
|
||||
There will be no periodic flush.
|
||||
"""
|
||||
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
|
||||
not running_remotely() and period is not None:
|
||||
period = min(period or DevWorker.report_period, DevWorker.report_period)
|
||||
|
||||
if not period:
|
||||
if self._flusher:
|
||||
self._flusher.exit()
|
||||
self._flusher = None
|
||||
elif self._flusher:
|
||||
self._flusher.set_period(period)
|
||||
else:
|
||||
self._flusher = _Flusher(self, period)
|
||||
self._flusher.start()
|
||||
|
||||
@classmethod
|
||||
def _remove_std_logger(self):
|
||||
if isinstance(sys.stdout, PrintPatchLogger):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.stdout.connect(None)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(sys.stderr, PrintPatchLogger):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.stderr.connect(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _start_task_if_needed(self):
|
||||
# do not refresh the task status read from cached variable _status
|
||||
if str(self._task._status) == str(tasks.TaskStatusEnum.created):
|
||||
@ -780,121 +595,3 @@ class Logger(object):
|
||||
def _flush_stdout_handler(self):
|
||||
if self._task_handler and DevWorker.report_stdout:
|
||||
self._task_handler.flush()
|
||||
|
||||
|
||||
def stdout__patched__write__(*args, **kwargs):
|
||||
if Logger._stdout_proxy:
|
||||
return Logger._stdout_proxy.write(*args, **kwargs)
|
||||
return sys.stdout._original_write(*args, **kwargs)
|
||||
|
||||
|
||||
def stderr__patched__write__(*args, **kwargs):
|
||||
if Logger._stderr_proxy:
|
||||
return Logger._stderr_proxy.write(*args, **kwargs)
|
||||
return sys.stderr._original_write(*args, **kwargs)
|
||||
|
||||
|
||||
class PrintPatchLogger(object):
|
||||
"""
|
||||
Allowed patching a stream into the logger.
|
||||
Used for capturing and logging stdin and stderr when running in development mode pseudo worker.
|
||||
"""
|
||||
patched = False
|
||||
lock = threading.Lock()
|
||||
recursion_protect_lock = threading.RLock()
|
||||
|
||||
def __init__(self, stream, logger=None, level=logging.INFO):
|
||||
PrintPatchLogger.patched = True
|
||||
self._terminal = stream
|
||||
self._log = logger
|
||||
self._log_level = level
|
||||
self._cur_line = ''
|
||||
|
||||
def write(self, message):
|
||||
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
|
||||
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
|
||||
try:
|
||||
self.lock.acquire()
|
||||
with PrintPatchLogger.recursion_protect_lock:
|
||||
if hasattr(self._terminal, '_original_write'):
|
||||
self._terminal._original_write(message)
|
||||
else:
|
||||
self._terminal.write(message)
|
||||
|
||||
do_flush = '\n' in message
|
||||
do_cr = '\r' in message
|
||||
self._cur_line += message
|
||||
if (not do_flush and not do_cr) or not message:
|
||||
return
|
||||
last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
|
||||
next_line = self._cur_line[last_lf + 1:]
|
||||
cur_line = self._cur_line[:last_lf + 1].rstrip()
|
||||
self._cur_line = next_line
|
||||
finally:
|
||||
self.lock.release()
|
||||
|
||||
if cur_line:
|
||||
with PrintPatchLogger.recursion_protect_lock:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self._log:
|
||||
self._log.console(cur_line, level=self._log_level, omit_console=True)
|
||||
except Exception:
|
||||
# what can we do, nothing
|
||||
pass
|
||||
else:
|
||||
if hasattr(self._terminal, '_original_write'):
|
||||
self._terminal._original_write(message)
|
||||
else:
|
||||
self._terminal.write(message)
|
||||
|
||||
def connect(self, logger):
|
||||
self._cur_line = ''
|
||||
self._log = logger
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
|
||||
return self.__dict__.get(attr)
|
||||
return getattr(self._terminal, attr)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in ['_log', '_terminal', '_log_level', '_cur_line']:
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
return setattr(self._terminal, key, value)
|
||||
|
||||
|
||||
class _Flusher(threading.Thread):
|
||||
def __init__(self, logger, period, **kwargs):
|
||||
super(_Flusher, self).__init__(**kwargs)
|
||||
self.daemon = True
|
||||
|
||||
self._period = period
|
||||
self._logger = logger
|
||||
self._exit_event = threading.Event()
|
||||
|
||||
@property
|
||||
def period(self):
|
||||
return self._period
|
||||
|
||||
def run(self):
|
||||
self._logger.flush()
|
||||
# store original wait period
|
||||
while True:
|
||||
period = self._period
|
||||
while not self._exit_event.wait(period or 1.0):
|
||||
self._logger.flush()
|
||||
# check if period is negative or None we should exit
|
||||
if self._period is None or self._period < 0:
|
||||
break
|
||||
# check if period was changed, we should restart
|
||||
self._exit_event.clear()
|
||||
|
||||
def exit(self):
|
||||
self._period = None
|
||||
self._exit_event.set()
|
||||
|
||||
def set_period(self, period):
|
||||
self._period = period
|
||||
# make sure we exit the previous wait
|
||||
self._exit_event.set()
|
||||
|
@ -340,7 +340,6 @@ class InputModel(BaseModel):
|
||||
name=None,
|
||||
tags=None,
|
||||
comment=None,
|
||||
logger=None,
|
||||
is_package=False,
|
||||
create_as_published=False,
|
||||
framework=None,
|
||||
@ -367,7 +366,6 @@ class InputModel(BaseModel):
|
||||
:param name: optional, name for the newly imported model
|
||||
:param tags: optional, list of strings as tags
|
||||
:param comment: optional, string description for the model
|
||||
:param logger: The logger to use. If None, use the default logger
|
||||
:param is_package: Boolean. Indicates that the imported weights file is a package.
|
||||
If True, and a new model was created, a package tag will be added.
|
||||
:param create_as_published: Boolean. If True, and a new model is created, it will be published.
|
||||
@ -386,8 +384,7 @@ class InputModel(BaseModel):
|
||||
))
|
||||
|
||||
if result.response.models:
|
||||
if not logger:
|
||||
logger = get_logger()
|
||||
logger = get_logger()
|
||||
|
||||
logger.debug('A model with uri "{}" already exists. Selecting it'.format(weights_url))
|
||||
|
||||
|
@ -1,16 +1,15 @@
|
||||
import atexit
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
from collections import OrderedDict, Callable
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from .binding.joblib_bind import PatchedJoblib
|
||||
from .backend_api.services import tasks, projects
|
||||
@ -29,7 +28,7 @@ from .errors import UsageError
|
||||
from .logger import Logger
|
||||
from .model import InputModel, OutputModel, ARCHIVED_TAG
|
||||
from .task_parameters import TaskParameters
|
||||
from .binding.artifacts import Artifacts
|
||||
from .binding.artifacts import Artifacts, Artifact
|
||||
from .binding.environ_bind import EnvironmentBind, PatchOsFork
|
||||
from .binding.absl_bind import PatchAbsl
|
||||
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
|
||||
@ -41,6 +40,7 @@ from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .utilities.resource_monitor import ResourceMonitor
|
||||
from .utilities.seed import make_deterministic
|
||||
from .utilities.dicts import ReadOnlyDict
|
||||
|
||||
NotSet = object()
|
||||
|
||||
@ -113,6 +113,7 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def current_task(cls):
|
||||
# type: () -> Task
|
||||
"""
|
||||
Return the Current Task object for the main execution task (task context).
|
||||
:return: Task() object or None
|
||||
@ -279,7 +280,7 @@ class Task(_Task):
|
||||
# The logger will automatically take care of all patching (we just need to make sure to initialize it)
|
||||
logger = task.get_logger()
|
||||
# show the debug metrics page in the log, it is very convenient
|
||||
logger.console(
|
||||
logger.report_text(
|
||||
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
|
||||
task._get_app_server(),
|
||||
task.project if task.project is not None else '*',
|
||||
@ -439,12 +440,12 @@ class Task(_Task):
|
||||
task._setup_log(replace_existing=True)
|
||||
logger = task.get_logger()
|
||||
if closed_old_task:
|
||||
logger.console('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
|
||||
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
|
||||
# print warning, reusing/creating a task
|
||||
if default_task_id:
|
||||
logger.console('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
|
||||
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
|
||||
else:
|
||||
logger.console('TRAINS Task: created new task id=%s' % task.id)
|
||||
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
|
||||
|
||||
# update current repository and put warning into logs
|
||||
if in_dev_mode and cls.__detect_repo_async:
|
||||
@ -462,8 +463,8 @@ class Task(_Task):
|
||||
thread.start()
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_task(task_id=None, project_name=None, task_name=None):
|
||||
@classmethod
|
||||
def get_task(cls, task_id=None, project_name=None, task_name=None):
|
||||
"""
|
||||
Returns Task object based on either, task_id (system uuid) or task name
|
||||
|
||||
@ -472,7 +473,7 @@ class Task(_Task):
|
||||
:param task_name: task name (str) in within the selected project
|
||||
:return: Task() object
|
||||
"""
|
||||
return Task.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
|
||||
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
|
||||
|
||||
@property
|
||||
def output_uri(self):
|
||||
@ -490,10 +491,14 @@ class Task(_Task):
|
||||
@property
|
||||
def artifacts(self):
|
||||
"""
|
||||
dictionary of Task artifacts (name, artifact)
|
||||
read-only dictionary of Task artifacts (name, artifact)
|
||||
:return: dict
|
||||
"""
|
||||
return self._artifacts_manager.artifacts
|
||||
if not Session.check_min_api_version('2.3'):
|
||||
return ReadOnlyDict()
|
||||
if not self.data.execution or not self.data.execution.artifacts:
|
||||
return ReadOnlyDict()
|
||||
return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts])
|
||||
|
||||
def set_comment(self, comment):
|
||||
"""
|
||||
@ -553,6 +558,7 @@ class Task(_Task):
|
||||
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
|
||||
|
||||
def get_logger(self, flush_period=NotSet):
|
||||
# type: (Optional[float]) -> Logger
|
||||
"""
|
||||
get a logger object for reporting based on the task
|
||||
|
||||
@ -663,6 +669,15 @@ class Task(_Task):
|
||||
"""
|
||||
self._artifacts_manager.unregister_artifact(name=name)
|
||||
|
||||
def get_registered_artifacts(self):
|
||||
"""
|
||||
dictionary of Task registered artifacts (name, artifact object)
|
||||
Notice these objects can be modified, changes will be uploaded automatically
|
||||
|
||||
:return: dict
|
||||
"""
|
||||
return self._artifacts_manager.registered_artifacts
|
||||
|
||||
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
|
||||
"""
|
||||
Add static artifact to Task. Artifact file/object will be uploaded in the background
|
||||
@ -671,6 +686,7 @@ class Task(_Task):
|
||||
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
|
||||
:param object artifact_object: Artifact object to upload. Currently supports:
|
||||
- string / pathlib2.Path are treated as path to artifact file to upload
|
||||
If wildcard or a folder is passed, zip file containing the local files will be created and uploaded.
|
||||
- dict will be stored as .json,
|
||||
- pandas.DataFrame will be stored as .csv.gz (compressed CSV file),
|
||||
- numpy.ndarray will be stored as .npz,
|
||||
@ -937,7 +953,7 @@ class Task(_Task):
|
||||
if self._at_exit_called:
|
||||
return
|
||||
|
||||
self.get_logger().warn(
|
||||
self.log.warning(
|
||||
"### TASK STOPPED - USER ABORTED - {} ###".format(
|
||||
stop_reason.upper().replace('_', ' ')
|
||||
)
|
||||
@ -1009,7 +1025,7 @@ class Task(_Task):
|
||||
# signal artifacts upload, and stop daemon
|
||||
self._artifacts_manager.stop(wait=True)
|
||||
# print artifacts summary
|
||||
self.get_logger().console(self._artifacts_manager.summary)
|
||||
self.get_logger().report_text(self._artifacts_manager.summary)
|
||||
|
||||
def _at_exit(self):
|
||||
"""
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
|
||||
@ -313,9 +314,11 @@ class CheckPackageUpdates(object):
|
||||
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from ..version import __version__
|
||||
cls._package_version_checked = True
|
||||
# Sending the request only for statistics
|
||||
update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server)
|
||||
update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server,
|
||||
args=(__version__,))
|
||||
update_statistics.daemon = True
|
||||
update_statistics.start()
|
||||
|
||||
@ -323,7 +326,6 @@ class CheckPackageUpdates(object):
|
||||
|
||||
releases = [Version(r) for r in releases]
|
||||
latest_version = sorted(releases)
|
||||
from ..version import __version__
|
||||
cur_version = Version(__version__)
|
||||
if not cur_version.is_devrelease and not cur_version.is_prerelease:
|
||||
latest_version = [r for r in latest_version if not r.is_devrelease and not r.is_prerelease]
|
||||
@ -336,8 +338,9 @@ class CheckPackageUpdates(object):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_version_from_updates_server():
|
||||
def get_version_from_updates_server(cur_version):
|
||||
try:
|
||||
_ = requests.get('https://updates.trainsai.io/updates', timeout=1.0)
|
||||
_ = requests.get('https://updates.trainsai.io/updates',
|
||||
params=json.dumps({'versions': {'trains': str(cur_version)}}), timeout=1.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -3,6 +3,19 @@
|
||||
_epsilon = 0.00001
|
||||
|
||||
|
||||
class ReadOnlyDict(dict):
|
||||
def __readonly__(self, *args, **kwargs):
|
||||
raise ValueError("This is a read only dictionary")
|
||||
__setitem__ = __readonly__
|
||||
__delitem__ = __readonly__
|
||||
pop = __readonly__
|
||||
popitem = __readonly__
|
||||
clear = __readonly__
|
||||
update = __readonly__
|
||||
setdefault = __readonly__
|
||||
del __readonly__
|
||||
|
||||
|
||||
class Logs:
|
||||
_logs_instances = []
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import logging
|
||||
import warnings
|
||||
from time import time
|
||||
from threading import Thread, Event
|
||||
|
||||
@ -32,8 +34,8 @@ class ResourceMonitor(object):
|
||||
self._gpustat_fail = 0
|
||||
self._gpustat = gpustat
|
||||
if not self._gpustat:
|
||||
self._task.get_logger().console('TRAINS Monitor: GPU monitoring is not available, '
|
||||
'run \"pip install gpustat\"')
|
||||
self._task.get_logger().report_text('TRAINS Monitor: GPU monitoring is not available, '
|
||||
'run \"pip install gpustat\"')
|
||||
|
||||
def start(self):
|
||||
self._exit_event.clear()
|
||||
@ -73,8 +75,8 @@ class ResourceMonitor(object):
|
||||
if IsTensorboardInit.tensorboard_used():
|
||||
fallback_to_sec_as_iterations = False
|
||||
elif seconds_since_started >= self._wait_for_first_iteration:
|
||||
self._task.get_logger().console('TRAINS Monitor: Could not detect iteration reporting, '
|
||||
'falling back to iterations as seconds-from-start')
|
||||
self._task.get_logger().report_text('TRAINS Monitor: Could not detect iteration reporting, '
|
||||
'falling back to iterations as seconds-from-start')
|
||||
fallback_to_sec_as_iterations = True
|
||||
|
||||
clear_readouts = True
|
||||
@ -168,9 +170,11 @@ class ResourceMonitor(object):
|
||||
stats["memory_free_gb"] = bytes_to_megabytes(virtual_memory.available) / 1024
|
||||
disk_use_percentage = psutil.disk_usage(Text(Path.home())).percent
|
||||
stats["disk_free_percent"] = 100.0-disk_use_percentage
|
||||
sensor_stat = (
|
||||
psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {}
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
if logging.root.level > logging.DEBUG: # If the logging level is bigger than debug, ignore
|
||||
# psutil.sensors_temperatures warnings
|
||||
warnings.simplefilter("ignore", category=RuntimeWarning)
|
||||
sensor_stat = (psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {})
|
||||
if "coretemp" in sensor_stat and len(sensor_stat["coretemp"]):
|
||||
stats["cpu_temperature"] = max([float(t.current) for t in sensor_stat["coretemp"]])
|
||||
|
||||
@ -197,8 +201,8 @@ class ResourceMonitor(object):
|
||||
# something happened and we can't use gpu stats,
|
||||
self._gpustat_fail += 1
|
||||
if self._gpustat_fail >= 3:
|
||||
self._task.get_logger().console('TRAINS Monitor: GPU monitoring failed getting GPU reading, '
|
||||
'switching off GPU monitoring')
|
||||
self._task.get_logger().report_text('TRAINS Monitor: GPU monitoring failed getting GPU reading, '
|
||||
'switching off GPU monitoring')
|
||||
self._gpustat = None
|
||||
|
||||
return stats
|
||||
|
Loading…
Reference in New Issue
Block a user