1
0
mirror of https://github.com/clearml/clearml synced 2025-05-15 01:45:38 +00:00

Add Task.init auto_connect_streams controlling stdout/stderr/logging capture. Issue

This commit is contained in:
allegroai 2020-11-20 15:50:33 +02:00
parent e9920e27ed
commit 6e012cb205
7 changed files with 305 additions and 164 deletions
trains
backend_interface
config/default
logger.pytask.py

View File

@ -1,10 +1,10 @@
import logging import logging
import sys import sys
import threading import threading
from time import time
from ..backend_interface.task.development.worker import DevWorker from ..binding.frameworks import _patched_call # noqa
from ..backend_interface.task.log import TaskHandler from ..config import running_remotely, config
from ..config import running_remotely
class StdStreamPatch(object): class StdStreamPatch(object):
@ -14,51 +14,62 @@ class StdStreamPatch(object):
_stderr_original_write = None _stderr_original_write = None
@staticmethod @staticmethod
def patch_std_streams(logger): def patch_std_streams(a_logger, connect_stdout=True, connect_stderr=True):
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely(): if (connect_stdout or connect_stderr) and not PrintPatchLogger.patched and not running_remotely():
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO) StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, a_logger, level=logging.INFO) \
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR) if connect_stdout else None
logger._task_handler = TaskHandler(task=logger._task, capacity=100) StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, a_logger, level=logging.ERROR) \
if connect_stderr else None
if StdStreamPatch._stdout_proxy:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if StdStreamPatch._stdout_original_write is None: if StdStreamPatch._stdout_original_write is None:
StdStreamPatch._stdout_original_write = sys.stdout.write StdStreamPatch._stdout_original_write = sys.stdout.write
if StdStreamPatch._stderr_original_write is None:
StdStreamPatch._stderr_original_write = sys.stderr.write
# this will only work in python 3, guard it with try/catch # this will only work in python 3, guard it with try/catch
if not hasattr(sys.stdout, '_original_write'): if not hasattr(sys.stdout, '_original_write'):
sys.stdout._original_write = sys.stdout.write sys.stdout._original_write = sys.stdout.write
sys.stdout.write = StdStreamPatch._stdout__patched__write__ sys.stdout.write = StdStreamPatch._stdout__patched__write__
except Exception:
pass
sys.stdout = StdStreamPatch._stdout_proxy
# noinspection PyBroadException
try:
sys.__stdout__ = sys.stdout
except Exception:
pass
if StdStreamPatch._stderr_proxy:
# noinspection PyBroadException
try:
if StdStreamPatch._stderr_original_write is None:
StdStreamPatch._stderr_original_write = sys.stderr.write
if not hasattr(sys.stderr, '_original_write'): if not hasattr(sys.stderr, '_original_write'):
sys.stderr._original_write = sys.stderr.write sys.stderr._original_write = sys.stderr.write
sys.stderr.write = StdStreamPatch._stderr__patched__write__ sys.stderr.write = StdStreamPatch._stderr__patched__write__
except Exception: except Exception:
pass pass
sys.stdout = StdStreamPatch._stdout_proxy
sys.stderr = StdStreamPatch._stderr_proxy sys.stderr = StdStreamPatch._stderr_proxy
# patch the base streams of sys (this way colorama will keep its ANSI colors) # patch the base streams of sys (this way colorama will keep its ANSI colors)
# noinspection PyBroadException # noinspection PyBroadException
try: try:
sys.__stderr__ = sys.stderr sys.__stderr__ = sys.stderr
except Exception: except Exception:
pass pass
# noinspection PyBroadException
try:
sys.__stdout__ = sys.stdout
except Exception:
pass
# now check if we have loguru and make it re-register the handlers # now check if we have loguru and make it re-register the handlers
# because it stores internally the stream.write function, which we cant patch # because it stores internally the stream.write function, which we cant patch
# noinspection PyBroadException # noinspection PyBroadException
try: try:
from loguru import logger from loguru import logger # noqa
register_stderr = None register_stderr = None
register_stdout = None register_stdout = None
for k, v in logger._handlers.items(): for k, v in logger._handlers.items(): # noqa
if v._name == '<stderr>': if connect_stderr and v._name == '<stderr>': # noqa
register_stderr = k register_stderr = k
elif v._name == '<stdout>': elif connect_stdout and v._name == '<stdout>': # noqa
register_stderr = k register_stderr = k
if register_stderr is not None: if register_stderr is not None:
logger.remove(register_stderr) logger.remove(register_stderr)
@ -69,12 +80,30 @@ class StdStreamPatch(object):
except Exception: except Exception:
pass pass
elif DevWorker.report_stdout and not running_remotely(): elif (connect_stdout or connect_stderr) and not running_remotely():
logger._task_handler = TaskHandler(task=logger._task, capacity=100) if StdStreamPatch._stdout_proxy and connect_stdout:
if StdStreamPatch._stdout_proxy: StdStreamPatch._stdout_proxy.connect(a_logger)
StdStreamPatch._stdout_proxy.connect(logger) if StdStreamPatch._stderr_proxy and connect_stderr:
if StdStreamPatch._stderr_proxy: StdStreamPatch._stderr_proxy.connect(a_logger)
StdStreamPatch._stderr_proxy.connect(logger)
@staticmethod
def patch_logging_formatter(a_logger, logging_handler=None):
if not logging_handler:
import logging
logging_handler = logging.Handler
logging_handler.format = _patched_call(logging_handler.format, HandlerFormat(a_logger))
@staticmethod
def remove_patch_logging_formatter(logging_handler=None):
if not logging_handler:
import logging
logging_handler = logging.Handler
# remove the function, Hack calling patched logging.Handler.format() returns the original function
# noinspection PyBroadException
try:
logging_handler.format = logging_handler.format() # noqa
except Exception:
pass
@staticmethod @staticmethod
def remove_std_logger(logger=None): def remove_std_logger(logger=None):
@ -105,13 +134,33 @@ class StdStreamPatch(object):
def _stdout__patched__write__(*args, **kwargs): def _stdout__patched__write__(*args, **kwargs):
if StdStreamPatch._stdout_proxy: if StdStreamPatch._stdout_proxy:
return StdStreamPatch._stdout_proxy.write(*args, **kwargs) return StdStreamPatch._stdout_proxy.write(*args, **kwargs)
return sys.stdout._original_write(*args, **kwargs) return sys.stdout._original_write(*args, **kwargs) # noqa
@staticmethod @staticmethod
def _stderr__patched__write__(*args, **kwargs): def _stderr__patched__write__(*args, **kwargs):
if StdStreamPatch._stderr_proxy: if StdStreamPatch._stderr_proxy:
return StdStreamPatch._stderr_proxy.write(*args, **kwargs) return StdStreamPatch._stderr_proxy.write(*args, **kwargs)
return sys.stderr._original_write(*args, **kwargs) return sys.stderr._original_write(*args, **kwargs) # noqa
class HandlerFormat(object):
def __init__(self, logger):
self._logger = logger
def __call__(self, original_format_func, *args):
# hack get back original function, so we can remove it
if all(a is None for a in args):
return original_format_func
if len(args) == 1:
record = args[0]
msg = original_format_func(record)
else:
handler = args[0]
record = args[1]
msg = original_format_func(handler, record)
self._logger.report_text(msg=msg, level=record.levelno, print_console=False)
return msg
class PrintPatchLogger(object): class PrintPatchLogger(object):
@ -122,6 +171,7 @@ class PrintPatchLogger(object):
patched = False patched = False
lock = threading.Lock() lock = threading.Lock()
recursion_protect_lock = threading.RLock() recursion_protect_lock = threading.RLock()
lf_flush_period = config.get("development.worker.console_lf_flush_period", 0)
def __init__(self, stream, logger=None, level=logging.INFO): def __init__(self, stream, logger=None, level=logging.INFO):
PrintPatchLogger.patched = True PrintPatchLogger.patched = True
@ -129,23 +179,35 @@ class PrintPatchLogger(object):
self._log = logger self._log = logger
self._log_level = level self._log_level = level
self._cur_line = '' self._cur_line = ''
self._force_lf_flush = False
def write(self, message): def write(self, message):
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us) # make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned(): if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned(): # noqa
try: try:
self.lock.acquire() self.lock.acquire()
with PrintPatchLogger.recursion_protect_lock: with PrintPatchLogger.recursion_protect_lock:
if hasattr(self._terminal, '_original_write'): if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message) self._terminal._original_write(message) # noqa
else: else:
self._terminal.write(message) self._terminal.write(message)
do_flush = '\n' in message do_flush = '\n' in message
do_cr = '\r' in message do_cr = '\r' in message
self._cur_line += message self._cur_line += message
if (not do_flush and not do_cr) or not message:
if not do_flush and do_cr and PrintPatchLogger.lf_flush_period and self._force_lf_flush:
self._cur_line += '\n'
do_flush = True
self._force_lf_flush = False
if (not do_flush and (PrintPatchLogger.lf_flush_period or not do_cr)) or not message:
return return
if PrintPatchLogger.lf_flush_period and self._cur_line:
self._cur_line = '\n'.join(line.split('\r')[-1] for line in self._cur_line.split('\n'))
last_lf = self._cur_line.rindex('\n' if do_flush else '\r') last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
next_line = self._cur_line[last_lf + 1:] next_line = self._cur_line[last_lf + 1:]
cur_line = self._cur_line[:last_lf + 1].rstrip() cur_line = self._cur_line[:last_lf + 1].rstrip()
@ -158,13 +220,14 @@ class PrintPatchLogger(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
if self._log: if self._log:
# noinspection PyProtectedMember
self._log._console(cur_line, level=self._log_level, omit_console=True) self._log._console(cur_line, level=self._log_level, omit_console=True)
except Exception: except Exception:
# what can we do, nothing # what can we do, nothing
pass pass
else: else:
if hasattr(self._terminal, '_original_write'): if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message) self._terminal._original_write(message) # noqa
else: else:
self._terminal.write(message) self._terminal.write(message)
@ -177,13 +240,16 @@ class PrintPatchLogger(object):
if not logger or self._log == logger: if not logger or self._log == logger:
self.connect(None) self.connect(None)
def force_lf_flush(self):
self._force_lf_flush = True
def __getattr__(self, attr): def __getattr__(self, attr):
if attr in ['_log', '_terminal', '_log_level', '_cur_line']: if attr in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
return self.__dict__.get(attr) return self.__dict__.get(attr)
return getattr(self._terminal, attr) return getattr(self._terminal, attr)
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key in ['_log', '_terminal', '_log_level', '_cur_line']: if key in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
self.__dict__[key] = value self.__dict__[key] = value
else: else:
return setattr(self._terminal, key, value) return setattr(self._terminal, key, value)
@ -197,6 +263,11 @@ class LogFlusher(threading.Thread):
self._period = period self._period = period
self._logger = logger self._logger = logger
self._exit_event = threading.Event() self._exit_event = threading.Event()
self._lf_last_flush = 0
try:
self._lf_flush_period = float(PrintPatchLogger.lf_flush_period)
except (ValueError, TypeError):
self._lf_flush_period = 0
@property @property
def period(self): def period(self):
@ -208,7 +279,15 @@ class LogFlusher(threading.Thread):
while True: while True:
period = self._period period = self._period
while not self._exit_event.wait(period or 1.0): while not self._exit_event.wait(period or 1.0):
if self._lf_flush_period and time() - self._lf_last_flush > self._lf_flush_period:
if isinstance(sys.stdout, PrintPatchLogger):
sys.stdout.force_lf_flush()
if isinstance(sys.stderr, PrintPatchLogger):
sys.stderr.force_lf_flush()
self._lf_last_flush = time()
# now signal the real flush
self._logger.flush() self._logger.flush()
# check if period is negative or None we should exit # check if period is negative or None we should exit
if self._period is None or self._period < 0: if self._period is None or self._period < 0:
break break

View File

@ -1,6 +1,7 @@
import abc import abc
import hashlib import hashlib
import time import time
from logging import getLevelName
from multiprocessing import Lock from multiprocessing import Lock
import attr import attr
@ -132,6 +133,24 @@ class ScalarEvent(MetricsEventAdapter):
**self._get_base_dict()) **self._get_base_dict())
class ConsoleEvent(MetricsEventAdapter):
""" Console log event adapter """
def __init__(self, message, level, worker, **kwargs):
self._value = str(message)
self._level = getLevelName(level) if isinstance(level, int) else str(level)
self._worker = worker
super(ConsoleEvent, self).__init__(metric=None, variant=None, iter=0, **kwargs)
def get_api_event(self):
return events.TaskLogEvent(
task=self._task,
timestamp=self._timestamp,
level=self._level,
worker=self._worker,
msg=self._value)
class VectorEvent(MetricsEventAdapter): class VectorEvent(MetricsEventAdapter):
""" Vector event adapter """ """ Vector event adapter """

View File

@ -1,8 +1,9 @@
import json import json
import logging
import math import math
try: try:
from collections.abc import Iterable from collections.abc import Iterable # noqa
except ImportError: except ImportError:
from collections import Iterable from collections import Iterable
@ -17,7 +18,8 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \ create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
create_image_plot, create_plotly_table create_image_plot, create_plotly_table
from ...utilities.py3_interop import AbstractContextManager from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, \
UploadEvent, MediaEvent, ConsoleEvent
from ...config import config from ...config import config
@ -143,7 +145,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param value: Reported value :param value: Reported value
:type value: float :type value: float
:param iter: Iteration number :param iter: Iteration number
:type value: int :type iter: int
""" """
ev = ScalarEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), value=value, ev = ScalarEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), value=value,
iter=iter) iter=iter)
@ -157,9 +159,9 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant) :param series: Series (AKA variant)
:type series: str :type series: str
:param values: Reported values :param values: Reported values
:type value: [float] :type values: [float]
:param iter: Iteration number :param iter: Iteration number
:type value: int :type iter: int
""" """
if not isinstance(values, Iterable): if not isinstance(values, Iterable):
raise ValueError('values: expected an iterable') raise ValueError('values: expected an iterable')
@ -190,7 +192,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type plot: str or dict :type plot: str or dict
:param iter: Iteration number :param iter: Iteration number
:param round_digits: number of digits after the dot to leave :param round_digits: number of digits after the dot to leave
:type value: int :type round_digits: int
""" """
def floatstr(o): def floatstr(o):
if o != o: if o != o:
@ -251,7 +253,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
for presentation of processing. Currently only http(s), file and s3 schemes are supported. for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str :type src: str
:param iter: Iteration number :param iter: Iteration number
:type value: int :type iter: int
""" """
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
src=src) src=src)
@ -268,7 +270,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
for presentation of processing. Currently only http(s), file and s3 schemes are supported. for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str :type src: str
:param iter: Iteration number :param iter: Iteration number
:type value: int :type iter: int
""" """
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter, ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
src=src) src=src)
@ -290,6 +292,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type path: str :type path: str
:param image: Image data. Required unless filename is provided. :param image: Image data. Required unless filename is provided.
:type image: A PIL.Image.Image object or a 3D numpy.ndarray object :type image: A PIL.Image.Image object or a 3D numpy.ndarray object
:param upload_uri: Destination URL
:param max_image_history: maximum number of image to store per metric/variant combination :param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5) use negative value for unlimited. default is set in global configuration (default=5)
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted :param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
@ -321,6 +324,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param path: A path to an image file. Required unless matrix is provided. :param path: A path to an image file. Required unless matrix is provided.
:type path: str :type path: str
:param stream: File/String stream :param stream: File/String stream
:param upload_uri: Destination URL
:param file_extension: file extension to use when stream is passed :param file_extension: file extension to use when stream is passed
:param max_history: maximum number of files to store per metric/variant combination :param max_history: maximum number of files to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5) use negative value for unlimited. default is set in global configuration (default=5)
@ -353,7 +357,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
A row for each dataset(bar in a bar group). A column for each bucket. A row for each dataset(bar in a bar group). A column for each bucket.
:type histogram: numpy array :type histogram: numpy array
:param iter: Iteration number :param iter: Iteration number
:type value: int :type iter: int
:param labels: The labels for each bar group. :param labels: The labels for each bar group.
:type labels: list of strings. :type labels: list of strings.
:param xlabels: The labels of the x axis. :param xlabels: The labels of the x axis.
@ -466,7 +470,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant) :param series: Series (AKA variant)
:type series: str :type series: str
:param data: A scattered data: pairs of x,y as rows in a numpy array :param data: A scattered data: pairs of x,y as rows in a numpy array
:type scatter: ndarray :type data: ndarray
:param iter: Iteration number :param iter: Iteration number
:type iter: int :type iter: int
:param mode: (type str) 'lines'/'markers'/'lines+markers' :param mode: (type str) 'lines'/'markers'/'lines+markers'
@ -520,16 +524,17 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param xtitle: optional x-axis title :param xtitle: optional x-axis title
:param ytitle: optional y-axis title :param ytitle: optional y-axis title
:param ztitle: optional z-axis title :param ztitle: optional z-axis title
:param fill: optional
:param comment: comment underneath the title :param comment: comment underneath the title
:param layout_config: optional dictionary for layout configuration, passed directly to plotly :param layout_config: optional dictionary for layout configuration, passed directly to plotly
:type layout_config: dict or None :type layout_config: dict or None
""" """
data_series = data if isinstance(data, list) else [data] data_series = data if isinstance(data, list) else [data]
def get_labels(i): def get_labels(a_i):
if labels and isinstance(labels, list): if labels and isinstance(labels, list):
try: try:
item = labels[i] item = labels[a_i]
except IndexError: except IndexError:
item = labels[-1] item = labels[-1]
if isinstance(item, list): if isinstance(item, list):
@ -666,11 +671,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant) :param series: Series (AKA variant)
:type series: str :type series: str
:param iter: Iteration number :param iter: Iteration number
:type value: int :type iter: int
:param path: A path to an image file. Required unless matrix is provided. :param path: A path to an image file. Required unless matrix is provided.
:type path: str :type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided. :param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str :type matrix: np.ndarray
:param upload_uri: upload image destination (str) :param upload_uri: upload image destination (str)
:type upload_uri: str :type upload_uri: str
:param max_image_history: maximum number of image to store per metric/variant combination :param max_image_history: maximum number of image to store per metric/variant combination
@ -686,8 +691,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
file_history_size=max_image_history) file_history_size=max_image_history)
if matrix is not None: if matrix is not None:
width = matrix.shape[1] width = matrix.shape[1] # noqa
height = matrix.shape[0] height = matrix.shape[0] # noqa
else: else:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
@ -722,6 +727,21 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
iter=iter, iter=iter,
) )
def report_console(self, message, level=logging.INFO):
"""
Report a scalar value
:param message: message (AKA metric)
:type message: str
:param level: log level (int or string, log level)
:type level: int
"""
ev = ConsoleEvent(
message=message,
level=level,
worker=self.session.worker,
)
self._report(ev)
@classmethod @classmethod
def _normalize_name(cls, name): def _normalize_name(cls, name):
return name return name

View File

@ -172,57 +172,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if running_remotely() or DevWorker.report_stdout: if running_remotely() or DevWorker.report_stdout:
log_to_backend = False log_to_backend = False
self._log_to_backend = log_to_backend self._log_to_backend = log_to_backend
self._setup_log(default_log_to_backend=log_to_backend)
self._artifacts_manager = Artifacts(self) self._artifacts_manager = Artifacts(self)
self._hyper_params_manager = HyperParams(self) self._hyper_params_manager = HyperParams(self)
def _setup_log(self, default_log_to_backend=None, replace_existing=False):
"""
Setup logging facilities for this task.
:param default_log_to_backend: Should this task log to the backend. If not specified, value for this option
will be obtained from the environment, with this value acting as a default in case configuration for this is
missing.
If the value for this option is false, we won't touch the current logger configuration regarding TaskHandler(s)
:param replace_existing: If True and another task is already logging to the backend, replace the handler with
a handler for this task.
"""
# Make sure urllib is never in debug/info,
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
logging.getLogger('urllib3').setLevel(logging.WARNING)
log_to_backend = get_log_to_backend(default=default_log_to_backend) or self._log_to_backend
if not log_to_backend:
return
# Handle the root logger and our own logger. We use set() to make sure we create no duplicates
# in case these are the same logger...
loggers = {logging.getLogger(), LoggerRoot.get_base_logger()}
# Find all TaskHandler handlers for these loggers
handlers = {logger: h for logger in loggers for h in logger.handlers if isinstance(h, TaskHandler)}
if handlers and not replace_existing:
# Handlers exist and we shouldn't replace them
return
# Remove all handlers, we'll add new ones
for logger, handler in handlers.items():
logger.removeHandler(handler)
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
# handler instance handled them)
backend_handler = TaskHandler(task=self)
# Add backend handler to both loggers:
# 1. to root logger root logger
# 2. to our own logger as well, since our logger is not propagated to the root logger
# (if we propagate our logger will be caught be the root handlers as well, and
# we do not want that)
for logger in loggers:
logger.addHandler(backend_handler)
def _validate(self, check_output_dest_credentials=True): def _validate(self, check_output_dest_credentials=True):
raise_errors = self._raise_on_validation_errors raise_errors = self._raise_on_validation_errors
output_dest = self.get_output_destination(raise_on_error=False, log_on_error=False) output_dest = self.get_output_destination(raise_on_error=False, log_on_error=False)

View File

@ -176,6 +176,10 @@
# Log all stdout & stderr # Log all stdout & stderr
log_stdout: true log_stdout: true
# Line feed (\r) support. If zero (0) \r treated as \n and flushed to backend
# Line feed flush support in seconds, flush consecutive line feeds (\r) every X (default: 10) seconds
console_lf_flush_period: 10
# compatibility feature, report memory usage for the entire machine # compatibility feature, report memory usage for the entire machine
# default (false), report only on the running process and its sub-processes # default (false), report only on the running process and its sub-processes
report_global_mem_used: false report_global_mem_used: false

View File

@ -5,13 +5,15 @@ from typing import Any, Sequence, Union, List, Optional, Tuple, Dict, TYPE_CHECK
import numpy as np import numpy as np
import six import six
from PIL import Image
from pathlib2 import Path
from .debugging.log import LoggerRoot
try: try:
import pandas as pd import pandas as pd
except ImportError: except ImportError:
pd = None pd = None
from PIL import Image
from pathlib2 import Path
from .backend_interface.logger import StdStreamPatch, LogFlusher from .backend_interface.logger import StdStreamPatch, LogFlusher
from .backend_interface.task import Task as _Task from .backend_interface.task import Task as _Task
@ -19,7 +21,6 @@ from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.log import TaskHandler from .backend_interface.task.log import TaskHandler
from .backend_interface.util import mutually_exclusive from .backend_interface.util import mutually_exclusive
from .config import running_remotely, get_cache_dir, config from .config import running_remotely, get_cache_dir, config
from .debugging.log import LoggerRoot
from .errors import UsageError from .errors import UsageError
from .storage.helper import StorageHelper from .storage.helper import StorageHelper
from .utilities.plotly_reporter import SeriesInfo from .utilities.plotly_reporter import SeriesInfo
@ -29,9 +30,8 @@ warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from matplotlib.figure import Figure as MatplotlibFigure from matplotlib.figure import Figure as MatplotlibFigure # noqa
from matplotlib import pyplot from matplotlib import pyplot # noqa
from plotly.graph_objects import Figure
class Logger(object): class Logger(object):
@ -60,7 +60,7 @@ class Logger(object):
_tensorboard_logging_auto_group_scalars = False _tensorboard_logging_auto_group_scalars = False
_tensorboard_single_series_per_graph = config.get('metrics.tensorboard_single_series_per_graph', False) _tensorboard_single_series_per_graph = config.get('metrics.tensorboard_single_series_per_graph', False)
def __init__(self, private_task): def __init__(self, private_task, connect_stdout=True, connect_stderr=True, connect_logging=False):
""" """
.. warning:: .. warning::
**Do not construct Logger manually!** **Do not construct Logger manually!**
@ -73,12 +73,26 @@ class Logger(object):
self._default_upload_destination = None self._default_upload_destination = None
self._flusher = None self._flusher = None
self._report_worker = None self._report_worker = None
self._task_handler = None
self._graph_titles = {} self._graph_titles = {}
self._tensorboard_series_force_prefix = None self._tensorboard_series_force_prefix = None
self._task_handler = TaskHandler(task=self._task, capacity=100)
self._connect_std_streams = connect_stdout or connect_stderr
self._connect_logging = connect_logging
if self._task.is_main_task(): # Make sure urllib is never in debug/info,
StdStreamPatch.patch_std_streams(self) disable_urllib3_info = config.get('log.disable_urllib3_info', True)
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
logging.getLogger('urllib3').setLevel(logging.WARNING)
StdStreamPatch.patch_std_streams(self, connect_stdout=connect_stdout, connect_stderr=connect_stderr)
if self._connect_logging:
StdStreamPatch.patch_logging_formatter(self)
elif not self._connect_std_streams:
# make sure that at least the main trains logger is connect
base_logger = LoggerRoot.get_base_logger()
if base_logger and base_logger.handlers:
StdStreamPatch.patch_logging_formatter(self, base_logger.handlers[0])
@classmethod @classmethod
def current_logger(cls): def current_logger(cls):
@ -381,6 +395,7 @@ class Logger(object):
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}} example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
""" """
# noinspection PyArgumentList
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series] series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
# if task was not started, we have to start it # if task was not started, we have to start it
@ -832,7 +847,7 @@ class Logger(object):
upload_uri = storage.verify_upload(folder_uri=upload_uri) upload_uri = storage.verify_upload(folder_uri=upload_uri)
if isinstance(image, Image.Image): if isinstance(image, Image.Image):
image = np.array(image) image = np.array(image) # noqa
# noinspection PyProtectedMember # noinspection PyProtectedMember
self._task._reporter.report_image_and_upload( self._task._reporter.report_image_and_upload(
title=title, title=title,
@ -1068,7 +1083,7 @@ class Logger(object):
:param float period: The period to flush the logger in seconds. To set no periodic flush, :param float period: The period to flush the logger in seconds. To set no periodic flush,
specify ``None`` or ``0``. specify ``None`` or ``0``.
""" """
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \ if self._task.is_main_task() and self._task_handler and DevWorker.report_period and \
not running_remotely() and period is not None: not running_remotely() and period is not None:
period = min(period or DevWorker.report_period, DevWorker.report_period) period = min(period or DevWorker.report_period, DevWorker.report_period)
@ -1099,6 +1114,30 @@ class Logger(object):
self.report_image(title=title, series=series, iteration=iteration, local_path=path, image=matrix, self.report_image(title=title, series=series, iteration=iteration, local_path=path, image=matrix,
max_image_history=max_image_history, delete_after_upload=delete_after_upload) max_image_history=max_image_history, delete_after_upload=delete_after_upload)
def capture_logging(self):
# type: () -> "_LoggingContext"
"""
Return context capturing all the logs (via logging) reported under the context
:return: a ContextManager
"""
class _LoggingContext(object):
def __init__(self, a_logger):
self.logger = a_logger
def __enter__(self, *_, **__):
if not self.logger:
return
StdStreamPatch.patch_logging_formatter(self.logger)
def __exit__(self, *_, **__):
if not self.logger:
return
StdStreamPatch.remove_patch_logging_formatter()
# Do nothing if we already have full logging support
return _LoggingContext(None if self._connect_logging else self)
@classmethod @classmethod
def tensorboard_auto_group_scalars(cls, group_scalars=False): def tensorboard_auto_group_scalars(cls, group_scalars=False):
# type: (bool) -> None # type: (bool) -> None
@ -1137,7 +1176,7 @@ class Logger(object):
def _remove_std_logger(cls): def _remove_std_logger(cls):
StdStreamPatch.remove_std_logger() StdStreamPatch.remove_std_logger()
def _console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs): def _console(self, msg, level=logging.INFO, omit_console=False, *args, **_):
# type: (str, int, bool, Any, Any) -> None # type: (str, int, bool, Any, Any) -> None
""" """
print text to log (same as print to console, and also prints to console) print text to log (same as print to console, and also prints to console)
@ -1158,29 +1197,32 @@ class Logger(object):
msg='Logger failed casting log level "%s" to integer' % str(level)) msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO level = logging.INFO
if not running_remotely(): # noinspection PyProtectedMember
if not running_remotely() or not self._task._is_remote_main_task():
if self._task_handler:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
record = self._task.log.makeRecord( record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None "console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
) )
# find the task handler that matches our task # find the task handler that matches our task
if not self._task_handler:
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers
if isinstance(h, TaskHandler) and h.task_id == self._task.id][0]
self._task_handler.emit(record) self._task_handler.emit(record)
except Exception: except Exception:
# avoid infinite loop, output directly to stderr # avoid infinite loop, output directly to stderr
# noinspection PyBroadException
try: try:
# make sure we are writing to the original stdout # make sure we are writing to the original stdout
StdStreamPatch.stderr_original_write( StdStreamPatch.stderr_original_write(
'trains.Logger failed sending log [level {}]: "{}"\n'.format(level, msg)) 'trains.Logger failed sending log [level {}]: "{}"\n'.format(level, msg))
except Exception: except Exception:
pass pass
else:
# noinspection PyProtectedMember
self._task._reporter.report_console(message=msg, level=level)
if not omit_console: if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing # if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout and not running_remotely(): if self._connect_std_streams and not running_remotely():
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# make sure we are writing to the original stdout # make sure we are writing to the original stdout
@ -1218,7 +1260,7 @@ class Logger(object):
:param path: A path to an image file. Required unless matrix is provided. :param path: A path to an image file. Required unless matrix is provided.
:type path: str :type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided. :param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str :type matrix: np.array
:param max_image_history: maximum number of image to store per metric/variant combination \ :param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5) use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int :type max_image_history: int
@ -1306,14 +1348,15 @@ class Logger(object):
pass pass
def _flush_stdout_handler(self): def _flush_stdout_handler(self):
if self._task_handler and DevWorker.report_stdout: if self._task_handler:
self._task_handler.flush() self._task_handler.flush()
def _close_stdout_handler(self, wait=True): def _close_stdout_handler(self, wait=True):
# detach the sys stdout/stderr # detach the sys stdout/stderr
if self._connect_std_streams:
StdStreamPatch.remove_std_logger(self) StdStreamPatch.remove_std_logger(self)
if self._task_handler and DevWorker.report_stdout: if self._task_handler:
t = self._task_handler t = self._task_handler
self._task_handler = None self._task_handler = None
t.close(wait) t.close(wait)

View File

@ -190,6 +190,7 @@ class Task(_Task):
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]] auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]] auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
auto_resource_monitoring=True, # type: bool auto_resource_monitoring=True, # type: bool
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
): ):
# type: (...) -> Task # type: (...) -> Task
""" """
@ -350,6 +351,24 @@ class Task(_Task):
- ``False`` - Do not automatically create. - ``False`` - Do not automatically create.
- Class Type - Create ResourceMonitor object of the specified class type. - Class Type - Create ResourceMonitor object of the specified class type.
:param auto_connect_streams: Control the automatic logging of stdout and stderr
The values are:
- ``True`` - Automatically connect (default)
- ``False`` - Do not automatically connect
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of stdout and
stderr. The dictionary keys are 'stdout' , 'stderr' and 'logging', the values are booleans.
Keys missing from the dictionary default to ``False``, and an empty dictionary defaults to ``False``.
Notice, the default behaviour is logging stdout/stderr the
`logging` module is logged as a by product of the stderr logging
For example:
.. code-block:: py
auto_connect_streams={'stdout': True, 'stderr': True`, 'logging': False}
:return: The main execution Task (Task context). :return: The main execution Task (Task context).
""" """
@ -390,7 +409,7 @@ class Task(_Task):
# create a new logger (to catch stdout/err) # create a new logger (to catch stdout/err)
cls.__main_task._logger = None cls.__main_task._logger = None
cls.__main_task.__reporter = None cls.__main_task.__reporter = None
cls.__main_task.get_logger() cls.__main_task._get_logger(auto_connect_streams=auto_connect_streams)
cls.__main_task._artifacts_manager = Artifacts(cls.__main_task) cls.__main_task._artifacts_manager = Artifacts(cls.__main_task)
# unregister signal hooks, they cause subprocess to hang # unregister signal hooks, they cause subprocess to hang
# noinspection PyProtectedMember # noinspection PyProtectedMember
@ -453,7 +472,8 @@ class Task(_Task):
continue_last_task=continue_last_task, continue_last_task=continue_last_task,
detect_repo=False if ( detect_repo=False if (
isinstance(auto_connect_frameworks, dict) and isinstance(auto_connect_frameworks, dict) and
not auto_connect_frameworks.get('detect_repository', True)) else True not auto_connect_frameworks.get('detect_repository', True)) else True,
auto_connect_streams=auto_connect_streams,
) )
# set defaults # set defaults
if cls._offline_mode: if cls._offline_mode:
@ -545,7 +565,7 @@ class Task(_Task):
# Make sure we start the logger, it will patch the main logging object and pipe all output # Make sure we start the logger, it will patch the main logging object and pipe all output
# if we are running locally and using development mode worker, we will pipe all stdout to logger. # if we are running locally and using development mode worker, we will pipe all stdout to logger.
# The logger will automatically take care of all patching (we just need to make sure to initialize it) # The logger will automatically take care of all patching (we just need to make sure to initialize it)
logger = task.get_logger() logger = task._get_logger(auto_connect_streams=auto_connect_streams)
# show the debug metrics page in the log, it is very convenient # show the debug metrics page in the log, it is very convenient
if not is_sub_process_task_id: if not is_sub_process_task_id:
if cls._offline_mode: if cls._offline_mode:
@ -2110,7 +2130,7 @@ class Task(_Task):
@classmethod @classmethod
def _create_dev_task( def _create_dev_task(
cls, default_project_name, default_task_name, default_task_type, cls, default_project_name, default_task_name, default_task_type,
reuse_last_task_id, continue_last_task=False, detect_repo=True, reuse_last_task_id, continue_last_task=False, detect_repo=True, auto_connect_streams=True
): ):
if not default_project_name or not default_task_name: if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point # get project name and task name from repository name and entry_point
@ -2230,8 +2250,7 @@ class Task(_Task):
task.reload() task.reload()
# force update of base logger to this current task (this is the main logger task) # force update of base logger to this current task (this is the main logger task)
task._setup_log(replace_existing=True) logger = task._get_logger(auto_connect_streams=auto_connect_streams)
logger = task.get_logger()
if closed_old_task: if closed_old_task:
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id'))) logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task # print warning, reusing/creating a task
@ -2271,8 +2290,8 @@ class Task(_Task):
return task return task
def _get_logger(self, flush_period=NotSet): def _get_logger(self, flush_period=NotSet, auto_connect_streams=False):
# type: (Optional[float]) -> Logger # type: (Optional[float], Union[bool, dict]) -> Logger
""" """
get a logger object for reporting based on the task get a logger object for reporting based on the task
@ -2288,10 +2307,15 @@ class Task(_Task):
# do not recreate logger after task was closed/quit # do not recreate logger after task was closed/quit
if self._at_exit_called: if self._at_exit_called:
raise ValueError("Cannot use Task Logger after task was closed") raise ValueError("Cannot use Task Logger after task was closed")
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object # Get a logger object
self._logger = Logger(private_task=self) self._logger = Logger(
private_task=self,
connect_stdout=(auto_connect_streams is True) or
(isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stdout', False)),
connect_stderr=(auto_connect_streams is True) or
(isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stderr', False)),
connect_logging=isinstance(auto_connect_streams, dict) and auto_connect_streams.get('logging', False),
)
# make sure we set our reported to async mode # make sure we set our reported to async mode
# we make sure we flush it in self._at_exit # we make sure we flush it in self._at_exit
self._reporter.async_enable = True self._reporter.async_enable = True