From 0b4f00af4dd76983f5e0a60fb4992f9563576cab Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 23 Sep 2019 18:40:13 +0300 Subject: [PATCH] Restructured Logger with nice clean interface. Breaking changes: Logger no longer supports info/error/warning console() replaced with report_text() --- examples/manual_reporting.py | 8 +- trains/backend_api/session/callresult.py | 2 +- trains/backend_api/session/session.py | 2 +- trains/backend_config/bucket_config.py | 2 +- trains/backend_interface/logger.py | 212 +++++ trains/backend_interface/metrics/events.py | 8 +- trains/backend_interface/task/access.py | 4 +- .../task/development/stop_signal.py | 2 +- trains/backend_interface/task/task.py | 90 ++- trains/backend_interface/util.py | 6 +- trains/binding/frameworks/tensorflow_bind.py | 49 +- trains/binding/matplotlib_bind.py | 16 +- trains/config/__init__.py | 2 +- trains/config/defs.py | 2 +- trains/logger.py | 749 ++++++------------ trains/model.py | 5 +- trains/task.py | 44 +- trains/utilities/check_updates.py | 11 +- trains/utilities/dicts.py | 13 + trains/utilities/resource_monitor.py | 22 +- 20 files changed, 602 insertions(+), 647 deletions(-) create mode 100644 trains/backend_interface/logger.py diff --git a/examples/manual_reporting.py b/examples/manual_reporting.py index 9692501f..4ef19f13 100644 --- a/examples/manual_reporting.py +++ b/examples/manual_reporting.py @@ -21,7 +21,7 @@ except ImportError: logger = Task.current_task().get_logger() # log text -logger.console("hello") +logger.report_text("hello") # report scalar values logger.report_scalar("example_scalar", "series A", iteration=0, value=100) @@ -49,11 +49,11 @@ logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter # reporting images m = np.eye(256, 256, dtype=np.float) -logger.report_image_and_upload("test case", "image float", iteration=1, matrix=m) +logger.report_image("test case", "image float", iteration=1, matrix=m) m = np.eye(256, 256, dtype=np.uint8)*255 -logger.report_image_and_upload("test case", "image uint8", iteration=1, matrix=m) +logger.report_image("test case", "image uint8", iteration=1, matrix=m) m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2) -logger.report_image_and_upload("test case", "image color red", iteration=1, matrix=m) +logger.report_image("test case", "image color red", iteration=1, matrix=m) # flush reports (otherwise it will be flushed in the background, every couple of seconds) logger.flush() diff --git a/trains/backend_api/session/callresult.py b/trains/backend_api/session/callresult.py index 40bd5d1d..834f166e 100644 --- a/trains/backend_api/session/callresult.py +++ b/trains/backend_api/session/callresult.py @@ -81,7 +81,7 @@ class CallResult(object): # response.validate() except Exception as e: if logger: - logger.warn('Failed parsing response: %s' % str(e)) + logger.warning('Failed parsing response: %s' % str(e)) return cls(meta=meta, response=response, response_data=response_data, request_cls=request_cls, session=session) def ok(self): diff --git a/trains/backend_api/session/session.py b/trains/backend_api/session/session.py index b3e50d8c..5c8d5a2f 100644 --- a/trains/backend_api/session/session.py +++ b/trains/backend_api/session/session.py @@ -215,7 +215,7 @@ class Session(TokenManager): res.status_code == requests.codes.service_unavailable and self.config.get("api.http.wait_on_maintenance_forever", True) ): - self._logger.warn( + self._logger.warning( "Service unavailable: {} is undergoing maintenance, retrying...".format( host ) diff --git a/trains/backend_config/bucket_config.py b/trains/backend_config/bucket_config.py index 24380a78..7f644d67 100644 --- a/trains/backend_config/bucket_config.py +++ b/trains/backend_config/bucket_config.py @@ -50,7 +50,7 @@ class S3BucketConfig(object): configs = [cls(**entry) for entry in dict_list] valid_configs = [conf for conf in configs if conf.is_valid()] if log and len(valid_configs) < len(configs): - log.warn( + log.warning( "Invalid bucket configurations detected for {}".format( ", ".join( "/".join((config.host, config.bucket)) diff --git a/trains/backend_interface/logger.py b/trains/backend_interface/logger.py new file mode 100644 index 00000000..391921ae --- /dev/null +++ b/trains/backend_interface/logger.py @@ -0,0 +1,212 @@ +import logging +import sys +import threading + +from ..backend_interface.task.development.worker import DevWorker +from ..backend_interface.task.log import TaskHandler +from ..config import running_remotely + + +class StdStreamPatch(object): + _stdout_proxy = None + _stderr_proxy = None + _stdout_original_write = None + + @staticmethod + def patch_std_streams(logger): + if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely(): + StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO) + StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR) + logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100) + # noinspection PyBroadException + try: + if StdStreamPatch._stdout_original_write is None: + StdStreamPatch._stdout_original_write = sys.stdout.write + # this will only work in python 3, guard it with try/catch + if not hasattr(sys.stdout, '_original_write'): + sys.stdout._original_write = sys.stdout.write + sys.stdout.write = StdStreamPatch._stdout__patched__write__ + if not hasattr(sys.stderr, '_original_write'): + sys.stderr._original_write = sys.stderr.write + sys.stderr.write = StdStreamPatch._stderr__patched__write__ + except Exception: + pass + sys.stdout = StdStreamPatch._stdout_proxy + sys.stderr = StdStreamPatch._stderr_proxy + # patch the base streams of sys (this way colorama will keep its ANSI colors) + # noinspection PyBroadException + try: + sys.__stderr__ = sys.stderr + except Exception: + pass + # noinspection PyBroadException + try: + sys.__stdout__ = sys.stdout + except Exception: + pass + + # now check if we have loguru and make it re-register the handlers + # because it sores internally the stream.write function, which we cant patch + # noinspection PyBroadException + try: + from loguru import logger + register_stderr = None + register_stdout = None + for k, v in logger._handlers.items(): + if v._name == '': + register_stderr = k + elif v._name == '': + register_stderr = k + if register_stderr is not None: + logger.remove(register_stderr) + logger.add(sys.stderr) + if register_stdout is not None: + logger.remove(register_stdout) + logger.add(sys.stdout) + except Exception: + pass + + elif DevWorker.report_stdout and not running_remotely(): + logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100) + if StdStreamPatch._stdout_proxy: + StdStreamPatch._stdout_proxy.connect(logger) + if StdStreamPatch._stderr_proxy: + StdStreamPatch._stderr_proxy.connect(logger) + + @staticmethod + def remove_std_logger(): + if isinstance(sys.stdout, PrintPatchLogger): + # noinspection PyBroadException + try: + sys.stdout.connect(None) + except Exception: + pass + if isinstance(sys.stderr, PrintPatchLogger): + # noinspection PyBroadException + try: + sys.stderr.connect(None) + except Exception: + pass + + @staticmethod + def stdout_original_write(*args, **kwargs): + if StdStreamPatch._stdout_original_write: + StdStreamPatch._stdout_original_write(*args, **kwargs) + + @staticmethod + def _stdout__patched__write__(*args, **kwargs): + if StdStreamPatch._stdout_proxy: + return StdStreamPatch._stdout_proxy.write(*args, **kwargs) + return sys.stdout._original_write(*args, **kwargs) + + @staticmethod + def _stderr__patched__write__(*args, **kwargs): + if StdStreamPatch._stderr_proxy: + return StdStreamPatch._stderr_proxy.write(*args, **kwargs) + return sys.stderr._original_write(*args, **kwargs) + + +class PrintPatchLogger(object): + """ + Allowed patching a stream into the logger. + Used for capturing and logging stdin and stderr when running in development mode pseudo worker. + """ + patched = False + lock = threading.Lock() + recursion_protect_lock = threading.RLock() + + def __init__(self, stream, logger=None, level=logging.INFO): + PrintPatchLogger.patched = True + self._terminal = stream + self._log = logger + self._log_level = level + self._cur_line = '' + + def write(self, message): + # make sure that we do not end up in infinite loop (i.e. log.console ends up calling us) + if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned(): + try: + self.lock.acquire() + with PrintPatchLogger.recursion_protect_lock: + if hasattr(self._terminal, '_original_write'): + self._terminal._original_write(message) + else: + self._terminal.write(message) + + do_flush = '\n' in message + do_cr = '\r' in message + self._cur_line += message + if (not do_flush and not do_cr) or not message: + return + last_lf = self._cur_line.rindex('\n' if do_flush else '\r') + next_line = self._cur_line[last_lf + 1:] + cur_line = self._cur_line[:last_lf + 1].rstrip() + self._cur_line = next_line + finally: + self.lock.release() + + if cur_line: + with PrintPatchLogger.recursion_protect_lock: + # noinspection PyBroadException + try: + if self._log: + self._log._console(cur_line, level=self._log_level, omit_console=True) + except Exception: + # what can we do, nothing + pass + else: + if hasattr(self._terminal, '_original_write'): + self._terminal._original_write(message) + else: + self._terminal.write(message) + + def connect(self, logger): + self._cur_line = '' + self._log = logger + + def __getattr__(self, attr): + if attr in ['_log', '_terminal', '_log_level', '_cur_line']: + return self.__dict__.get(attr) + return getattr(self._terminal, attr) + + def __setattr__(self, key, value): + if key in ['_log', '_terminal', '_log_level', '_cur_line']: + self.__dict__[key] = value + else: + return setattr(self._terminal, key, value) + + +class LogFlusher(threading.Thread): + def __init__(self, logger, period, **kwargs): + super(LogFlusher, self).__init__(**kwargs) + self.daemon = True + + self._period = period + self._logger = logger + self._exit_event = threading.Event() + + @property + def period(self): + return self._period + + def run(self): + self._logger.flush() + # store original wait period + while True: + period = self._period + while not self._exit_event.wait(period or 1.0): + self._logger.flush() + # check if period is negative or None we should exit + if self._period is None or self._period < 0: + break + # check if period was changed, we should restart + self._exit_event.clear() + + def exit(self): + self._period = None + self._exit_event.set() + + def set_period(self, period): + self._period = period + # make sure we exit the previous wait + self._exit_event.set() diff --git a/trains/backend_interface/metrics/events.py b/trains/backend_interface/metrics/events.py index 5c990208..d72e88db 100644 --- a/trains/backend_interface/metrics/events.py +++ b/trains/backend_interface/metrics/events.py @@ -177,6 +177,8 @@ class UploadEvent(MetricsEventAdapter): def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None, image_file_history_size=None, delete_after_upload=False, **kwargs): + # param override_filename: override uploaded file name (notice extension will be added from local path + # param override_filename_ext: override uploaded file extension if image_data is not None and not hasattr(image_data, 'shape'): raise ValueError('Image must have a shape attribute') self._image_data = image_data @@ -197,8 +199,10 @@ class UploadEvent(MetricsEventAdapter): # get upload uri upfront, either predefined image format or local file extension # e.g.: image.png -> .png or image.raw.gz -> .raw.gz - image_format = self._format.lower() if self._image_data is not None else \ - '.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:]) + image_format = kwargs.pop('override_filename_ext', None) + if image_format is None: + image_format = self._format.lower() if self._image_data is not None else \ + '.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:]) self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format)) self._override_storage_key_prefix = kwargs.pop('override_storage_key_prefix', None) diff --git a/trains/backend_interface/task/access.py b/trains/backend_interface/task/access.py index 3d19097c..c4a552f0 100644 --- a/trains/backend_interface/task/access.py +++ b/trains/backend_interface/task/access.py @@ -80,6 +80,6 @@ class AccessMixin(object): expected_num_of_classes += 1 if int(index) > 0 else 0 num_of_classes = int(max(model_labels.values())) if num_of_classes != expected_num_of_classes: - self.log.warn('The highest label index is %d, while there are %d non-bg labels' % - (num_of_classes, expected_num_of_classes)) + self.log.warning('The highest label index is %d, while there are %d non-bg labels' % + (num_of_classes, expected_num_of_classes)) return num_of_classes + 1 # +1 is meant for bg! diff --git a/trains/backend_interface/task/development/stop_signal.py b/trains/backend_interface/task/development/stop_signal.py index 98af69a6..f4566ba3 100644 --- a/trains/backend_interface/task/development/stop_signal.py +++ b/trains/backend_interface/task/development/stop_signal.py @@ -51,7 +51,7 @@ class TaskStopSignal(object): if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests: return TaskStopReason.reset - self.task.get_logger().warning( + self.task.log.warning( "Task {} was reset! if state is consistent we shall terminate.".format(self.task.id), ) else: diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index b6608d49..4bba695a 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -160,7 +160,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): conf = get_config_for_bucket(base_url=output_dest) if not conf: msg = 'Failed resolving output destination (no credentials found for %s)' % output_dest - self.log.warn(msg) + self.log.warning(msg) if raise_errors: raise Exception(msg) else: @@ -187,12 +187,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): latest_version = CheckPackageUpdates.check_new_package_available(only_once=True) if latest_version: if not latest_version[1]: - self.get_logger().console( + self.get_logger().report_text( 'TRAINS new package available: UPGRADE to v{} is recommended!'.format( latest_version[0]), ) else: - self.get_logger().console( + self.get_logger().report_text( 'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format( latest_version[0]), ) @@ -205,7 +205,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): check_package_update_thread.start() result = ScriptInfo.get(log=self.log) for msg in result.warning_messages: - self.get_logger().console(msg) + self.get_logger().report_text(msg) self.data.script = result.script # Since we might run asynchronously, don't use self.data (lest someone else @@ -418,16 +418,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): def update_model_desc(self, new_model_desc_file=None): """ Change the task's model_desc """ - execution = self._get_task_property('execution') - p = Path(new_model_desc_file) - if not p.is_file(): - raise IOError('mode_desc file %s cannot be found' % new_model_desc_file) - new_model_desc = p.read_text() - model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design' - execution.model_desc[model_desc_key] = new_model_desc + with self._edit_lock: + execution = self._get_task_property('execution') + p = Path(new_model_desc_file) + if not p.is_file(): + raise IOError('mode_desc file %s cannot be found' % new_model_desc_file) + new_model_desc = p.read_text() + model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design' + execution.model_desc[model_desc_key] = new_model_desc - res = self._edit(execution=execution) - return res.response + res = self._edit(execution=execution) + return res.response def update_output_model(self, model_uri, name=None, comment=None, tags=None): """ @@ -536,16 +537,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): model = None model_id = '' - # store model id - self.data.execution.model = model_id + with self._edit_lock: + # store model id + self.data.execution.model = model_id - # Auto populate input field from model, if they are empty - if update_task_design and not self.data.execution.model_desc: - self.data.execution.model_desc = model.design if model else '' - if update_task_labels and not self.data.execution.model_labels: - self.data.execution.model_labels = model.labels if model else {} + # Auto populate input field from model, if they are empty + if update_task_design and not self.data.execution.model_desc: + self.data.execution.model_desc = model.design if model else '' + if update_task_labels and not self.data.execution.model_labels: + self.data.execution.model_labels = model.labels if model else {} - self._edit(execution=self.data.execution) + self._edit(execution=self.data.execution) def set_parameters(self, *args, **kwargs): """ @@ -580,12 +582,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): # force cast all variables to strings (so that we can later edit them in UI) parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()} - execution = self.data.execution - if execution is None: - execution = tasks.Execution(parameters=parameters) - else: - execution.parameters = parameters - self._edit(execution=execution) + with self._edit_lock: + execution = self.data.execution + if execution is None: + execution = tasks.Execution(parameters=parameters) + else: + execution.parameters = parameters + self._edit(execution=execution) def set_parameter(self, name, value, description=None): """ @@ -630,14 +633,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): :param dict enumeration: For example: {str(label): integer(id)} """ enumeration = enumeration or {} - execution = self.data.execution - if enumeration is None: - return - if not (isinstance(enumeration, dict) - and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())): - raise ValueError('Expected label to be a dict[str => int]') - execution.model_labels = enumeration - self._edit(execution=execution) + with self._edit_lock: + execution = self.data.execution + if enumeration is None: + return + if not (isinstance(enumeration, dict) + and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())): + raise ValueError('Expected label to be a dict[str => int]') + execution.model_labels = enumeration + self._edit(execution=execution) def set_artifacts(self, artifacts_list=None): """ @@ -650,16 +654,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): if not (isinstance(artifacts_list, (list, tuple)) and all(isinstance(a, tasks.Artifact) for a in artifacts_list)): raise ValueError('Expected artifacts to [tasks.Artifacts]') - execution = self.data.execution - execution.artifacts = artifacts_list - self._edit(execution=execution) + with self._edit_lock: + execution = self.data.execution + execution.artifacts = artifacts_list + self._edit(execution=execution) def _set_model_design(self, design=None): - execution = self.data.execution - if design is not None: - execution.model_desc = Model._wrap_design(design) + with self._edit_lock: + execution = self.data.execution + if design is not None: + execution.model_desc = Model._wrap_design(design) - self._edit(execution=execution) + self._edit(execution=execution) def get_labels_enumeration(self): """ diff --git a/trains/backend_interface/util.py b/trains/backend_interface/util.py index 5660c5b3..09afe664 100644 --- a/trains/backend_interface/util.py +++ b/trains/backend_interface/util.py @@ -36,8 +36,8 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o log = get_logger() if len(results) > 1: - log.warn('More than one {entity} found when searching for `{query}`' - ' (showing first {show_results} {entity}s follow)'.format(**locals())) + log.warning('More than one {entity} found when searching for `{query}`' + ' (showing first {show_results} {entity}s follow)'.format(**locals())) if sort_by_date: # sort results based on timestamp and return the newest one if hasattr(results[0], 'last_update'): @@ -49,7 +49,7 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o for i, obj in enumerate(o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]): selected = 'Selected' if i == 0 else 'Additionally found' - log.warn('{selected} {entity} `{obj[name]}` (id={obj[id]})'.format(**locals())) + log.warning('{selected} {entity} `{obj[name]}` (id={obj[id]})'.format(**locals())) if raise_on_error: raise ValueError('More than one {entity}s found when searching for ``{query}`'.format(**locals())) diff --git a/trains/binding/frameworks/tensorflow_bind.py b/trains/binding/frameworks/tensorflow_bind.py index dbf748ce..6385e214 100644 --- a/trains/binding/frameworks/tensorflow_bind.py +++ b/trains/binding/frameworks/tensorflow_bind.py @@ -5,13 +5,13 @@ import threading from collections import defaultdict from functools import partial from io import BytesIO -from logging import ERROR, WARNING, getLogger from typing import Any import numpy as np from PIL import Image -from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter +from ...debugging.log import LoggerRoot +from ..frameworks import _patched_call, WeightsFileHandler, _Empty from ..import_bind import PostImportHookPatching from ...config import running_remotely from ...model import InputModel, OutputModel, Framework @@ -187,7 +187,8 @@ class EventTrainsWriter(object): else: val = val[:, :, [0, 1, 2]] except Exception: - self._logger.warning('Failed decoding debug image [%d, %d, %d]' % (width, height, color_channels)) + LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]' + % (width, height, color_channels)) val = None return val @@ -213,7 +214,7 @@ class EventTrainsWriter(object): tile_size = res.shape[0] * res.shape[1] img_data_np = res.reshape(tile_size, tile_size, -1) - self._logger.report_image_and_upload( + self._logger.report_image( title=title, series=series, iteration=step, @@ -419,7 +420,7 @@ class EventTrainsWriter(object): msg_dict.pop('wallTime', None) keys_list = [key for key in msg_dict.keys() if len(key) > 0] keys_list = ', '.join(keys_list) - self._logger.debug('event summary not found, message type unsupported: %s' % keys_list) + LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list) return value_dicts = summary.get('value') walltime = walltime or msg_dict.get('step') @@ -431,19 +432,20 @@ class EventTrainsWriter(object): step = int(event.step) else: step = 0 - self._logger.debug('Recieved event without step, assuming step = {}'.format(step), WARNING) + LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step)) else: step = int(step) self._max_step = max(self._max_step, step) if value_dicts is None: - self._logger.debug("Summary with arrived without 'value'", ERROR) + LoggerRoot.get_base_logger().debug("Summary arrived without 'value'") return for vdict in value_dicts: tag = vdict.pop('tag', None) if tag is None: # we should not get here - self._logger.debug('No tag for \'value\' existing keys %s' % ', '.join(vdict.keys())) + LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s' + % ', '.join(vdict.keys())) continue metric, values = get_data(vdict, supported_metrics) if metric == 'simpleValue': @@ -459,7 +461,8 @@ class EventTrainsWriter(object): elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT': self._add_plot(tag, step, values, vdict) else: - self._logger.debug('Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys))) + LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]' + % (tag, ', '.join(vdict.keys))) continue def get_logdir(self): @@ -589,7 +592,7 @@ class PatchSummaryToEventTransformer(object): setattr(SummaryToEventTransformer, 'trains', property(PatchSummaryToEventTransformer.trains_object)) except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + LoggerRoot.get_base_logger().debug(str(ex)) if 'torch' in sys.modules: try: @@ -603,7 +606,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + LoggerRoot.get_base_logger().debug(str(ex)) if 'tensorboardX' in sys.modules: try: @@ -619,7 +622,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + LoggerRoot.get_base_logger().debug(str(ex)) if PatchSummaryToEventTransformer.__original_getattributeX is None: try: @@ -633,7 +636,7 @@ class PatchSummaryToEventTransformer(object): # this is a new version of TensorflowX pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + LoggerRoot.get_base_logger().debug(str(ex)) @staticmethod def _patched_add_eventT(self, *args, **kwargs): @@ -717,7 +720,7 @@ class _ModelAdapter(object): super(_ModelAdapter, self).__init__() super(_ModelAdapter, self).__setattr__('_model', model) super(_ModelAdapter, self).__setattr__('_output_model', output_model) - super(_ModelAdapter, self).__setattr__('_logger', getLogger('TrainsModelAdapter')) + super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger()) def __getattr__(self, attr): return getattr(self._model, attr) @@ -800,7 +803,7 @@ class PatchModelCheckPointCallback(object): property(PatchModelCheckPointCallback.trains_object)) except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) @staticmethod def _patched_getattribute(self, attr): @@ -878,7 +881,7 @@ class PatchTensorFlowEager(object): except ImportError: pass except Exception as ex: - getLogger(TrainsFrameworkAdapter).debug(str(ex)) + LoggerRoot.get_base_logger().debug(str(ex)) @staticmethod def _get_event_writer(writer): @@ -905,7 +908,7 @@ class PatchTensorFlowEager(object): try: event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy()) except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs) @staticmethod @@ -915,7 +918,7 @@ class PatchTensorFlowEager(object): try: event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy()) except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs) @staticmethod @@ -926,7 +929,7 @@ class PatchTensorFlowEager(object): event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(), max_keep_images=max_images) except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name, **kwargs) @@ -1024,7 +1027,7 @@ class PatchKerasModelIO(object): keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model) keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model) except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) @staticmethod def _updated_config(original_fn, self): @@ -1052,7 +1055,7 @@ class PatchKerasModelIO(object): framework=Framework.keras, ) except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) return config @@ -1102,7 +1105,7 @@ class PatchKerasModelIO(object): return model except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) return self @@ -1184,7 +1187,7 @@ class PatchKerasModelIO(object): # if anyone asks, we were here self.trains_out_model._processed = True except Exception as ex: - getLogger(TrainsFrameworkAdapter).warning(str(ex)) + LoggerRoot.get_base_logger().warning(str(ex)) @staticmethod def _save_model(original_fn, model, filepath, *args, **kwargs): diff --git a/trains/binding/matplotlib_bind.py b/trains/binding/matplotlib_bind.py index e3e01c2b..c8483fc8 100644 --- a/trains/binding/matplotlib_bind.py +++ b/trains/binding/matplotlib_bind.py @@ -329,19 +329,19 @@ class PatchedMatplotlib: PatchedMatplotlib._global_image_counter += 1 title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter - logger.report_image_and_upload(title=title, series='plot image', path=image, - delete_after_upload=True, - iteration=PatchedMatplotlib._global_image_counter - if plot_title else 0) + logger.report_image(title=title, series='plot image', local_path=image, + delete_after_upload=True, + iteration=PatchedMatplotlib._global_image_counter + if plot_title else 0) else: # send the plot as plotly with embedded image PatchedMatplotlib._global_plot_counter += 1 title = plot_title or 'untitled %d' % PatchedMatplotlib._global_plot_counter - logger.report_image_plot_and_upload(title=title, series='plot image', path=image, - delete_after_upload=True, - iteration=PatchedMatplotlib._global_plot_counter - if plot_title else 0) + logger._report_image_plot_and_upload(title=title, series='plot image', path=image, + delete_after_upload=True, + iteration=PatchedMatplotlib._global_plot_counter + if plot_title else 0) except Exception: # plotly failed diff --git a/trains/config/__init__.py b/trains/config/__init__.py index f11e0eb2..0647c550 100644 --- a/trains/config/__init__.py +++ b/trains/config/__init__.py @@ -18,7 +18,7 @@ def get_cache_dir(): cache_base_dir = Path( expandvars( expanduser( - config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR + TRAINS_CACHE_DIR.get() or config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR ) ) ) diff --git a/trains/config/defs.py b/trains/config/defs.py index b13f94bd..acc1f55c 100644 --- a/trains/config/defs.py +++ b/trains/config/defs.py @@ -7,7 +7,6 @@ from pathlib2 import Path SESSION_CACHE_FILE = ".session.json" DEFAULT_CACHE_DIR = str(Path(tempfile.gettempdir()) / "trains_cache") - TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID") LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool) NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int) @@ -16,6 +15,7 @@ LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LO DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME") DEV_TASK_NO_REUSE = EnvEntry("TRAINS_TASK_NO_REUSE", "ALG_TASK_NO_REUSE", type=bool) TASK_LOG_ENVIRONMENT = EnvEntry("TRAINS_LOG_ENVIRONMENT", "ALG_LOG_ENVIRONMENT", type=str) +TRAINS_CACHE_DIR = EnvEntry("TRAINS_CACHE_DIR", "ALG_CACHE_DIR") LOG_LEVEL_ENV_VAR = EnvEntry("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL", converter=or_(int, str)) diff --git a/trains/logger.py b/trains/logger.py index 7a45ab51..edc12c76 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -1,12 +1,9 @@ import logging -import re -import sys -import threading -from functools import wraps import numpy as np from pathlib2 import Path +from .backend_interface.logger import StdStreamPatch, LogFlusher from .debugging.log import LoggerRoot from .backend_interface.task.development.worker import DevWorker from .backend_interface.task.log import TaskHandler @@ -17,36 +14,6 @@ from .backend_interface.task import Task as _Task from .config import running_remotely, get_cache_dir -def _safe_names(func): - """ - Validate the form of title and series parameters. - - This decorator assert that a method receives 'title' and 'series' as its - first positional arguments, and that their values have only legal characters. - - '\', '/' and ':' will be replaced automatically by '_' - Whitespace chars will be replaced automatically by ' ' - """ - _replacements = { - '_': re.compile(r"[/\\:]"), - ' ': re.compile(r"[\s]"), - } - - def _make_safe(value): - for repl, regex in _replacements.items(): - value = regex.sub(repl, value) - return value - - @wraps(func) - def fixed_names(self, title, series, *args, **kwargs): - title = _make_safe(title) - series = _make_safe(series) - - func(self, title, series, *args, **kwargs) - - return fixed_names - - class Logger(object): """ Console log and metric statistics interface. @@ -56,9 +23,6 @@ class Logger(object): **Usage:** :func:`Logger.current_logger` or :func:`Task.get_logger` """ SeriesInfo = SeriesInfo - _stdout_proxy = None - _stderr_proxy = None - _stdout_original_write = None def __init__(self, private_task): """ @@ -75,67 +39,11 @@ class Logger(object): self._report_worker = None self._task_handler = None - if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely(): - Logger._stdout_proxy = PrintPatchLogger(sys.stdout, self, level=logging.INFO) - Logger._stderr_proxy = PrintPatchLogger(sys.stderr, self, level=logging.ERROR) - self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100) - # noinspection PyBroadException - try: - if Logger._stdout_original_write is None: - Logger._stdout_original_write = sys.stdout.write - # this will only work in python 3, guard it with try/catch - if not hasattr(sys.stdout, '_original_write'): - sys.stdout._original_write = sys.stdout.write - sys.stdout.write = stdout__patched__write__ - if not hasattr(sys.stderr, '_original_write'): - sys.stderr._original_write = sys.stderr.write - sys.stderr.write = stderr__patched__write__ - except Exception: - pass - sys.stdout = Logger._stdout_proxy - sys.stderr = Logger._stderr_proxy - # patch the base streams of sys (this way colorama will keep its ANSI colors) - # noinspection PyBroadException - try: - sys.__stderr__ = sys.stderr - except Exception: - pass - # noinspection PyBroadException - try: - sys.__stdout__ = sys.stdout - except Exception: - pass - - # now check if we have loguru and make it re-register the handlers - # because it sores internally the stream.write function, which we cant patch - # noinspection PyBroadException - try: - from loguru import logger - register_stderr = None - register_stdout = None - for k, v in logger._handlers.items(): - if v._name == '': - register_stderr = k - elif v._name == '': - register_stderr = k - if register_stderr is not None: - logger.remove(register_stderr) - logger.add(sys.stderr) - if register_stdout is not None: - logger.remove(register_stdout) - logger.add(sys.stdout) - except Exception: - pass - - elif DevWorker.report_stdout and not running_remotely(): - self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100) - if Logger._stdout_proxy: - Logger._stdout_proxy.connect(self) - if Logger._stderr_proxy: - Logger._stderr_proxy.connect(self) + StdStreamPatch.patch_std_streams(self) @classmethod def current_logger(cls): + # type: () -> Logger """ Return a logger object for the current task. Can be called from anywhere in the code @@ -147,92 +55,24 @@ class Logger(object): return None return task.get_logger() - def console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs): + def report_text(self, msg, level=logging.INFO, print_console=True, *args, **_): """ - print text to log (same as print to console, and also prints to console) + print text to log and optionally also prints to console - :param msg: text to print to the console (always send to the backend and displayed in console) - :param level: logging level, default: logging.INFO - :param omit_console: If True we only send 'msg' to log (no console print) + :param str msg: text to print to the console (always send to the backend and displayed in console) + :param int level: logging level, default: logging.INFO + :param bool print_console: If True we also print 'msg' to console """ - try: - level = int(level) - except (TypeError, ValueError): - self._task.log.log(level=logging.ERROR, - msg='Logger failed casting log level "%s" to integer' % str(level)) - level = logging.INFO - - if not running_remotely(): - # noinspection PyBroadException - try: - record = self._task.log.makeRecord( - "console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None - ) - # find the task handler that matches our task - if not self._task_handler: - self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers - if isinstance(h, TaskHandler) and h.task_id == self._task.id][0] - self._task_handler.emit(record) - except Exception: - LoggerRoot.get_base_logger().warning(msg='Logger failed sending log: [level %s]: "%s"' - % (str(level), str(msg))) - - if not omit_console: - # if we are here and we grabbed the stdout, we need to print the real thing - if DevWorker.report_stdout and not running_remotely(): - # noinspection PyBroadException - try: - # make sure we are writing to the original stdout - Logger._stdout_original_write(str(msg)+'\n') - except Exception: - pass - else: - print(str(msg)) - - # if task was not started, we have to start it - self._start_task_if_needed() - - def report_text(self, msg, level=logging.INFO, print_console=False, *args, **_): - return self.console(msg, level, not print_console, *args, **_) - - def debug(self, msg, *args, **kwargs): - """ Print information to the log. This is the same as console(msg, logging.DEBUG) """ - self._task.log.log(msg=msg, level=logging.DEBUG, *args, **kwargs) - - def info(self, msg, *args, **kwargs): - """ Print information to the log. This is the same as console(msg, logging.INFO) """ - self._task.log.log(msg=msg, level=logging.INFO, *args, **kwargs) - - def warn(self, msg, *args, **kwargs): - """ Print a warning to the log. This is the same as console(msg, logging.WARNING) """ - self._task.log.log(msg=msg, level=logging.WARNING, *args, **kwargs) - - warning = warn - - def error(self, msg, *args, **kwargs): - """ Print an error to the log. This is the same as console(msg, logging.ERROR) """ - self._task.log.log(msg=msg, level=logging.ERROR, *args, **kwargs) - - def fatal(self, msg, *args, **kwargs): - """ Print a fatal error to the log. This is the same as console(msg, logging.FATAL) """ - self._task.log.log(msg=msg, level=logging.FATAL, *args, **kwargs) - - def critical(self, msg, *args, **kwargs): - """ Print a critical error to the log. This is the same as console(msg, logging.CRITICAL) """ - self._task.log.log(msg=msg, level=logging.CRITICAL, *args, **kwargs) + return self._console(msg, level, not print_console, *args, **_) def report_scalar(self, title, series, value, iteration): """ Report a scalar value - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param value: Reported value - :type value: float - :param iteration: Iteration number - :type value: int + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param float value: Reported value + :param int iteration: Iteration number """ # if task was not started, we have to start it @@ -244,18 +84,12 @@ class Logger(object): """ Report a histogram plot - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param values: Reported values (or numpy array) - :type values: [float] - :param iteration: Iteration number - :type iteration: int - :param labels: optional, labels for each bar group. - :type labels: list of strings. - :param xlabels: optional label per entry in the vector (bucket in the histogram) - :type xlabels: list of strings. + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param list(float) values: Reported values (or numpy array) + :param int iteration: Iteration number + :param list(str) labels: optional, labels for each bar group. + :param list(str) xlabels: optional label per entry in the vector (bucket in the histogram) """ return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels) @@ -263,18 +97,12 @@ class Logger(object): """ Report a histogram plot - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param values: Reported values (or numpy array) - :type values: [float] - :param iteration: Iteration number - :type iteration: int - :param labels: optional, labels for each bar group. - :type labels: list of strings. - :param xlabels: optional label per entry in the vector (bucket in the histogram) - :type xlabels: list of strings. + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param list(float) values: Reported values (or numpy array) + :param int iteration: Iteration number + :param list(str) labels: optional, labels for each bar group. + :param list(str) xlabels: optional label per entry in the vector (bucket in the histogram) """ if not isinstance(values, np.ndarray): @@ -292,24 +120,19 @@ class Logger(object): xlabels=xlabels, ) - def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines', reverse_xaxis=False, comment=None): + def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines', + reverse_xaxis=False, comment=None): """ Report a (possibly multiple) line plot. - :param title: Title (AKA metric) - :type title: str - :param series: All the series' data, one for each line in the plot. - :type series: An iterable of LineSeriesInfo. - :param iteration: Iteration number - :type iteration: int - :param xaxis: optional x-axis title - :param yaxis: optional y-axis title - :param mode: scatter plot with 'lines'/'markers'/'lines+markers' - :type mode: str - :param reverse_xaxis: If true X axis will be displayed from high to low (reversed) - :type reverse_xaxis: bool - :param comment: comment underneath the title - :type comment: str + :param str title: Title (AKA metric) + :param list(LineSeriesInfo) series: All the series' data, one for each line in the plot. + :param int iteration: Iteration number + :param str xaxis: optional x-axis title + :param str yaxis: optional y-axis title + :param str mode: scatter plot with 'lines'/'markers'/'lines+markers' + :param bool reverse_xaxis: If true X axis will be displayed from high to low (reversed) + :param str comment: comment underneath the title """ series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series] @@ -333,21 +156,15 @@ class Logger(object): """ Report a 2d scatter graph (with lines) - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param scatter: A scattered data: list of (pairs of x,y) (or numpy array) - :type scatter: ndarray or list - :param iteration: Iteration number - :type iteration: int - :param xaxis: optional x-axis title - :param yaxis: optional y-axis title - :param labels: label (text) per point in the scatter (in the same order) - :param mode: scatter plot with 'lines'/'markers'/'lines+markers' - :type mode: str - :param comment: comment underneath the title - :type comment: str + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param np.ndarray scatter: A scattered data: list of (pairs of x,y) (or numpy array) + :param int iteration: Iteration number + :param str xaxis: optional x-axis title + :param str yaxis: optional y-axis title + :param list(str) labels: label (text) per point in the scatter (in the same order) + :param str mode: scatter plot with 'lines'/'markers'/'lines+markers' + :param str comment: comment underneath the title """ if not isinstance(scatter, np.ndarray): @@ -375,18 +192,15 @@ class Logger(object): """ Report a 3d scatter graph (with markers) - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param scatter: A scattered data: list of (pairs of x,y,z) (or numpy array) or list of series [[(x1,y1,z1)...]] - :type scatter: ndarray or list - :param iteration: Iteration number - :type iteration: int - :param labels: label (text) per point in the scatter (in the same order) - :param mode: scatter plot with 'lines'/'markers'/'lines+markers' - :param fill: fill area under the curve - :param comment: comment underneath the title + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param np.ndarray scatter: A scattered data: list of (pairs of x,y,z) (or numpy array) + or list of series [[(x1,y1,z1)...]] + :param int iteration: Iteration number + :param list(str) labels: label (text) per point in the scatter (in the same order) + :param str mode: scatter plot with 'lines'/'markers'/'lines+markers' + :param bool fill: fill area under the curve + :param str comment: comment underneath the title """ # check if multiple series multi_series = ( @@ -429,17 +243,13 @@ class Logger(object): """ Report a heat-map matrix - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param matrix: A heat-map matrix (example: confusion matrix) - :type matrix: ndarray - :param iteration: Iteration number - :type iteration: int - :param xlabels: optional label per column of the matrix - :param ylabels: optional label per row of the matrix - :param comment: comment underneath the title + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param np.ndarray matrix: A heat-map matrix (example: confusion matrix) + :param int iteration: Iteration number + :param list(str) xlabels: optional label per column of the matrix + :param list(str) ylabels: optional label per row of the matrix + :param str comment: comment underneath the title """ if not isinstance(matrix, np.ndarray): @@ -463,16 +273,12 @@ class Logger(object): Same as report_confusion_matrix Report a heat-map matrix - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param matrix: A heat-map matrix (example: confusion matrix) - :type matrix: ndarray - :param iteration: Iteration number - :type iteration: int - :param xlabels: optional label per column of the matrix - :param ylabels: optional label per row of the matrix + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param np.ndarray matrix: A heat-map matrix (example: confusion matrix) + :param int iteration: Iteration number + :param list(str) xlabels: optional label per column of the matrix + :param list(str) ylabels: optional label per row of the matrix """ return self.report_confusion_matrix(title, series, matrix, iteration, xlabels=xlabels, ylabels=ylabels) @@ -481,21 +287,17 @@ class Logger(object): """ Report a 3d surface (same data as heat-map matrix, only presented differently) - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param matrix: A heat-map matrix (example: confusion matrix) - :type matrix: ndarray - :param iteration: Iteration number - :type iteration: int - :param xlabels: optional label per column of the matrix - :param ylabels: optional label per row of the matrix - :param xtitle: optional x-axis title - :param ytitle: optional y-axis title - :param ztitle: optional z-axis title - :param camera: X,Y,Z camera position. def: (1,1,1) - :param comment: comment underneath the title + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param np.ndarray matrix: A heat-map matrix (example: confusion matrix) + :param int iteration: Iteration number + :param list(str) xlabels: optional label per column of the matrix + :param list(str) ylabels: optional label per row of the matrix + :param str xtitle: optional x-axis title + :param str ytitle: optional y-axis title + :param str ztitle: optional z-axis title + :param list(float) camera: X,Y,Z camera position. def: (1,1,1) + :param str comment: comment underneath the title """ if not isinstance(matrix, np.ndarray): @@ -518,56 +320,24 @@ class Logger(object): comment=comment, ) - @_safe_names - def report_image(self, title, series, src, iteration): - """ - Report an image, and register the 'src' as url content. - - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param src: Image source URI. This URI will be used by the webapp and workers when trying to obtain the image \ - for presentation of processing. Currently only http(s), file and s3 schemes are supported. - :type src: str - :param iteration: Iteration number - :type iteration: int - """ - - # if task was not started, we have to start it - self._start_task_if_needed() - - self._task.reporter.report_image( - title=title, - series=series, - src=src, - iter=iteration, - ) - - @_safe_names - def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None, - delete_after_upload=False): + def report_image(self, title, series, iteration, local_path=None, matrix=None, max_image_history=None, + delete_after_upload=False): """ Report an image and upload its contents. Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename) describing the task ID, title, series and iteration. - :param title: Title (AKA metric) - :type title: str - :param series: Series (AKA variant) - :type series: str - :param iteration: Iteration number - :type iteration: int - :param path: A path to an image file. Required unless matrix is provided. - :type path: str - :param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided. - :type matrix: str - :param max_image_history: maximum number of image to store per metric/variant combination \ - use negative value for unlimited. default is set in global configuration (default=5) - :type max_image_history: int - :param delete_after_upload: if True, one the file was uploaded the local copy will be deleted - :type delete_after_upload: boolean + :param str title: Title (AKA metric) + :param str series: Series (AKA variant) + :param int iteration: Iteration number + :param str local_path: A path to an image file. Required unless matrix is provided. + Required unless matrix is provided. + :param np.ndarray matrix: A 3D numpy.ndarray object containing image data (RGB). + Required unless filename is provided. + :param int max_image_history: maximum number of image to store per metric/variant combination + use negative value for unlimited. default is set in global configuration (default=5) + :param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted """ # if task was not started, we have to start it @@ -584,7 +354,7 @@ class Logger(object): self._task.reporter.report_image_and_upload( title=title, series=series, - path=path, + path=local_path, matrix=matrix, iter=iteration, upload_uri=upload_uri, @@ -592,7 +362,138 @@ class Logger(object): delete_after_upload=delete_after_upload, ) - def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None, + def set_default_upload_destination(self, uri): + """ + Set the uri to upload all the debug images to. + + Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then + a link to the uploaded image is sent in the report + Notice: credentials for the upload destination will be pooled from the + global configuration file (i.e. ~/trains.conf) + + :param str uri: example: 's3://bucket/directory/' or 'file:///tmp/debug/' + :return: True if destination scheme is supported (i.e. s3:// file:// gc:// etc...) + """ + + # Create the storage helper + storage = StorageHelper.get(uri) + + # Verify that we can upload to this destination + uri = storage.verify_upload(folder_uri=uri) + + self._default_upload_destination = uri + + def get_default_upload_destination(self): + """ + Get the uri to upload all the debug images to. + + Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then + a link to the uploaded image is sent in the report + Notice: credentials for the upload destination will be pooled from the + global configuration file (i.e. ~/trains.conf) + + :return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc... + """ + return self._default_upload_destination or self._task._get_default_report_storage_uri() + + def flush(self): + """ + Flush cached reports and console outputs to backend. + + :return: True if successful + """ + self._flush_stdout_handler() + if self._task: + return self._task.flush() + return False + + def get_flush_period(self): + """ + :return: logger flush period in seconds + """ + if self._flusher: + return self._flusher.period + return None + + def set_flush_period(self, period): + """ + Set the period of the logger flush. + + :param float period: The period to flush the logger in seconds. If None or 0, + There will be no periodic flush. + """ + if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \ + not running_remotely() and period is not None: + period = min(period or DevWorker.report_period, DevWorker.report_period) + + if not period: + if self._flusher: + self._flusher.exit() + self._flusher = None + elif self._flusher: + self._flusher.set_period(period) + else: + self._flusher = LogFlusher(self, period) + self._flusher.start() + + def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None, + delete_after_upload=False): + """ + Backwards compatibility, please use report_image instead + """ + self.report_image(title=title, series=series, iteration=iteration, local_path=path, matrix=matrix, + max_image_history=max_image_history, delete_after_upload=delete_after_upload) + + @classmethod + def _remove_std_logger(cls): + StdStreamPatch.remove_std_logger() + + def _console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs): + """ + print text to log (same as print to console, and also prints to console) + + :param msg: text to print to the console (always send to the backend and displayed in console) + :param level: logging level, default: logging.INFO + :param omit_console: If True we only send 'msg' to log (no console print) + """ + try: + level = int(level) + except (TypeError, ValueError): + self._task.log.log(level=logging.ERROR, + msg='Logger failed casting log level "%s" to integer' % str(level)) + level = logging.INFO + + if not running_remotely(): + # noinspection PyBroadException + try: + record = self._task.log.makeRecord( + "console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None + ) + # find the task handler that matches our task + if not self._task_handler: + self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers + if isinstance(h, TaskHandler) and h.task_id == self._task.id][0] + self._task_handler.emit(record) + except Exception: + LoggerRoot.get_base_logger().warning(msg='Logger failed sending log: [level %s]: "%s"' + % (str(level), str(msg))) + + if not omit_console: + # if we are here and we grabbed the stdout, we need to print the real thing + if DevWorker.report_stdout and not running_remotely(): + # noinspection PyBroadException + try: + # make sure we are writing to the original stdout + StdStreamPatch.stdout_original_write(str(msg)+'\n') + except Exception: + pass + else: + print(str(msg)) + + # if task was not started, we have to start it + self._start_task_if_needed() + + def _report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None, delete_after_upload=False): """ Report an image, upload its contents, and present in plots section using plotly @@ -639,7 +540,7 @@ class Logger(object): delete_after_upload=delete_after_upload, ) - def report_file_and_upload(self, title, series, iteration, path=None, max_file_history=None, + def _report_file_and_upload(self, title, series, iteration, path=None, max_file_history=None, delete_after_upload=False): """ Upload a file and report it as link in the debug images section. @@ -684,92 +585,6 @@ class Logger(object): delete_after_upload=delete_after_upload, ) - def set_default_upload_destination(self, uri): - """ - Set the uri to upload all the debug images to. - - Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then - a link to the uploaded image is sent in the report - Notice: credentials for the upload destination will be pooled from the - global configuration file (i.e. ~/trains.conf) - - :param uri: example: 's3://bucket/directory/' or 'file:///tmp/debug/' - :return: True if destination scheme is supported (i.e. s3:// file:// gc:// etc...) - """ - - # Create the storage helper - storage = StorageHelper.get(uri) - - # Verify that we can upload to this destination - uri = storage.verify_upload(folder_uri=uri) - - self._default_upload_destination = uri - - def get_default_upload_destination(self): - """ - Get the uri to upload all the debug images to. - - Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then - a link to the uploaded image is sent in the report - Notice: credentials for the upload destination will be pooled from the - global configuration file (i.e. ~/trains.conf) - - :return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc... - """ - return self._default_upload_destination or self._task._get_default_report_storage_uri() - - def flush(self): - """ - Flush cached reports and console outputs to backend. - - :return: True if successful - """ - self._flush_stdout_handler() - if self._task: - return self._task.flush() - return False - - def get_flush_period(self): - if self._flusher: - return self._flusher.period - return None - - def set_flush_period(self, period): - """ - Set the period of the logger flush. - - :param period: The period to flush the logger in seconds. If None or 0, - There will be no periodic flush. - """ - if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \ - not running_remotely() and period is not None: - period = min(period or DevWorker.report_period, DevWorker.report_period) - - if not period: - if self._flusher: - self._flusher.exit() - self._flusher = None - elif self._flusher: - self._flusher.set_period(period) - else: - self._flusher = _Flusher(self, period) - self._flusher.start() - - @classmethod - def _remove_std_logger(self): - if isinstance(sys.stdout, PrintPatchLogger): - # noinspection PyBroadException - try: - sys.stdout.connect(None) - except Exception: - pass - if isinstance(sys.stderr, PrintPatchLogger): - # noinspection PyBroadException - try: - sys.stderr.connect(None) - except Exception: - pass - def _start_task_if_needed(self): # do not refresh the task status read from cached variable _status if str(self._task._status) == str(tasks.TaskStatusEnum.created): @@ -780,121 +595,3 @@ class Logger(object): def _flush_stdout_handler(self): if self._task_handler and DevWorker.report_stdout: self._task_handler.flush() - - -def stdout__patched__write__(*args, **kwargs): - if Logger._stdout_proxy: - return Logger._stdout_proxy.write(*args, **kwargs) - return sys.stdout._original_write(*args, **kwargs) - - -def stderr__patched__write__(*args, **kwargs): - if Logger._stderr_proxy: - return Logger._stderr_proxy.write(*args, **kwargs) - return sys.stderr._original_write(*args, **kwargs) - - -class PrintPatchLogger(object): - """ - Allowed patching a stream into the logger. - Used for capturing and logging stdin and stderr when running in development mode pseudo worker. - """ - patched = False - lock = threading.Lock() - recursion_protect_lock = threading.RLock() - - def __init__(self, stream, logger=None, level=logging.INFO): - PrintPatchLogger.patched = True - self._terminal = stream - self._log = logger - self._log_level = level - self._cur_line = '' - - def write(self, message): - # make sure that we do not end up in infinite loop (i.e. log.console ends up calling us) - if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned(): - try: - self.lock.acquire() - with PrintPatchLogger.recursion_protect_lock: - if hasattr(self._terminal, '_original_write'): - self._terminal._original_write(message) - else: - self._terminal.write(message) - - do_flush = '\n' in message - do_cr = '\r' in message - self._cur_line += message - if (not do_flush and not do_cr) or not message: - return - last_lf = self._cur_line.rindex('\n' if do_flush else '\r') - next_line = self._cur_line[last_lf + 1:] - cur_line = self._cur_line[:last_lf + 1].rstrip() - self._cur_line = next_line - finally: - self.lock.release() - - if cur_line: - with PrintPatchLogger.recursion_protect_lock: - # noinspection PyBroadException - try: - if self._log: - self._log.console(cur_line, level=self._log_level, omit_console=True) - except Exception: - # what can we do, nothing - pass - else: - if hasattr(self._terminal, '_original_write'): - self._terminal._original_write(message) - else: - self._terminal.write(message) - - def connect(self, logger): - self._cur_line = '' - self._log = logger - - def __getattr__(self, attr): - if attr in ['_log', '_terminal', '_log_level', '_cur_line']: - return self.__dict__.get(attr) - return getattr(self._terminal, attr) - - def __setattr__(self, key, value): - if key in ['_log', '_terminal', '_log_level', '_cur_line']: - self.__dict__[key] = value - else: - return setattr(self._terminal, key, value) - - -class _Flusher(threading.Thread): - def __init__(self, logger, period, **kwargs): - super(_Flusher, self).__init__(**kwargs) - self.daemon = True - - self._period = period - self._logger = logger - self._exit_event = threading.Event() - - @property - def period(self): - return self._period - - def run(self): - self._logger.flush() - # store original wait period - while True: - period = self._period - while not self._exit_event.wait(period or 1.0): - self._logger.flush() - # check if period is negative or None we should exit - if self._period is None or self._period < 0: - break - # check if period was changed, we should restart - self._exit_event.clear() - - def exit(self): - self._period = None - self._exit_event.set() - - def set_period(self, period): - self._period = period - # make sure we exit the previous wait - self._exit_event.set() diff --git a/trains/model.py b/trains/model.py index 15cdb212..0c54f708 100644 --- a/trains/model.py +++ b/trains/model.py @@ -340,7 +340,6 @@ class InputModel(BaseModel): name=None, tags=None, comment=None, - logger=None, is_package=False, create_as_published=False, framework=None, @@ -367,7 +366,6 @@ class InputModel(BaseModel): :param name: optional, name for the newly imported model :param tags: optional, list of strings as tags :param comment: optional, string description for the model - :param logger: The logger to use. If None, use the default logger :param is_package: Boolean. Indicates that the imported weights file is a package. If True, and a new model was created, a package tag will be added. :param create_as_published: Boolean. If True, and a new model is created, it will be published. @@ -386,8 +384,7 @@ class InputModel(BaseModel): )) if result.response.models: - if not logger: - logger = get_logger() + logger = get_logger() logger.debug('A model with uri "{}" already exists. Selecting it'.format(weights_url)) diff --git a/trains/task.py b/trains/task.py index 2f183db0..fe27529f 100644 --- a/trains/task.py +++ b/trains/task.py @@ -1,16 +1,15 @@ import atexit import os -import re import signal import sys import threading import time from argparse import ArgumentParser from collections import OrderedDict, Callable +from typing import Optional import psutil import six -from pathlib2 import Path from .binding.joblib_bind import PatchedJoblib from .backend_api.services import tasks, projects @@ -29,7 +28,7 @@ from .errors import UsageError from .logger import Logger from .model import InputModel, OutputModel, ARCHIVED_TAG from .task_parameters import TaskParameters -from .binding.artifacts import Artifacts +from .binding.artifacts import Artifacts, Artifact from .binding.environ_bind import EnvironmentBind, PatchOsFork from .binding.absl_bind import PatchAbsl from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \ @@ -41,6 +40,7 @@ from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO from .binding.matplotlib_bind import PatchedMatplotlib from .utilities.resource_monitor import ResourceMonitor from .utilities.seed import make_deterministic +from .utilities.dicts import ReadOnlyDict NotSet = object() @@ -113,6 +113,7 @@ class Task(_Task): @classmethod def current_task(cls): + # type: () -> Task """ Return the Current Task object for the main execution task (task context). :return: Task() object or None @@ -279,7 +280,7 @@ class Task(_Task): # The logger will automatically take care of all patching (we just need to make sure to initialize it) logger = task.get_logger() # show the debug metrics page in the log, it is very convenient - logger.console( + logger.report_text( 'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format( task._get_app_server(), task.project if task.project is not None else '*', @@ -439,12 +440,12 @@ class Task(_Task): task._setup_log(replace_existing=True) logger = task.get_logger() if closed_old_task: - logger.console('TRAINS Task: Closing old development task id={}'.format(default_task.get('id'))) + logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id'))) # print warning, reusing/creating a task if default_task_id: - logger.console('TRAINS Task: overwriting (reusing) task id=%s' % task.id) + logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id) else: - logger.console('TRAINS Task: created new task id=%s' % task.id) + logger.report_text('TRAINS Task: created new task id=%s' % task.id) # update current repository and put warning into logs if in_dev_mode and cls.__detect_repo_async: @@ -462,8 +463,8 @@ class Task(_Task): thread.start() return task - @staticmethod - def get_task(task_id=None, project_name=None, task_name=None): + @classmethod + def get_task(cls, task_id=None, project_name=None, task_name=None): """ Returns Task object based on either, task_id (system uuid) or task name @@ -472,7 +473,7 @@ class Task(_Task): :param task_name: task name (str) in within the selected project :return: Task() object """ - return Task.__get_task(task_id=task_id, project_name=project_name, task_name=task_name) + return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name) @property def output_uri(self): @@ -490,10 +491,14 @@ class Task(_Task): @property def artifacts(self): """ - dictionary of Task artifacts (name, artifact) + read-only dictionary of Task artifacts (name, artifact) :return: dict """ - return self._artifacts_manager.artifacts + if not Session.check_min_api_version('2.3'): + return ReadOnlyDict() + if not self.data.execution or not self.data.execution.artifacts: + return ReadOnlyDict() + return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts]) def set_comment(self, comment): """ @@ -553,6 +558,7 @@ class Task(_Task): raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__) def get_logger(self, flush_period=NotSet): + # type: (Optional[float]) -> Logger """ get a logger object for reporting based on the task @@ -663,6 +669,15 @@ class Task(_Task): """ self._artifacts_manager.unregister_artifact(name=name) + def get_registered_artifacts(self): + """ + dictionary of Task registered artifacts (name, artifact object) + Notice these objects can be modified, changes will be uploaded automatically + + :return: dict + """ + return self._artifacts_manager.registered_artifacts + def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False): """ Add static artifact to Task. Artifact file/object will be uploaded in the background @@ -671,6 +686,7 @@ class Task(_Task): :param str name: Artifact name. Notice! it will override previous artifact if name already exists :param object artifact_object: Artifact object to upload. Currently supports: - string / pathlib2.Path are treated as path to artifact file to upload + If wildcard or a folder is passed, zip file containing the local files will be created and uploaded. - dict will be stored as .json, - pandas.DataFrame will be stored as .csv.gz (compressed CSV file), - numpy.ndarray will be stored as .npz, @@ -937,7 +953,7 @@ class Task(_Task): if self._at_exit_called: return - self.get_logger().warn( + self.log.warning( "### TASK STOPPED - USER ABORTED - {} ###".format( stop_reason.upper().replace('_', ' ') ) @@ -1009,7 +1025,7 @@ class Task(_Task): # signal artifacts upload, and stop daemon self._artifacts_manager.stop(wait=True) # print artifacts summary - self.get_logger().console(self._artifacts_manager.summary) + self.get_logger().report_text(self._artifacts_manager.summary) def _at_exit(self): """ diff --git a/trains/utilities/check_updates.py b/trains/utilities/check_updates.py index 830b5d88..523a2856 100644 --- a/trains/utilities/check_updates.py +++ b/trains/utilities/check_updates.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import collections +import json import re import threading @@ -313,9 +314,11 @@ class CheckPackageUpdates(object): # noinspection PyBroadException try: + from ..version import __version__ cls._package_version_checked = True # Sending the request only for statistics - update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server) + update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server, + args=(__version__,)) update_statistics.daemon = True update_statistics.start() @@ -323,7 +326,6 @@ class CheckPackageUpdates(object): releases = [Version(r) for r in releases] latest_version = sorted(releases) - from ..version import __version__ cur_version = Version(__version__) if not cur_version.is_devrelease and not cur_version.is_prerelease: latest_version = [r for r in latest_version if not r.is_devrelease and not r.is_prerelease] @@ -336,8 +338,9 @@ class CheckPackageUpdates(object): return None @staticmethod - def get_version_from_updates_server(): + def get_version_from_updates_server(cur_version): try: - _ = requests.get('https://updates.trainsai.io/updates', timeout=1.0) + _ = requests.get('https://updates.trainsai.io/updates', + params=json.dumps({'versions': {'trains': str(cur_version)}}), timeout=1.0) except Exception: pass diff --git a/trains/utilities/dicts.py b/trains/utilities/dicts.py index 5c5649fd..fb0e2f15 100644 --- a/trains/utilities/dicts.py +++ b/trains/utilities/dicts.py @@ -3,6 +3,19 @@ _epsilon = 0.00001 +class ReadOnlyDict(dict): + def __readonly__(self, *args, **kwargs): + raise ValueError("This is a read only dictionary") + __setitem__ = __readonly__ + __delitem__ = __readonly__ + pop = __readonly__ + popitem = __readonly__ + clear = __readonly__ + update = __readonly__ + setdefault = __readonly__ + del __readonly__ + + class Logs: _logs_instances = [] diff --git a/trains/utilities/resource_monitor.py b/trains/utilities/resource_monitor.py index bda33b9d..a40a790f 100644 --- a/trains/utilities/resource_monitor.py +++ b/trains/utilities/resource_monitor.py @@ -1,3 +1,5 @@ +import logging +import warnings from time import time from threading import Thread, Event @@ -32,8 +34,8 @@ class ResourceMonitor(object): self._gpustat_fail = 0 self._gpustat = gpustat if not self._gpustat: - self._task.get_logger().console('TRAINS Monitor: GPU monitoring is not available, ' - 'run \"pip install gpustat\"') + self._task.get_logger().report_text('TRAINS Monitor: GPU monitoring is not available, ' + 'run \"pip install gpustat\"') def start(self): self._exit_event.clear() @@ -73,8 +75,8 @@ class ResourceMonitor(object): if IsTensorboardInit.tensorboard_used(): fallback_to_sec_as_iterations = False elif seconds_since_started >= self._wait_for_first_iteration: - self._task.get_logger().console('TRAINS Monitor: Could not detect iteration reporting, ' - 'falling back to iterations as seconds-from-start') + self._task.get_logger().report_text('TRAINS Monitor: Could not detect iteration reporting, ' + 'falling back to iterations as seconds-from-start') fallback_to_sec_as_iterations = True clear_readouts = True @@ -168,9 +170,11 @@ class ResourceMonitor(object): stats["memory_free_gb"] = bytes_to_megabytes(virtual_memory.available) / 1024 disk_use_percentage = psutil.disk_usage(Text(Path.home())).percent stats["disk_free_percent"] = 100.0-disk_use_percentage - sensor_stat = ( - psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {} - ) + with warnings.catch_warnings(): + if logging.root.level > logging.DEBUG: # If the logging level is bigger than debug, ignore + # psutil.sensors_temperatures warnings + warnings.simplefilter("ignore", category=RuntimeWarning) + sensor_stat = (psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {}) if "coretemp" in sensor_stat and len(sensor_stat["coretemp"]): stats["cpu_temperature"] = max([float(t.current) for t in sensor_stat["coretemp"]]) @@ -197,8 +201,8 @@ class ResourceMonitor(object): # something happened and we can't use gpu stats, self._gpustat_fail += 1 if self._gpustat_fail >= 3: - self._task.get_logger().console('TRAINS Monitor: GPU monitoring failed getting GPU reading, ' - 'switching off GPU monitoring') + self._task.get_logger().report_text('TRAINS Monitor: GPU monitoring failed getting GPU reading, ' + 'switching off GPU monitoring') self._gpustat = None return stats