Restructured Logger with nice clean interface.

Breaking changes: Logger no longer supports info/error/warning console() replaced with report_text()
This commit is contained in:
allegroai 2019-09-23 18:40:13 +03:00
parent 1dbe962879
commit 0b4f00af4d
20 changed files with 602 additions and 647 deletions

View File

@ -21,7 +21,7 @@ except ImportError:
logger = Task.current_task().get_logger()
# log text
logger.console("hello")
logger.report_text("hello")
# report scalar values
logger.report_scalar("example_scalar", "series A", iteration=0, value=100)
@ -49,11 +49,11 @@ logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter
# reporting images
m = np.eye(256, 256, dtype=np.float)
logger.report_image_and_upload("test case", "image float", iteration=1, matrix=m)
logger.report_image("test case", "image float", iteration=1, matrix=m)
m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image_and_upload("test case", "image uint8", iteration=1, matrix=m)
logger.report_image("test case", "image uint8", iteration=1, matrix=m)
m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)
logger.report_image_and_upload("test case", "image color red", iteration=1, matrix=m)
logger.report_image("test case", "image color red", iteration=1, matrix=m)
# flush reports (otherwise it will be flushed in the background, every couple of seconds)
logger.flush()

View File

@ -81,7 +81,7 @@ class CallResult(object):
# response.validate()
except Exception as e:
if logger:
logger.warn('Failed parsing response: %s' % str(e))
logger.warning('Failed parsing response: %s' % str(e))
return cls(meta=meta, response=response, response_data=response_data, request_cls=request_cls, session=session)
def ok(self):

View File

@ -215,7 +215,7 @@ class Session(TokenManager):
res.status_code == requests.codes.service_unavailable
and self.config.get("api.http.wait_on_maintenance_forever", True)
):
self._logger.warn(
self._logger.warning(
"Service unavailable: {} is undergoing maintenance, retrying...".format(
host
)

View File

@ -50,7 +50,7 @@ class S3BucketConfig(object):
configs = [cls(**entry) for entry in dict_list]
valid_configs = [conf for conf in configs if conf.is_valid()]
if log and len(valid_configs) < len(configs):
log.warn(
log.warning(
"Invalid bucket configurations detected for {}".format(
", ".join(
"/".join((config.host, config.bucket))

View File

@ -0,0 +1,212 @@
import logging
import sys
import threading
from ..backend_interface.task.development.worker import DevWorker
from ..backend_interface.task.log import TaskHandler
from ..config import running_remotely
class StdStreamPatch(object):
_stdout_proxy = None
_stderr_proxy = None
_stdout_original_write = None
@staticmethod
def patch_std_streams(logger):
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO)
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR)
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
# noinspection PyBroadException
try:
if StdStreamPatch._stdout_original_write is None:
StdStreamPatch._stdout_original_write = sys.stdout.write
# this will only work in python 3, guard it with try/catch
if not hasattr(sys.stdout, '_original_write'):
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = StdStreamPatch._stdout__patched__write__
if not hasattr(sys.stderr, '_original_write'):
sys.stderr._original_write = sys.stderr.write
sys.stderr.write = StdStreamPatch._stderr__patched__write__
except Exception:
pass
sys.stdout = StdStreamPatch._stdout_proxy
sys.stderr = StdStreamPatch._stderr_proxy
# patch the base streams of sys (this way colorama will keep its ANSI colors)
# noinspection PyBroadException
try:
sys.__stderr__ = sys.stderr
except Exception:
pass
# noinspection PyBroadException
try:
sys.__stdout__ = sys.stdout
except Exception:
pass
# now check if we have loguru and make it re-register the handlers
# because it sores internally the stream.write function, which we cant patch
# noinspection PyBroadException
try:
from loguru import logger
register_stderr = None
register_stdout = None
for k, v in logger._handlers.items():
if v._name == '<stderr>':
register_stderr = k
elif v._name == '<stdout>':
register_stderr = k
if register_stderr is not None:
logger.remove(register_stderr)
logger.add(sys.stderr)
if register_stdout is not None:
logger.remove(register_stdout)
logger.add(sys.stdout)
except Exception:
pass
elif DevWorker.report_stdout and not running_remotely():
logger._task_handler = TaskHandler(logger._task.session, logger._task.id, capacity=100)
if StdStreamPatch._stdout_proxy:
StdStreamPatch._stdout_proxy.connect(logger)
if StdStreamPatch._stderr_proxy:
StdStreamPatch._stderr_proxy.connect(logger)
@staticmethod
def remove_std_logger():
if isinstance(sys.stdout, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stdout.connect(None)
except Exception:
pass
if isinstance(sys.stderr, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stderr.connect(None)
except Exception:
pass
@staticmethod
def stdout_original_write(*args, **kwargs):
if StdStreamPatch._stdout_original_write:
StdStreamPatch._stdout_original_write(*args, **kwargs)
@staticmethod
def _stdout__patched__write__(*args, **kwargs):
if StdStreamPatch._stdout_proxy:
return StdStreamPatch._stdout_proxy.write(*args, **kwargs)
return sys.stdout._original_write(*args, **kwargs)
@staticmethod
def _stderr__patched__write__(*args, **kwargs):
if StdStreamPatch._stderr_proxy:
return StdStreamPatch._stderr_proxy.write(*args, **kwargs)
return sys.stderr._original_write(*args, **kwargs)
class PrintPatchLogger(object):
"""
Allowed patching a stream into the logger.
Used for capturing and logging stdin and stderr when running in development mode pseudo worker.
"""
patched = False
lock = threading.Lock()
recursion_protect_lock = threading.RLock()
def __init__(self, stream, logger=None, level=logging.INFO):
PrintPatchLogger.patched = True
self._terminal = stream
self._log = logger
self._log_level = level
self._cur_line = ''
def write(self, message):
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
try:
self.lock.acquire()
with PrintPatchLogger.recursion_protect_lock:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
else:
self._terminal.write(message)
do_flush = '\n' in message
do_cr = '\r' in message
self._cur_line += message
if (not do_flush and not do_cr) or not message:
return
last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
next_line = self._cur_line[last_lf + 1:]
cur_line = self._cur_line[:last_lf + 1].rstrip()
self._cur_line = next_line
finally:
self.lock.release()
if cur_line:
with PrintPatchLogger.recursion_protect_lock:
# noinspection PyBroadException
try:
if self._log:
self._log._console(cur_line, level=self._log_level, omit_console=True)
except Exception:
# what can we do, nothing
pass
else:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
else:
self._terminal.write(message)
def connect(self, logger):
self._cur_line = ''
self._log = logger
def __getattr__(self, attr):
if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
return self.__dict__.get(attr)
return getattr(self._terminal, attr)
def __setattr__(self, key, value):
if key in ['_log', '_terminal', '_log_level', '_cur_line']:
self.__dict__[key] = value
else:
return setattr(self._terminal, key, value)
class LogFlusher(threading.Thread):
def __init__(self, logger, period, **kwargs):
super(LogFlusher, self).__init__(**kwargs)
self.daemon = True
self._period = period
self._logger = logger
self._exit_event = threading.Event()
@property
def period(self):
return self._period
def run(self):
self._logger.flush()
# store original wait period
while True:
period = self._period
while not self._exit_event.wait(period or 1.0):
self._logger.flush()
# check if period is negative or None we should exit
if self._period is None or self._period < 0:
break
# check if period was changed, we should restart
self._exit_event.clear()
def exit(self):
self._period = None
self._exit_event.set()
def set_period(self, period):
self._period = period
# make sure we exit the previous wait
self._exit_event.set()

View File

@ -177,6 +177,8 @@ class UploadEvent(MetricsEventAdapter):
def __init__(self, metric, variant, image_data, local_image_path=None, iter=0, upload_uri=None,
image_file_history_size=None, delete_after_upload=False, **kwargs):
# param override_filename: override uploaded file name (notice extension will be added from local path
# param override_filename_ext: override uploaded file extension
if image_data is not None and not hasattr(image_data, 'shape'):
raise ValueError('Image must have a shape attribute')
self._image_data = image_data
@ -197,8 +199,10 @@ class UploadEvent(MetricsEventAdapter):
# get upload uri upfront, either predefined image format or local file extension
# e.g.: image.png -> .png or image.raw.gz -> .raw.gz
image_format = self._format.lower() if self._image_data is not None else \
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
image_format = kwargs.pop('override_filename_ext', None)
if image_format is None:
image_format = self._format.lower() if self._image_data is not None else \
'.' + '.'.join(pathlib2.Path(self._local_image_path).parts[-1].split('.')[1:])
self._upload_filename = str(pathlib2.Path(self._filename).with_suffix(image_format))
self._override_storage_key_prefix = kwargs.pop('override_storage_key_prefix', None)

View File

@ -80,6 +80,6 @@ class AccessMixin(object):
expected_num_of_classes += 1 if int(index) > 0 else 0
num_of_classes = int(max(model_labels.values()))
if num_of_classes != expected_num_of_classes:
self.log.warn('The highest label index is %d, while there are %d non-bg labels' %
(num_of_classes, expected_num_of_classes))
self.log.warning('The highest label index is %d, while there are %d non-bg labels' %
(num_of_classes, expected_num_of_classes))
return num_of_classes + 1 # +1 is meant for bg!

View File

@ -51,7 +51,7 @@ class TaskStopSignal(object):
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:
return TaskStopReason.reset
self.task.get_logger().warning(
self.task.log.warning(
"Task {} was reset! if state is consistent we shall terminate.".format(self.task.id),
)
else:

View File

@ -160,7 +160,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
conf = get_config_for_bucket(base_url=output_dest)
if not conf:
msg = 'Failed resolving output destination (no credentials found for %s)' % output_dest
self.log.warn(msg)
self.log.warning(msg)
if raise_errors:
raise Exception(msg)
else:
@ -187,12 +187,12 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
latest_version = CheckPackageUpdates.check_new_package_available(only_once=True)
if latest_version:
if not latest_version[1]:
self.get_logger().console(
self.get_logger().report_text(
'TRAINS new package available: UPGRADE to v{} is recommended!'.format(
latest_version[0]),
)
else:
self.get_logger().console(
self.get_logger().report_text(
'TRAINS-SERVER new version available: upgrade to v{} is recommended!'.format(
latest_version[0]),
)
@ -205,7 +205,7 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
check_package_update_thread.start()
result = ScriptInfo.get(log=self.log)
for msg in result.warning_messages:
self.get_logger().console(msg)
self.get_logger().report_text(msg)
self.data.script = result.script
# Since we might run asynchronously, don't use self.data (lest someone else
@ -418,16 +418,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
def update_model_desc(self, new_model_desc_file=None):
""" Change the task's model_desc """
execution = self._get_task_property('execution')
p = Path(new_model_desc_file)
if not p.is_file():
raise IOError('mode_desc file %s cannot be found' % new_model_desc_file)
new_model_desc = p.read_text()
model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design'
execution.model_desc[model_desc_key] = new_model_desc
with self._edit_lock:
execution = self._get_task_property('execution')
p = Path(new_model_desc_file)
if not p.is_file():
raise IOError('mode_desc file %s cannot be found' % new_model_desc_file)
new_model_desc = p.read_text()
model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design'
execution.model_desc[model_desc_key] = new_model_desc
res = self._edit(execution=execution)
return res.response
res = self._edit(execution=execution)
return res.response
def update_output_model(self, model_uri, name=None, comment=None, tags=None):
"""
@ -536,16 +537,17 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
model = None
model_id = ''
# store model id
self.data.execution.model = model_id
with self._edit_lock:
# store model id
self.data.execution.model = model_id
# Auto populate input field from model, if they are empty
if update_task_design and not self.data.execution.model_desc:
self.data.execution.model_desc = model.design if model else ''
if update_task_labels and not self.data.execution.model_labels:
self.data.execution.model_labels = model.labels if model else {}
# Auto populate input field from model, if they are empty
if update_task_design and not self.data.execution.model_desc:
self.data.execution.model_desc = model.design if model else ''
if update_task_labels and not self.data.execution.model_labels:
self.data.execution.model_labels = model.labels if model else {}
self._edit(execution=self.data.execution)
self._edit(execution=self.data.execution)
def set_parameters(self, *args, **kwargs):
"""
@ -580,12 +582,13 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
execution = self.data.execution
if execution is None:
execution = tasks.Execution(parameters=parameters)
else:
execution.parameters = parameters
self._edit(execution=execution)
with self._edit_lock:
execution = self.data.execution
if execution is None:
execution = tasks.Execution(parameters=parameters)
else:
execution.parameters = parameters
self._edit(execution=execution)
def set_parameter(self, name, value, description=None):
"""
@ -630,14 +633,15 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
:param dict enumeration: For example: {str(label): integer(id)}
"""
enumeration = enumeration or {}
execution = self.data.execution
if enumeration is None:
return
if not (isinstance(enumeration, dict)
and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())):
raise ValueError('Expected label to be a dict[str => int]')
execution.model_labels = enumeration
self._edit(execution=execution)
with self._edit_lock:
execution = self.data.execution
if enumeration is None:
return
if not (isinstance(enumeration, dict)
and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())):
raise ValueError('Expected label to be a dict[str => int]')
execution.model_labels = enumeration
self._edit(execution=execution)
def set_artifacts(self, artifacts_list=None):
"""
@ -650,16 +654,18 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if not (isinstance(artifacts_list, (list, tuple))
and all(isinstance(a, tasks.Artifact) for a in artifacts_list)):
raise ValueError('Expected artifacts to [tasks.Artifacts]')
execution = self.data.execution
execution.artifacts = artifacts_list
self._edit(execution=execution)
with self._edit_lock:
execution = self.data.execution
execution.artifacts = artifacts_list
self._edit(execution=execution)
def _set_model_design(self, design=None):
execution = self.data.execution
if design is not None:
execution.model_desc = Model._wrap_design(design)
with self._edit_lock:
execution = self.data.execution
if design is not None:
execution.model_desc = Model._wrap_design(design)
self._edit(execution=execution)
self._edit(execution=execution)
def get_labels_enumeration(self):
"""

View File

@ -36,8 +36,8 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
log = get_logger()
if len(results) > 1:
log.warn('More than one {entity} found when searching for `{query}`'
' (showing first {show_results} {entity}s follow)'.format(**locals()))
log.warning('More than one {entity} found when searching for `{query}`'
' (showing first {show_results} {entity}s follow)'.format(**locals()))
if sort_by_date:
# sort results based on timestamp and return the newest one
if hasattr(results[0], 'last_update'):
@ -49,7 +49,7 @@ def get_single_result(entity, query, results, log=None, show_results=10, raise_o
for i, obj in enumerate(o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
selected = 'Selected' if i == 0 else 'Additionally found'
log.warn('{selected} {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
log.warning('{selected} {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
if raise_on_error:
raise ValueError('More than one {entity}s found when searching for ``{query}`'.format(**locals()))

View File

@ -5,13 +5,13 @@ import threading
from collections import defaultdict
from functools import partial
from io import BytesIO
from logging import ERROR, WARNING, getLogger
from typing import Any
import numpy as np
from PIL import Image
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
from ...debugging.log import LoggerRoot
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
from ..import_bind import PostImportHookPatching
from ...config import running_remotely
from ...model import InputModel, OutputModel, Framework
@ -187,7 +187,8 @@ class EventTrainsWriter(object):
else:
val = val[:, :, [0, 1, 2]]
except Exception:
self._logger.warning('Failed decoding debug image [%d, %d, %d]' % (width, height, color_channels))
LoggerRoot.get_base_logger().warning('Failed decoding debug image [%d, %d, %d]'
% (width, height, color_channels))
val = None
return val
@ -213,7 +214,7 @@ class EventTrainsWriter(object):
tile_size = res.shape[0] * res.shape[1]
img_data_np = res.reshape(tile_size, tile_size, -1)
self._logger.report_image_and_upload(
self._logger.report_image(
title=title,
series=series,
iteration=step,
@ -419,7 +420,7 @@ class EventTrainsWriter(object):
msg_dict.pop('wallTime', None)
keys_list = [key for key in msg_dict.keys() if len(key) > 0]
keys_list = ', '.join(keys_list)
self._logger.debug('event summary not found, message type unsupported: %s' % keys_list)
LoggerRoot.get_base_logger().debug('event summary not found, message type unsupported: %s' % keys_list)
return
value_dicts = summary.get('value')
walltime = walltime or msg_dict.get('step')
@ -431,19 +432,20 @@ class EventTrainsWriter(object):
step = int(event.step)
else:
step = 0
self._logger.debug('Recieved event without step, assuming step = {}'.format(step), WARNING)
LoggerRoot.get_base_logger().debug('Received event without step, assuming step = {}'.format(step))
else:
step = int(step)
self._max_step = max(self._max_step, step)
if value_dicts is None:
self._logger.debug("Summary with arrived without 'value'", ERROR)
LoggerRoot.get_base_logger().debug("Summary arrived without 'value'")
return
for vdict in value_dicts:
tag = vdict.pop('tag', None)
if tag is None:
# we should not get here
self._logger.debug('No tag for \'value\' existing keys %s' % ', '.join(vdict.keys()))
LoggerRoot.get_base_logger().debug('No tag for \'value\' existing keys %s'
% ', '.join(vdict.keys()))
continue
metric, values = get_data(vdict, supported_metrics)
if metric == 'simpleValue':
@ -459,7 +461,8 @@ class EventTrainsWriter(object):
elif metric == 'tensor' and values.get('dtype') == 'DT_FLOAT':
self._add_plot(tag, step, values, vdict)
else:
self._logger.debug('Event unsupported. tag = %s, vdict keys [%s]' % (tag, ', '.join(vdict.keys)))
LoggerRoot.get_base_logger().debug('Event unsupported. tag = %s, vdict keys [%s]'
% (tag, ', '.join(vdict.keys)))
continue
def get_logdir(self):
@ -589,7 +592,7 @@ class PatchSummaryToEventTransformer(object):
setattr(SummaryToEventTransformer, 'trains',
property(PatchSummaryToEventTransformer.trains_object))
except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex))
LoggerRoot.get_base_logger().debug(str(ex))
if 'torch' in sys.modules:
try:
@ -603,7 +606,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex))
LoggerRoot.get_base_logger().debug(str(ex))
if 'tensorboardX' in sys.modules:
try:
@ -619,7 +622,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex))
LoggerRoot.get_base_logger().debug(str(ex))
if PatchSummaryToEventTransformer.__original_getattributeX is None:
try:
@ -633,7 +636,7 @@ class PatchSummaryToEventTransformer(object):
# this is a new version of TensorflowX
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex))
LoggerRoot.get_base_logger().debug(str(ex))
@staticmethod
def _patched_add_eventT(self, *args, **kwargs):
@ -717,7 +720,7 @@ class _ModelAdapter(object):
super(_ModelAdapter, self).__init__()
super(_ModelAdapter, self).__setattr__('_model', model)
super(_ModelAdapter, self).__setattr__('_output_model', output_model)
super(_ModelAdapter, self).__setattr__('_logger', getLogger('TrainsModelAdapter'))
super(_ModelAdapter, self).__setattr__('_logger', LoggerRoot.get_base_logger())
def __getattr__(self, attr):
return getattr(self._model, attr)
@ -800,7 +803,7 @@ class PatchModelCheckPointCallback(object):
property(PatchModelCheckPointCallback.trains_object))
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
@staticmethod
def _patched_getattribute(self, attr):
@ -878,7 +881,7 @@ class PatchTensorFlowEager(object):
except ImportError:
pass
except Exception as ex:
getLogger(TrainsFrameworkAdapter).debug(str(ex))
LoggerRoot.get_base_logger().debug(str(ex))
@staticmethod
def _get_event_writer(writer):
@ -905,7 +908,7 @@ class PatchTensorFlowEager(object):
try:
event_writer._add_scalar(tag=str(tag), step=int(step.numpy()), scalar_data=value.numpy())
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
return PatchTensorFlowEager.__original_fn_scalar(writer, step, tag, value, name, **kwargs)
@staticmethod
@ -915,7 +918,7 @@ class PatchTensorFlowEager(object):
try:
event_writer._add_histogram(tag=str(tag), step=int(step.numpy()), histo_data=values.numpy())
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
return PatchTensorFlowEager.__original_fn_hist(writer, step, tag, values, name, **kwargs)
@staticmethod
@ -926,7 +929,7 @@ class PatchTensorFlowEager(object):
event_writer._add_image_numpy(tag=str(tag), step=int(step.numpy()), img_data_np=tensor.numpy(),
max_keep_images=max_images)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
return PatchTensorFlowEager.__original_fn_image(writer, step, tag, tensor, bad_color, max_images, name,
**kwargs)
@ -1024,7 +1027,7 @@ class PatchKerasModelIO(object):
keras_saving.save_model = _patched_call(keras_saving.save_model, PatchKerasModelIO._save_model)
keras_saving.load_model = _patched_call(keras_saving.load_model, PatchKerasModelIO._load_model)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
@staticmethod
def _updated_config(original_fn, self):
@ -1052,7 +1055,7 @@ class PatchKerasModelIO(object):
framework=Framework.keras,
)
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
return config
@ -1102,7 +1105,7 @@ class PatchKerasModelIO(object):
return model
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
return self
@ -1184,7 +1187,7 @@ class PatchKerasModelIO(object):
# if anyone asks, we were here
self.trains_out_model._processed = True
except Exception as ex:
getLogger(TrainsFrameworkAdapter).warning(str(ex))
LoggerRoot.get_base_logger().warning(str(ex))
@staticmethod
def _save_model(original_fn, model, filepath, *args, **kwargs):

View File

@ -329,19 +329,19 @@ class PatchedMatplotlib:
PatchedMatplotlib._global_image_counter += 1
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_image_counter
logger.report_image_and_upload(title=title, series='plot image', path=image,
delete_after_upload=True,
iteration=PatchedMatplotlib._global_image_counter
if plot_title else 0)
logger.report_image(title=title, series='plot image', local_path=image,
delete_after_upload=True,
iteration=PatchedMatplotlib._global_image_counter
if plot_title else 0)
else:
# send the plot as plotly with embedded image
PatchedMatplotlib._global_plot_counter += 1
title = plot_title or 'untitled %d' % PatchedMatplotlib._global_plot_counter
logger.report_image_plot_and_upload(title=title, series='plot image', path=image,
delete_after_upload=True,
iteration=PatchedMatplotlib._global_plot_counter
if plot_title else 0)
logger._report_image_plot_and_upload(title=title, series='plot image', path=image,
delete_after_upload=True,
iteration=PatchedMatplotlib._global_plot_counter
if plot_title else 0)
except Exception:
# plotly failed

View File

@ -18,7 +18,7 @@ def get_cache_dir():
cache_base_dir = Path(
expandvars(
expanduser(
config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR
TRAINS_CACHE_DIR.get() or config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR
)
)
)

View File

@ -7,7 +7,6 @@ from pathlib2 import Path
SESSION_CACHE_FILE = ".session.json"
DEFAULT_CACHE_DIR = str(Path(tempfile.gettempdir()) / "trains_cache")
TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID")
LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool)
NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int)
@ -16,6 +15,7 @@ LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LO
DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME")
DEV_TASK_NO_REUSE = EnvEntry("TRAINS_TASK_NO_REUSE", "ALG_TASK_NO_REUSE", type=bool)
TASK_LOG_ENVIRONMENT = EnvEntry("TRAINS_LOG_ENVIRONMENT", "ALG_LOG_ENVIRONMENT", type=str)
TRAINS_CACHE_DIR = EnvEntry("TRAINS_CACHE_DIR", "ALG_CACHE_DIR")
LOG_LEVEL_ENV_VAR = EnvEntry("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL", converter=or_(int, str))

View File

@ -1,12 +1,9 @@
import logging
import re
import sys
import threading
from functools import wraps
import numpy as np
from pathlib2 import Path
from .backend_interface.logger import StdStreamPatch, LogFlusher
from .debugging.log import LoggerRoot
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.log import TaskHandler
@ -17,36 +14,6 @@ from .backend_interface.task import Task as _Task
from .config import running_remotely, get_cache_dir
def _safe_names(func):
"""
Validate the form of title and series parameters.
This decorator assert that a method receives 'title' and 'series' as its
first positional arguments, and that their values have only legal characters.
'\', '/' and ':' will be replaced automatically by '_'
Whitespace chars will be replaced automatically by ' '
"""
_replacements = {
'_': re.compile(r"[/\\:]"),
' ': re.compile(r"[\s]"),
}
def _make_safe(value):
for repl, regex in _replacements.items():
value = regex.sub(repl, value)
return value
@wraps(func)
def fixed_names(self, title, series, *args, **kwargs):
title = _make_safe(title)
series = _make_safe(series)
func(self, title, series, *args, **kwargs)
return fixed_names
class Logger(object):
"""
Console log and metric statistics interface.
@ -56,9 +23,6 @@ class Logger(object):
**Usage:** :func:`Logger.current_logger` or :func:`Task.get_logger`
"""
SeriesInfo = SeriesInfo
_stdout_proxy = None
_stderr_proxy = None
_stdout_original_write = None
def __init__(self, private_task):
"""
@ -75,67 +39,11 @@ class Logger(object):
self._report_worker = None
self._task_handler = None
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
Logger._stdout_proxy = PrintPatchLogger(sys.stdout, self, level=logging.INFO)
Logger._stderr_proxy = PrintPatchLogger(sys.stderr, self, level=logging.ERROR)
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
# noinspection PyBroadException
try:
if Logger._stdout_original_write is None:
Logger._stdout_original_write = sys.stdout.write
# this will only work in python 3, guard it with try/catch
if not hasattr(sys.stdout, '_original_write'):
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = stdout__patched__write__
if not hasattr(sys.stderr, '_original_write'):
sys.stderr._original_write = sys.stderr.write
sys.stderr.write = stderr__patched__write__
except Exception:
pass
sys.stdout = Logger._stdout_proxy
sys.stderr = Logger._stderr_proxy
# patch the base streams of sys (this way colorama will keep its ANSI colors)
# noinspection PyBroadException
try:
sys.__stderr__ = sys.stderr
except Exception:
pass
# noinspection PyBroadException
try:
sys.__stdout__ = sys.stdout
except Exception:
pass
# now check if we have loguru and make it re-register the handlers
# because it sores internally the stream.write function, which we cant patch
# noinspection PyBroadException
try:
from loguru import logger
register_stderr = None
register_stdout = None
for k, v in logger._handlers.items():
if v._name == '<stderr>':
register_stderr = k
elif v._name == '<stdout>':
register_stderr = k
if register_stderr is not None:
logger.remove(register_stderr)
logger.add(sys.stderr)
if register_stdout is not None:
logger.remove(register_stdout)
logger.add(sys.stdout)
except Exception:
pass
elif DevWorker.report_stdout and not running_remotely():
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
if Logger._stdout_proxy:
Logger._stdout_proxy.connect(self)
if Logger._stderr_proxy:
Logger._stderr_proxy.connect(self)
StdStreamPatch.patch_std_streams(self)
@classmethod
def current_logger(cls):
# type: () -> Logger
"""
Return a logger object for the current task. Can be called from anywhere in the code
@ -147,92 +55,24 @@ class Logger(object):
return None
return task.get_logger()
def console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs):
def report_text(self, msg, level=logging.INFO, print_console=True, *args, **_):
"""
print text to log (same as print to console, and also prints to console)
print text to log and optionally also prints to console
:param msg: text to print to the console (always send to the backend and displayed in console)
:param level: logging level, default: logging.INFO
:param omit_console: If True we only send 'msg' to log (no console print)
:param str msg: text to print to the console (always send to the backend and displayed in console)
:param int level: logging level, default: logging.INFO
:param bool print_console: If True we also print 'msg' to console
"""
try:
level = int(level)
except (TypeError, ValueError):
self._task.log.log(level=logging.ERROR,
msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO
if not running_remotely():
# noinspection PyBroadException
try:
record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
)
# find the task handler that matches our task
if not self._task_handler:
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers
if isinstance(h, TaskHandler) and h.task_id == self._task.id][0]
self._task_handler.emit(record)
except Exception:
LoggerRoot.get_base_logger().warning(msg='Logger failed sending log: [level %s]: "%s"'
% (str(level), str(msg)))
if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout and not running_remotely():
# noinspection PyBroadException
try:
# make sure we are writing to the original stdout
Logger._stdout_original_write(str(msg)+'\n')
except Exception:
pass
else:
print(str(msg))
# if task was not started, we have to start it
self._start_task_if_needed()
def report_text(self, msg, level=logging.INFO, print_console=False, *args, **_):
return self.console(msg, level, not print_console, *args, **_)
def debug(self, msg, *args, **kwargs):
""" Print information to the log. This is the same as console(msg, logging.DEBUG) """
self._task.log.log(msg=msg, level=logging.DEBUG, *args, **kwargs)
def info(self, msg, *args, **kwargs):
""" Print information to the log. This is the same as console(msg, logging.INFO) """
self._task.log.log(msg=msg, level=logging.INFO, *args, **kwargs)
def warn(self, msg, *args, **kwargs):
""" Print a warning to the log. This is the same as console(msg, logging.WARNING) """
self._task.log.log(msg=msg, level=logging.WARNING, *args, **kwargs)
warning = warn
def error(self, msg, *args, **kwargs):
""" Print an error to the log. This is the same as console(msg, logging.ERROR) """
self._task.log.log(msg=msg, level=logging.ERROR, *args, **kwargs)
def fatal(self, msg, *args, **kwargs):
""" Print a fatal error to the log. This is the same as console(msg, logging.FATAL) """
self._task.log.log(msg=msg, level=logging.FATAL, *args, **kwargs)
def critical(self, msg, *args, **kwargs):
""" Print a critical error to the log. This is the same as console(msg, logging.CRITICAL) """
self._task.log.log(msg=msg, level=logging.CRITICAL, *args, **kwargs)
return self._console(msg, level, not print_console, *args, **_)
def report_scalar(self, title, series, value, iteration):
"""
Report a scalar value
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param value: Reported value
:type value: float
:param iteration: Iteration number
:type value: int
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param float value: Reported value
:param int iteration: Iteration number
"""
# if task was not started, we have to start it
@ -244,18 +84,12 @@ class Logger(object):
"""
Report a histogram plot
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param values: Reported values (or numpy array)
:type values: [float]
:param iteration: Iteration number
:type iteration: int
:param labels: optional, labels for each bar group.
:type labels: list of strings.
:param xlabels: optional label per entry in the vector (bucket in the histogram)
:type xlabels: list of strings.
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param list(float) values: Reported values (or numpy array)
:param int iteration: Iteration number
:param list(str) labels: optional, labels for each bar group.
:param list(str) xlabels: optional label per entry in the vector (bucket in the histogram)
"""
return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels)
@ -263,18 +97,12 @@ class Logger(object):
"""
Report a histogram plot
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param values: Reported values (or numpy array)
:type values: [float]
:param iteration: Iteration number
:type iteration: int
:param labels: optional, labels for each bar group.
:type labels: list of strings.
:param xlabels: optional label per entry in the vector (bucket in the histogram)
:type xlabels: list of strings.
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param list(float) values: Reported values (or numpy array)
:param int iteration: Iteration number
:param list(str) labels: optional, labels for each bar group.
:param list(str) xlabels: optional label per entry in the vector (bucket in the histogram)
"""
if not isinstance(values, np.ndarray):
@ -292,24 +120,19 @@ class Logger(object):
xlabels=xlabels,
)
def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines', reverse_xaxis=False, comment=None):
def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines',
reverse_xaxis=False, comment=None):
"""
Report a (possibly multiple) line plot.
:param title: Title (AKA metric)
:type title: str
:param series: All the series' data, one for each line in the plot.
:type series: An iterable of LineSeriesInfo.
:param iteration: Iteration number
:type iteration: int
:param xaxis: optional x-axis title
:param yaxis: optional y-axis title
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
:type mode: str
:param reverse_xaxis: If true X axis will be displayed from high to low (reversed)
:type reverse_xaxis: bool
:param comment: comment underneath the title
:type comment: str
:param str title: Title (AKA metric)
:param list(LineSeriesInfo) series: All the series' data, one for each line in the plot.
:param int iteration: Iteration number
:param str xaxis: optional x-axis title
:param str yaxis: optional y-axis title
:param str mode: scatter plot with 'lines'/'markers'/'lines+markers'
:param bool reverse_xaxis: If true X axis will be displayed from high to low (reversed)
:param str comment: comment underneath the title
"""
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
@ -333,21 +156,15 @@ class Logger(object):
"""
Report a 2d scatter graph (with lines)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param scatter: A scattered data: list of (pairs of x,y) (or numpy array)
:type scatter: ndarray or list
:param iteration: Iteration number
:type iteration: int
:param xaxis: optional x-axis title
:param yaxis: optional y-axis title
:param labels: label (text) per point in the scatter (in the same order)
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
:type mode: str
:param comment: comment underneath the title
:type comment: str
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param np.ndarray scatter: A scattered data: list of (pairs of x,y) (or numpy array)
:param int iteration: Iteration number
:param str xaxis: optional x-axis title
:param str yaxis: optional y-axis title
:param list(str) labels: label (text) per point in the scatter (in the same order)
:param str mode: scatter plot with 'lines'/'markers'/'lines+markers'
:param str comment: comment underneath the title
"""
if not isinstance(scatter, np.ndarray):
@ -375,18 +192,15 @@ class Logger(object):
"""
Report a 3d scatter graph (with markers)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param scatter: A scattered data: list of (pairs of x,y,z) (or numpy array) or list of series [[(x1,y1,z1)...]]
:type scatter: ndarray or list
:param iteration: Iteration number
:type iteration: int
:param labels: label (text) per point in the scatter (in the same order)
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
:param fill: fill area under the curve
:param comment: comment underneath the title
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param np.ndarray scatter: A scattered data: list of (pairs of x,y,z) (or numpy array)
or list of series [[(x1,y1,z1)...]]
:param int iteration: Iteration number
:param list(str) labels: label (text) per point in the scatter (in the same order)
:param str mode: scatter plot with 'lines'/'markers'/'lines+markers'
:param bool fill: fill area under the curve
:param str comment: comment underneath the title
"""
# check if multiple series
multi_series = (
@ -429,17 +243,13 @@ class Logger(object):
"""
Report a heat-map matrix
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param matrix: A heat-map matrix (example: confusion matrix)
:type matrix: ndarray
:param iteration: Iteration number
:type iteration: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
:param comment: comment underneath the title
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param np.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: Iteration number
:param list(str) xlabels: optional label per column of the matrix
:param list(str) ylabels: optional label per row of the matrix
:param str comment: comment underneath the title
"""
if not isinstance(matrix, np.ndarray):
@ -463,16 +273,12 @@ class Logger(object):
Same as report_confusion_matrix
Report a heat-map matrix
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param matrix: A heat-map matrix (example: confusion matrix)
:type matrix: ndarray
:param iteration: Iteration number
:type iteration: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param np.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: Iteration number
:param list(str) xlabels: optional label per column of the matrix
:param list(str) ylabels: optional label per row of the matrix
"""
return self.report_confusion_matrix(title, series, matrix, iteration, xlabels=xlabels, ylabels=ylabels)
@ -481,21 +287,17 @@ class Logger(object):
"""
Report a 3d surface (same data as heat-map matrix, only presented differently)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param matrix: A heat-map matrix (example: confusion matrix)
:type matrix: ndarray
:param iteration: Iteration number
:type iteration: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
:param xtitle: optional x-axis title
:param ytitle: optional y-axis title
:param ztitle: optional z-axis title
:param camera: X,Y,Z camera position. def: (1,1,1)
:param comment: comment underneath the title
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param np.ndarray matrix: A heat-map matrix (example: confusion matrix)
:param int iteration: Iteration number
:param list(str) xlabels: optional label per column of the matrix
:param list(str) ylabels: optional label per row of the matrix
:param str xtitle: optional x-axis title
:param str ytitle: optional y-axis title
:param str ztitle: optional z-axis title
:param list(float) camera: X,Y,Z camera position. def: (1,1,1)
:param str comment: comment underneath the title
"""
if not isinstance(matrix, np.ndarray):
@ -518,56 +320,24 @@ class Logger(object):
comment=comment,
)
@_safe_names
def report_image(self, title, series, src, iteration):
"""
Report an image, and register the 'src' as url content.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param src: Image source URI. This URI will be used by the webapp and workers when trying to obtain the image \
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str
:param iteration: Iteration number
:type iteration: int
"""
# if task was not started, we have to start it
self._start_task_if_needed()
self._task.reporter.report_image(
title=title,
series=series,
src=src,
iter=iteration,
)
@_safe_names
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
delete_after_upload=False):
def report_image(self, title, series, iteration, local_path=None, matrix=None, max_image_history=None,
delete_after_upload=False):
"""
Report an image and upload its contents.
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
describing the task ID, title, series and iteration.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param iteration: Iteration number
:type iteration: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
:type delete_after_upload: boolean
:param str title: Title (AKA metric)
:param str series: Series (AKA variant)
:param int iteration: Iteration number
:param str local_path: A path to an image file. Required unless matrix is provided.
Required unless matrix is provided.
:param np.ndarray matrix: A 3D numpy.ndarray object containing image data (RGB).
Required unless filename is provided.
:param int max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param bool delete_after_upload: if True, one the file was uploaded the local copy will be deleted
"""
# if task was not started, we have to start it
@ -584,7 +354,7 @@ class Logger(object):
self._task.reporter.report_image_and_upload(
title=title,
series=series,
path=path,
path=local_path,
matrix=matrix,
iter=iteration,
upload_uri=upload_uri,
@ -592,7 +362,138 @@ class Logger(object):
delete_after_upload=delete_after_upload,
)
def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
def set_default_upload_destination(self, uri):
"""
Set the uri to upload all the debug images to.
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
a link to the uploaded image is sent in the report
Notice: credentials for the upload destination will be pooled from the
global configuration file (i.e. ~/trains.conf)
:param str uri: example: 's3://bucket/directory/' or 'file:///tmp/debug/'
:return: True if destination scheme is supported (i.e. s3:// file:// gc:// etc...)
"""
# Create the storage helper
storage = StorageHelper.get(uri)
# Verify that we can upload to this destination
uri = storage.verify_upload(folder_uri=uri)
self._default_upload_destination = uri
def get_default_upload_destination(self):
"""
Get the uri to upload all the debug images to.
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
a link to the uploaded image is sent in the report
Notice: credentials for the upload destination will be pooled from the
global configuration file (i.e. ~/trains.conf)
:return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc...
"""
return self._default_upload_destination or self._task._get_default_report_storage_uri()
def flush(self):
"""
Flush cached reports and console outputs to backend.
:return: True if successful
"""
self._flush_stdout_handler()
if self._task:
return self._task.flush()
return False
def get_flush_period(self):
"""
:return: logger flush period in seconds
"""
if self._flusher:
return self._flusher.period
return None
def set_flush_period(self, period):
"""
Set the period of the logger flush.
:param float period: The period to flush the logger in seconds. If None or 0,
There will be no periodic flush.
"""
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
not running_remotely() and period is not None:
period = min(period or DevWorker.report_period, DevWorker.report_period)
if not period:
if self._flusher:
self._flusher.exit()
self._flusher = None
elif self._flusher:
self._flusher.set_period(period)
else:
self._flusher = LogFlusher(self, period)
self._flusher.start()
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
delete_after_upload=False):
"""
Backwards compatibility, please use report_image instead
"""
self.report_image(title=title, series=series, iteration=iteration, local_path=path, matrix=matrix,
max_image_history=max_image_history, delete_after_upload=delete_after_upload)
@classmethod
def _remove_std_logger(cls):
StdStreamPatch.remove_std_logger()
def _console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs):
"""
print text to log (same as print to console, and also prints to console)
:param msg: text to print to the console (always send to the backend and displayed in console)
:param level: logging level, default: logging.INFO
:param omit_console: If True we only send 'msg' to log (no console print)
"""
try:
level = int(level)
except (TypeError, ValueError):
self._task.log.log(level=logging.ERROR,
msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO
if not running_remotely():
# noinspection PyBroadException
try:
record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
)
# find the task handler that matches our task
if not self._task_handler:
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers
if isinstance(h, TaskHandler) and h.task_id == self._task.id][0]
self._task_handler.emit(record)
except Exception:
LoggerRoot.get_base_logger().warning(msg='Logger failed sending log: [level %s]: "%s"'
% (str(level), str(msg)))
if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout and not running_remotely():
# noinspection PyBroadException
try:
# make sure we are writing to the original stdout
StdStreamPatch.stdout_original_write(str(msg)+'\n')
except Exception:
pass
else:
print(str(msg))
# if task was not started, we have to start it
self._start_task_if_needed()
def _report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None,
delete_after_upload=False):
"""
Report an image, upload its contents, and present in plots section using plotly
@ -639,7 +540,7 @@ class Logger(object):
delete_after_upload=delete_after_upload,
)
def report_file_and_upload(self, title, series, iteration, path=None, max_file_history=None,
def _report_file_and_upload(self, title, series, iteration, path=None, max_file_history=None,
delete_after_upload=False):
"""
Upload a file and report it as link in the debug images section.
@ -684,92 +585,6 @@ class Logger(object):
delete_after_upload=delete_after_upload,
)
def set_default_upload_destination(self, uri):
"""
Set the uri to upload all the debug images to.
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
a link to the uploaded image is sent in the report
Notice: credentials for the upload destination will be pooled from the
global configuration file (i.e. ~/trains.conf)
:param uri: example: 's3://bucket/directory/' or 'file:///tmp/debug/'
:return: True if destination scheme is supported (i.e. s3:// file:// gc:// etc...)
"""
# Create the storage helper
storage = StorageHelper.get(uri)
# Verify that we can upload to this destination
uri = storage.verify_upload(folder_uri=uri)
self._default_upload_destination = uri
def get_default_upload_destination(self):
"""
Get the uri to upload all the debug images to.
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
a link to the uploaded image is sent in the report
Notice: credentials for the upload destination will be pooled from the
global configuration file (i.e. ~/trains.conf)
:return: Uri (str) example: 's3://bucket/directory/' or 'file:///tmp/debug/' etc...
"""
return self._default_upload_destination or self._task._get_default_report_storage_uri()
def flush(self):
"""
Flush cached reports and console outputs to backend.
:return: True if successful
"""
self._flush_stdout_handler()
if self._task:
return self._task.flush()
return False
def get_flush_period(self):
if self._flusher:
return self._flusher.period
return None
def set_flush_period(self, period):
"""
Set the period of the logger flush.
:param period: The period to flush the logger in seconds. If None or 0,
There will be no periodic flush.
"""
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
not running_remotely() and period is not None:
period = min(period or DevWorker.report_period, DevWorker.report_period)
if not period:
if self._flusher:
self._flusher.exit()
self._flusher = None
elif self._flusher:
self._flusher.set_period(period)
else:
self._flusher = _Flusher(self, period)
self._flusher.start()
@classmethod
def _remove_std_logger(self):
if isinstance(sys.stdout, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stdout.connect(None)
except Exception:
pass
if isinstance(sys.stderr, PrintPatchLogger):
# noinspection PyBroadException
try:
sys.stderr.connect(None)
except Exception:
pass
def _start_task_if_needed(self):
# do not refresh the task status read from cached variable _status
if str(self._task._status) == str(tasks.TaskStatusEnum.created):
@ -780,121 +595,3 @@ class Logger(object):
def _flush_stdout_handler(self):
if self._task_handler and DevWorker.report_stdout:
self._task_handler.flush()
def stdout__patched__write__(*args, **kwargs):
if Logger._stdout_proxy:
return Logger._stdout_proxy.write(*args, **kwargs)
return sys.stdout._original_write(*args, **kwargs)
def stderr__patched__write__(*args, **kwargs):
if Logger._stderr_proxy:
return Logger._stderr_proxy.write(*args, **kwargs)
return sys.stderr._original_write(*args, **kwargs)
class PrintPatchLogger(object):
"""
Allowed patching a stream into the logger.
Used for capturing and logging stdin and stderr when running in development mode pseudo worker.
"""
patched = False
lock = threading.Lock()
recursion_protect_lock = threading.RLock()
def __init__(self, stream, logger=None, level=logging.INFO):
PrintPatchLogger.patched = True
self._terminal = stream
self._log = logger
self._log_level = level
self._cur_line = ''
def write(self, message):
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
try:
self.lock.acquire()
with PrintPatchLogger.recursion_protect_lock:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
else:
self._terminal.write(message)
do_flush = '\n' in message
do_cr = '\r' in message
self._cur_line += message
if (not do_flush and not do_cr) or not message:
return
last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
next_line = self._cur_line[last_lf + 1:]
cur_line = self._cur_line[:last_lf + 1].rstrip()
self._cur_line = next_line
finally:
self.lock.release()
if cur_line:
with PrintPatchLogger.recursion_protect_lock:
# noinspection PyBroadException
try:
if self._log:
self._log.console(cur_line, level=self._log_level, omit_console=True)
except Exception:
# what can we do, nothing
pass
else:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
else:
self._terminal.write(message)
def connect(self, logger):
self._cur_line = ''
self._log = logger
def __getattr__(self, attr):
if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
return self.__dict__.get(attr)
return getattr(self._terminal, attr)
def __setattr__(self, key, value):
if key in ['_log', '_terminal', '_log_level', '_cur_line']:
self.__dict__[key] = value
else:
return setattr(self._terminal, key, value)
class _Flusher(threading.Thread):
def __init__(self, logger, period, **kwargs):
super(_Flusher, self).__init__(**kwargs)
self.daemon = True
self._period = period
self._logger = logger
self._exit_event = threading.Event()
@property
def period(self):
return self._period
def run(self):
self._logger.flush()
# store original wait period
while True:
period = self._period
while not self._exit_event.wait(period or 1.0):
self._logger.flush()
# check if period is negative or None we should exit
if self._period is None or self._period < 0:
break
# check if period was changed, we should restart
self._exit_event.clear()
def exit(self):
self._period = None
self._exit_event.set()
def set_period(self, period):
self._period = period
# make sure we exit the previous wait
self._exit_event.set()

View File

@ -340,7 +340,6 @@ class InputModel(BaseModel):
name=None,
tags=None,
comment=None,
logger=None,
is_package=False,
create_as_published=False,
framework=None,
@ -367,7 +366,6 @@ class InputModel(BaseModel):
:param name: optional, name for the newly imported model
:param tags: optional, list of strings as tags
:param comment: optional, string description for the model
:param logger: The logger to use. If None, use the default logger
:param is_package: Boolean. Indicates that the imported weights file is a package.
If True, and a new model was created, a package tag will be added.
:param create_as_published: Boolean. If True, and a new model is created, it will be published.
@ -386,8 +384,7 @@ class InputModel(BaseModel):
))
if result.response.models:
if not logger:
logger = get_logger()
logger = get_logger()
logger.debug('A model with uri "{}" already exists. Selecting it'.format(weights_url))

View File

@ -1,16 +1,15 @@
import atexit
import os
import re
import signal
import sys
import threading
import time
from argparse import ArgumentParser
from collections import OrderedDict, Callable
from typing import Optional
import psutil
import six
from pathlib2 import Path
from .binding.joblib_bind import PatchedJoblib
from .backend_api.services import tasks, projects
@ -29,7 +28,7 @@ from .errors import UsageError
from .logger import Logger
from .model import InputModel, OutputModel, ARCHIVED_TAG
from .task_parameters import TaskParameters
from .binding.artifacts import Artifacts
from .binding.artifacts import Artifacts, Artifact
from .binding.environ_bind import EnvironmentBind, PatchOsFork
from .binding.absl_bind import PatchAbsl
from .utilities.args import argparser_parseargs_called, get_argparser_last_args, \
@ -41,6 +40,7 @@ from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
from .binding.matplotlib_bind import PatchedMatplotlib
from .utilities.resource_monitor import ResourceMonitor
from .utilities.seed import make_deterministic
from .utilities.dicts import ReadOnlyDict
NotSet = object()
@ -113,6 +113,7 @@ class Task(_Task):
@classmethod
def current_task(cls):
# type: () -> Task
"""
Return the Current Task object for the main execution task (task context).
:return: Task() object or None
@ -279,7 +280,7 @@ class Task(_Task):
# The logger will automatically take care of all patching (we just need to make sure to initialize it)
logger = task.get_logger()
# show the debug metrics page in the log, it is very convenient
logger.console(
logger.report_text(
'TRAINS results page: {}/projects/{}/experiments/{}/output/log'.format(
task._get_app_server(),
task.project if task.project is not None else '*',
@ -439,12 +440,12 @@ class Task(_Task):
task._setup_log(replace_existing=True)
logger = task.get_logger()
if closed_old_task:
logger.console('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task
if default_task_id:
logger.console('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
logger.report_text('TRAINS Task: overwriting (reusing) task id=%s' % task.id)
else:
logger.console('TRAINS Task: created new task id=%s' % task.id)
logger.report_text('TRAINS Task: created new task id=%s' % task.id)
# update current repository and put warning into logs
if in_dev_mode and cls.__detect_repo_async:
@ -462,8 +463,8 @@ class Task(_Task):
thread.start()
return task
@staticmethod
def get_task(task_id=None, project_name=None, task_name=None):
@classmethod
def get_task(cls, task_id=None, project_name=None, task_name=None):
"""
Returns Task object based on either, task_id (system uuid) or task name
@ -472,7 +473,7 @@ class Task(_Task):
:param task_name: task name (str) in within the selected project
:return: Task() object
"""
return Task.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
return cls.__get_task(task_id=task_id, project_name=project_name, task_name=task_name)
@property
def output_uri(self):
@ -490,10 +491,14 @@ class Task(_Task):
@property
def artifacts(self):
"""
dictionary of Task artifacts (name, artifact)
read-only dictionary of Task artifacts (name, artifact)
:return: dict
"""
return self._artifacts_manager.artifacts
if not Session.check_min_api_version('2.3'):
return ReadOnlyDict()
if not self.data.execution or not self.data.execution.artifacts:
return ReadOnlyDict()
return ReadOnlyDict([(a.key, Artifact(a)) for a in self.data.execution.artifacts])
def set_comment(self, comment):
"""
@ -553,6 +558,7 @@ class Task(_Task):
raise Exception('Unsupported mutable type %s: no connect function found' % type(mutable).__name__)
def get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
"""
get a logger object for reporting based on the task
@ -663,6 +669,15 @@ class Task(_Task):
"""
self._artifacts_manager.unregister_artifact(name=name)
def get_registered_artifacts(self):
"""
dictionary of Task registered artifacts (name, artifact object)
Notice these objects can be modified, changes will be uploaded automatically
:return: dict
"""
return self._artifacts_manager.registered_artifacts
def upload_artifact(self, name, artifact_object, metadata=None, delete_after_upload=False):
"""
Add static artifact to Task. Artifact file/object will be uploaded in the background
@ -671,6 +686,7 @@ class Task(_Task):
:param str name: Artifact name. Notice! it will override previous artifact if name already exists
:param object artifact_object: Artifact object to upload. Currently supports:
- string / pathlib2.Path are treated as path to artifact file to upload
If wildcard or a folder is passed, zip file containing the local files will be created and uploaded.
- dict will be stored as .json,
- pandas.DataFrame will be stored as .csv.gz (compressed CSV file),
- numpy.ndarray will be stored as .npz,
@ -937,7 +953,7 @@ class Task(_Task):
if self._at_exit_called:
return
self.get_logger().warn(
self.log.warning(
"### TASK STOPPED - USER ABORTED - {} ###".format(
stop_reason.upper().replace('_', ' ')
)
@ -1009,7 +1025,7 @@ class Task(_Task):
# signal artifacts upload, and stop daemon
self._artifacts_manager.stop(wait=True)
# print artifacts summary
self.get_logger().console(self._artifacts_manager.summary)
self.get_logger().report_text(self._artifacts_manager.summary)
def _at_exit(self):
"""

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function
import collections
import json
import re
import threading
@ -313,9 +314,11 @@ class CheckPackageUpdates(object):
# noinspection PyBroadException
try:
from ..version import __version__
cls._package_version_checked = True
# Sending the request only for statistics
update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server)
update_statistics = threading.Thread(target=CheckPackageUpdates.get_version_from_updates_server,
args=(__version__,))
update_statistics.daemon = True
update_statistics.start()
@ -323,7 +326,6 @@ class CheckPackageUpdates(object):
releases = [Version(r) for r in releases]
latest_version = sorted(releases)
from ..version import __version__
cur_version = Version(__version__)
if not cur_version.is_devrelease and not cur_version.is_prerelease:
latest_version = [r for r in latest_version if not r.is_devrelease and not r.is_prerelease]
@ -336,8 +338,9 @@ class CheckPackageUpdates(object):
return None
@staticmethod
def get_version_from_updates_server():
def get_version_from_updates_server(cur_version):
try:
_ = requests.get('https://updates.trainsai.io/updates', timeout=1.0)
_ = requests.get('https://updates.trainsai.io/updates',
params=json.dumps({'versions': {'trains': str(cur_version)}}), timeout=1.0)
except Exception:
pass

View File

@ -3,6 +3,19 @@
_epsilon = 0.00001
class ReadOnlyDict(dict):
def __readonly__(self, *args, **kwargs):
raise ValueError("This is a read only dictionary")
__setitem__ = __readonly__
__delitem__ = __readonly__
pop = __readonly__
popitem = __readonly__
clear = __readonly__
update = __readonly__
setdefault = __readonly__
del __readonly__
class Logs:
_logs_instances = []

View File

@ -1,3 +1,5 @@
import logging
import warnings
from time import time
from threading import Thread, Event
@ -32,8 +34,8 @@ class ResourceMonitor(object):
self._gpustat_fail = 0
self._gpustat = gpustat
if not self._gpustat:
self._task.get_logger().console('TRAINS Monitor: GPU monitoring is not available, '
'run \"pip install gpustat\"')
self._task.get_logger().report_text('TRAINS Monitor: GPU monitoring is not available, '
'run \"pip install gpustat\"')
def start(self):
self._exit_event.clear()
@ -73,8 +75,8 @@ class ResourceMonitor(object):
if IsTensorboardInit.tensorboard_used():
fallback_to_sec_as_iterations = False
elif seconds_since_started >= self._wait_for_first_iteration:
self._task.get_logger().console('TRAINS Monitor: Could not detect iteration reporting, '
'falling back to iterations as seconds-from-start')
self._task.get_logger().report_text('TRAINS Monitor: Could not detect iteration reporting, '
'falling back to iterations as seconds-from-start')
fallback_to_sec_as_iterations = True
clear_readouts = True
@ -168,9 +170,11 @@ class ResourceMonitor(object):
stats["memory_free_gb"] = bytes_to_megabytes(virtual_memory.available) / 1024
disk_use_percentage = psutil.disk_usage(Text(Path.home())).percent
stats["disk_free_percent"] = 100.0-disk_use_percentage
sensor_stat = (
psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {}
)
with warnings.catch_warnings():
if logging.root.level > logging.DEBUG: # If the logging level is bigger than debug, ignore
# psutil.sensors_temperatures warnings
warnings.simplefilter("ignore", category=RuntimeWarning)
sensor_stat = (psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {})
if "coretemp" in sensor_stat and len(sensor_stat["coretemp"]):
stats["cpu_temperature"] = max([float(t.current) for t in sensor_stat["coretemp"]])
@ -197,8 +201,8 @@ class ResourceMonitor(object):
# something happened and we can't use gpu stats,
self._gpustat_fail += 1
if self._gpustat_fail >= 3:
self._task.get_logger().console('TRAINS Monitor: GPU monitoring failed getting GPU reading, '
'switching off GPU monitoring')
self._task.get_logger().report_text('TRAINS Monitor: GPU monitoring failed getting GPU reading, '
'switching off GPU monitoring')
self._gpustat = None
return stats