Add Task.init auto_connect_streams controlling stdout/stderr/logging capture. Issue #181

This commit is contained in:
allegroai 2020-11-20 15:50:33 +02:00
parent e9920e27ed
commit 6e012cb205
7 changed files with 305 additions and 164 deletions

View File

@ -1,10 +1,10 @@
import logging
import sys
import threading
from time import time
from ..backend_interface.task.development.worker import DevWorker
from ..backend_interface.task.log import TaskHandler
from ..config import running_remotely
from ..binding.frameworks import _patched_call # noqa
from ..config import running_remotely, config
class StdStreamPatch(object):
@ -14,51 +14,62 @@ class StdStreamPatch(object):
_stderr_original_write = None
@staticmethod
def patch_std_streams(logger):
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO)
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR)
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
def patch_std_streams(a_logger, connect_stdout=True, connect_stderr=True):
if (connect_stdout or connect_stderr) and not PrintPatchLogger.patched and not running_remotely():
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, a_logger, level=logging.INFO) \
if connect_stdout else None
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, a_logger, level=logging.ERROR) \
if connect_stderr else None
if StdStreamPatch._stdout_proxy:
# noinspection PyBroadException
try:
if StdStreamPatch._stdout_original_write is None:
StdStreamPatch._stdout_original_write = sys.stdout.write
if StdStreamPatch._stderr_original_write is None:
StdStreamPatch._stderr_original_write = sys.stderr.write
# this will only work in python 3, guard it with try/catch
if not hasattr(sys.stdout, '_original_write'):
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = StdStreamPatch._stdout__patched__write__
except Exception:
pass
sys.stdout = StdStreamPatch._stdout_proxy
# noinspection PyBroadException
try:
sys.__stdout__ = sys.stdout
except Exception:
pass
if StdStreamPatch._stderr_proxy:
# noinspection PyBroadException
try:
if StdStreamPatch._stderr_original_write is None:
StdStreamPatch._stderr_original_write = sys.stderr.write
if not hasattr(sys.stderr, '_original_write'):
sys.stderr._original_write = sys.stderr.write
sys.stderr.write = StdStreamPatch._stderr__patched__write__
except Exception:
pass
sys.stdout = StdStreamPatch._stdout_proxy
sys.stderr = StdStreamPatch._stderr_proxy
# patch the base streams of sys (this way colorama will keep its ANSI colors)
# noinspection PyBroadException
try:
sys.__stderr__ = sys.stderr
except Exception:
pass
# noinspection PyBroadException
try:
sys.__stdout__ = sys.stdout
except Exception:
pass
# now check if we have loguru and make it re-register the handlers
# because it stores internally the stream.write function, which we cant patch
# noinspection PyBroadException
try:
from loguru import logger
from loguru import logger # noqa
register_stderr = None
register_stdout = None
for k, v in logger._handlers.items():
if v._name == '<stderr>':
for k, v in logger._handlers.items(): # noqa
if connect_stderr and v._name == '<stderr>': # noqa
register_stderr = k
elif v._name == '<stdout>':
elif connect_stdout and v._name == '<stdout>': # noqa
register_stderr = k
if register_stderr is not None:
logger.remove(register_stderr)
@ -69,12 +80,30 @@ class StdStreamPatch(object):
except Exception:
pass
elif DevWorker.report_stdout and not running_remotely():
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
if StdStreamPatch._stdout_proxy:
StdStreamPatch._stdout_proxy.connect(logger)
if StdStreamPatch._stderr_proxy:
StdStreamPatch._stderr_proxy.connect(logger)
elif (connect_stdout or connect_stderr) and not running_remotely():
if StdStreamPatch._stdout_proxy and connect_stdout:
StdStreamPatch._stdout_proxy.connect(a_logger)
if StdStreamPatch._stderr_proxy and connect_stderr:
StdStreamPatch._stderr_proxy.connect(a_logger)
@staticmethod
def patch_logging_formatter(a_logger, logging_handler=None):
if not logging_handler:
import logging
logging_handler = logging.Handler
logging_handler.format = _patched_call(logging_handler.format, HandlerFormat(a_logger))
@staticmethod
def remove_patch_logging_formatter(logging_handler=None):
if not logging_handler:
import logging
logging_handler = logging.Handler
# remove the function, Hack calling patched logging.Handler.format() returns the original function
# noinspection PyBroadException
try:
logging_handler.format = logging_handler.format() # noqa
except Exception:
pass
@staticmethod
def remove_std_logger(logger=None):
@ -105,13 +134,33 @@ class StdStreamPatch(object):
def _stdout__patched__write__(*args, **kwargs):
if StdStreamPatch._stdout_proxy:
return StdStreamPatch._stdout_proxy.write(*args, **kwargs)
return sys.stdout._original_write(*args, **kwargs)
return sys.stdout._original_write(*args, **kwargs) # noqa
@staticmethod
def _stderr__patched__write__(*args, **kwargs):
if StdStreamPatch._stderr_proxy:
return StdStreamPatch._stderr_proxy.write(*args, **kwargs)
return sys.stderr._original_write(*args, **kwargs)
return sys.stderr._original_write(*args, **kwargs) # noqa
class HandlerFormat(object):
def __init__(self, logger):
self._logger = logger
def __call__(self, original_format_func, *args):
# hack get back original function, so we can remove it
if all(a is None for a in args):
return original_format_func
if len(args) == 1:
record = args[0]
msg = original_format_func(record)
else:
handler = args[0]
record = args[1]
msg = original_format_func(handler, record)
self._logger.report_text(msg=msg, level=record.levelno, print_console=False)
return msg
class PrintPatchLogger(object):
@ -122,6 +171,7 @@ class PrintPatchLogger(object):
patched = False
lock = threading.Lock()
recursion_protect_lock = threading.RLock()
lf_flush_period = config.get("development.worker.console_lf_flush_period", 0)
def __init__(self, stream, logger=None, level=logging.INFO):
PrintPatchLogger.patched = True
@ -129,23 +179,35 @@ class PrintPatchLogger(object):
self._log = logger
self._log_level = level
self._cur_line = ''
self._force_lf_flush = False
def write(self, message):
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned(): # noqa
try:
self.lock.acquire()
with PrintPatchLogger.recursion_protect_lock:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
self._terminal._original_write(message) # noqa
else:
self._terminal.write(message)
do_flush = '\n' in message
do_cr = '\r' in message
self._cur_line += message
if (not do_flush and not do_cr) or not message:
if not do_flush and do_cr and PrintPatchLogger.lf_flush_period and self._force_lf_flush:
self._cur_line += '\n'
do_flush = True
self._force_lf_flush = False
if (not do_flush and (PrintPatchLogger.lf_flush_period or not do_cr)) or not message:
return
if PrintPatchLogger.lf_flush_period and self._cur_line:
self._cur_line = '\n'.join(line.split('\r')[-1] for line in self._cur_line.split('\n'))
last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
next_line = self._cur_line[last_lf + 1:]
cur_line = self._cur_line[:last_lf + 1].rstrip()
@ -158,13 +220,14 @@ class PrintPatchLogger(object):
# noinspection PyBroadException
try:
if self._log:
# noinspection PyProtectedMember
self._log._console(cur_line, level=self._log_level, omit_console=True)
except Exception:
# what can we do, nothing
pass
else:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
self._terminal._original_write(message) # noqa
else:
self._terminal.write(message)
@ -177,13 +240,16 @@ class PrintPatchLogger(object):
if not logger or self._log == logger:
self.connect(None)
def force_lf_flush(self):
self._force_lf_flush = True
def __getattr__(self, attr):
if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
if attr in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
return self.__dict__.get(attr)
return getattr(self._terminal, attr)
def __setattr__(self, key, value):
if key in ['_log', '_terminal', '_log_level', '_cur_line']:
if key in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
self.__dict__[key] = value
else:
return setattr(self._terminal, key, value)
@ -197,6 +263,11 @@ class LogFlusher(threading.Thread):
self._period = period
self._logger = logger
self._exit_event = threading.Event()
self._lf_last_flush = 0
try:
self._lf_flush_period = float(PrintPatchLogger.lf_flush_period)
except (ValueError, TypeError):
self._lf_flush_period = 0
@property
def period(self):
@ -208,7 +279,15 @@ class LogFlusher(threading.Thread):
while True:
period = self._period
while not self._exit_event.wait(period or 1.0):
if self._lf_flush_period and time() - self._lf_last_flush > self._lf_flush_period:
if isinstance(sys.stdout, PrintPatchLogger):
sys.stdout.force_lf_flush()
if isinstance(sys.stderr, PrintPatchLogger):
sys.stderr.force_lf_flush()
self._lf_last_flush = time()
# now signal the real flush
self._logger.flush()
# check if period is negative or None we should exit
if self._period is None or self._period < 0:
break

View File

@ -1,6 +1,7 @@
import abc
import hashlib
import time
from logging import getLevelName
from multiprocessing import Lock
import attr
@ -132,6 +133,24 @@ class ScalarEvent(MetricsEventAdapter):
**self._get_base_dict())
class ConsoleEvent(MetricsEventAdapter):
""" Console log event adapter """
def __init__(self, message, level, worker, **kwargs):
self._value = str(message)
self._level = getLevelName(level) if isinstance(level, int) else str(level)
self._worker = worker
super(ConsoleEvent, self).__init__(metric=None, variant=None, iter=0, **kwargs)
def get_api_event(self):
return events.TaskLogEvent(
task=self._task,
timestamp=self._timestamp,
level=self._level,
worker=self._worker,
msg=self._value)
class VectorEvent(MetricsEventAdapter):
""" Vector event adapter """

View File

@ -1,8 +1,9 @@
import json
import logging
import math
try:
from collections.abc import Iterable
from collections.abc import Iterable # noqa
except ImportError:
from collections import Iterable
@ -17,7 +18,8 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
create_image_plot, create_plotly_table
from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, \
UploadEvent, MediaEvent, ConsoleEvent
from ...config import config
@ -143,7 +145,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param value: Reported value
:type value: float
:param iter: Iteration number
:type value: int
:type iter: int
"""
ev = ScalarEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), value=value,
iter=iter)
@ -157,9 +159,9 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant)
:type series: str
:param values: Reported values
:type value: [float]
:type values: [float]
:param iter: Iteration number
:type value: int
:type iter: int
"""
if not isinstance(values, Iterable):
raise ValueError('values: expected an iterable')
@ -190,7 +192,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type plot: str or dict
:param iter: Iteration number
:param round_digits: number of digits after the dot to leave
:type value: int
:type round_digits: int
"""
def floatstr(o):
if o != o:
@ -251,7 +253,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str
:param iter: Iteration number
:type value: int
:type iter: int
"""
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
src=src)
@ -268,7 +270,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str
:param iter: Iteration number
:type value: int
:type iter: int
"""
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
src=src)
@ -290,6 +292,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type path: str
:param image: Image data. Required unless filename is provided.
:type image: A PIL.Image.Image object or a 3D numpy.ndarray object
:param upload_uri: Destination URL
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
@ -321,6 +324,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param stream: File/String stream
:param upload_uri: Destination URL
:param file_extension: file extension to use when stream is passed
:param max_history: maximum number of files to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
@ -353,7 +357,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
A row for each dataset(bar in a bar group). A column for each bucket.
:type histogram: numpy array
:param iter: Iteration number
:type value: int
:type iter: int
:param labels: The labels for each bar group.
:type labels: list of strings.
:param xlabels: The labels of the x axis.
@ -466,7 +470,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant)
:type series: str
:param data: A scattered data: pairs of x,y as rows in a numpy array
:type scatter: ndarray
:type data: ndarray
:param iter: Iteration number
:type iter: int
:param mode: (type str) 'lines'/'markers'/'lines+markers'
@ -520,16 +524,17 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param xtitle: optional x-axis title
:param ytitle: optional y-axis title
:param ztitle: optional z-axis title
:param fill: optional
:param comment: comment underneath the title
:param layout_config: optional dictionary for layout configuration, passed directly to plotly
:type layout_config: dict or None
"""
data_series = data if isinstance(data, list) else [data]
def get_labels(i):
def get_labels(a_i):
if labels and isinstance(labels, list):
try:
item = labels[i]
item = labels[a_i]
except IndexError:
item = labels[-1]
if isinstance(item, list):
@ -666,11 +671,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:param series: Series (AKA variant)
:type series: str
:param iter: Iteration number
:type value: int
:type iter: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:type matrix: np.ndarray
:param upload_uri: upload image destination (str)
:type upload_uri: str
:param max_image_history: maximum number of image to store per metric/variant combination
@ -686,8 +691,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
file_history_size=max_image_history)
if matrix is not None:
width = matrix.shape[1]
height = matrix.shape[0]
width = matrix.shape[1] # noqa
height = matrix.shape[0] # noqa
else:
# noinspection PyBroadException
try:
@ -722,6 +727,21 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
iter=iter,
)
def report_console(self, message, level=logging.INFO):
"""
Report a scalar value
:param message: message (AKA metric)
:type message: str
:param level: log level (int or string, log level)
:type level: int
"""
ev = ConsoleEvent(
message=message,
level=level,
worker=self.session.worker,
)
self._report(ev)
@classmethod
def _normalize_name(cls, name):
return name

View File

@ -172,57 +172,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
if running_remotely() or DevWorker.report_stdout:
log_to_backend = False
self._log_to_backend = log_to_backend
self._setup_log(default_log_to_backend=log_to_backend)
self._artifacts_manager = Artifacts(self)
self._hyper_params_manager = HyperParams(self)
def _setup_log(self, default_log_to_backend=None, replace_existing=False):
"""
Setup logging facilities for this task.
:param default_log_to_backend: Should this task log to the backend. If not specified, value for this option
will be obtained from the environment, with this value acting as a default in case configuration for this is
missing.
If the value for this option is false, we won't touch the current logger configuration regarding TaskHandler(s)
:param replace_existing: If True and another task is already logging to the backend, replace the handler with
a handler for this task.
"""
# Make sure urllib is never in debug/info,
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
logging.getLogger('urllib3').setLevel(logging.WARNING)
log_to_backend = get_log_to_backend(default=default_log_to_backend) or self._log_to_backend
if not log_to_backend:
return
# Handle the root logger and our own logger. We use set() to make sure we create no duplicates
# in case these are the same logger...
loggers = {logging.getLogger(), LoggerRoot.get_base_logger()}
# Find all TaskHandler handlers for these loggers
handlers = {logger: h for logger in loggers for h in logger.handlers if isinstance(h, TaskHandler)}
if handlers and not replace_existing:
# Handlers exist and we shouldn't replace them
return
# Remove all handlers, we'll add new ones
for logger, handler in handlers.items():
logger.removeHandler(handler)
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
# handler instance handled them)
backend_handler = TaskHandler(task=self)
# Add backend handler to both loggers:
# 1. to root logger root logger
# 2. to our own logger as well, since our logger is not propagated to the root logger
# (if we propagate our logger will be caught be the root handlers as well, and
# we do not want that)
for logger in loggers:
logger.addHandler(backend_handler)
def _validate(self, check_output_dest_credentials=True):
raise_errors = self._raise_on_validation_errors
output_dest = self.get_output_destination(raise_on_error=False, log_on_error=False)

View File

@ -176,6 +176,10 @@
# Log all stdout & stderr
log_stdout: true
# Line feed (\r) support. If zero (0) \r treated as \n and flushed to backend
# Line feed flush support in seconds, flush consecutive line feeds (\r) every X (default: 10) seconds
console_lf_flush_period: 10
# compatibility feature, report memory usage for the entire machine
# default (false), report only on the running process and its sub-processes
report_global_mem_used: false

View File

@ -5,13 +5,15 @@ from typing import Any, Sequence, Union, List, Optional, Tuple, Dict, TYPE_CHECK
import numpy as np
import six
from PIL import Image
from pathlib2 import Path
from .debugging.log import LoggerRoot
try:
import pandas as pd
except ImportError:
pd = None
from PIL import Image
from pathlib2 import Path
from .backend_interface.logger import StdStreamPatch, LogFlusher
from .backend_interface.task import Task as _Task
@ -19,7 +21,6 @@ from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.log import TaskHandler
from .backend_interface.util import mutually_exclusive
from .config import running_remotely, get_cache_dir, config
from .debugging.log import LoggerRoot
from .errors import UsageError
from .storage.helper import StorageHelper
from .utilities.plotly_reporter import SeriesInfo
@ -29,9 +30,8 @@ warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
if TYPE_CHECKING:
from matplotlib.figure import Figure as MatplotlibFigure
from matplotlib import pyplot
from plotly.graph_objects import Figure
from matplotlib.figure import Figure as MatplotlibFigure # noqa
from matplotlib import pyplot # noqa
class Logger(object):
@ -60,7 +60,7 @@ class Logger(object):
_tensorboard_logging_auto_group_scalars = False
_tensorboard_single_series_per_graph = config.get('metrics.tensorboard_single_series_per_graph', False)
def __init__(self, private_task):
def __init__(self, private_task, connect_stdout=True, connect_stderr=True, connect_logging=False):
"""
.. warning::
**Do not construct Logger manually!**
@ -73,12 +73,26 @@ class Logger(object):
self._default_upload_destination = None
self._flusher = None
self._report_worker = None
self._task_handler = None
self._graph_titles = {}
self._tensorboard_series_force_prefix = None
self._task_handler = TaskHandler(task=self._task, capacity=100)
self._connect_std_streams = connect_stdout or connect_stderr
self._connect_logging = connect_logging
if self._task.is_main_task():
StdStreamPatch.patch_std_streams(self)
# Make sure urllib is never in debug/info,
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
logging.getLogger('urllib3').setLevel(logging.WARNING)
StdStreamPatch.patch_std_streams(self, connect_stdout=connect_stdout, connect_stderr=connect_stderr)
if self._connect_logging:
StdStreamPatch.patch_logging_formatter(self)
elif not self._connect_std_streams:
# make sure that at least the main trains logger is connect
base_logger = LoggerRoot.get_base_logger()
if base_logger and base_logger.handlers:
StdStreamPatch.patch_logging_formatter(self, base_logger.handlers[0])
@classmethod
def current_logger(cls):
@ -381,6 +395,7 @@ class Logger(object):
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
"""
# noinspection PyArgumentList
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
# if task was not started, we have to start it
@ -832,7 +847,7 @@ class Logger(object):
upload_uri = storage.verify_upload(folder_uri=upload_uri)
if isinstance(image, Image.Image):
image = np.array(image)
image = np.array(image) # noqa
# noinspection PyProtectedMember
self._task._reporter.report_image_and_upload(
title=title,
@ -1068,7 +1083,7 @@ class Logger(object):
:param float period: The period to flush the logger in seconds. To set no periodic flush,
specify ``None`` or ``0``.
"""
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
if self._task.is_main_task() and self._task_handler and DevWorker.report_period and \
not running_remotely() and period is not None:
period = min(period or DevWorker.report_period, DevWorker.report_period)
@ -1099,6 +1114,30 @@ class Logger(object):
self.report_image(title=title, series=series, iteration=iteration, local_path=path, image=matrix,
max_image_history=max_image_history, delete_after_upload=delete_after_upload)
def capture_logging(self):
# type: () -> "_LoggingContext"
"""
Return context capturing all the logs (via logging) reported under the context
:return: a ContextManager
"""
class _LoggingContext(object):
def __init__(self, a_logger):
self.logger = a_logger
def __enter__(self, *_, **__):
if not self.logger:
return
StdStreamPatch.patch_logging_formatter(self.logger)
def __exit__(self, *_, **__):
if not self.logger:
return
StdStreamPatch.remove_patch_logging_formatter()
# Do nothing if we already have full logging support
return _LoggingContext(None if self._connect_logging else self)
@classmethod
def tensorboard_auto_group_scalars(cls, group_scalars=False):
# type: (bool) -> None
@ -1137,7 +1176,7 @@ class Logger(object):
def _remove_std_logger(cls):
StdStreamPatch.remove_std_logger()
def _console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs):
def _console(self, msg, level=logging.INFO, omit_console=False, *args, **_):
# type: (str, int, bool, Any, Any) -> None
"""
print text to log (same as print to console, and also prints to console)
@ -1158,29 +1197,32 @@ class Logger(object):
msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO
if not running_remotely():
# noinspection PyProtectedMember
if not running_remotely() or not self._task._is_remote_main_task():
if self._task_handler:
# noinspection PyBroadException
try:
record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
)
# find the task handler that matches our task
if not self._task_handler:
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers
if isinstance(h, TaskHandler) and h.task_id == self._task.id][0]
self._task_handler.emit(record)
except Exception:
# avoid infinite loop, output directly to stderr
# noinspection PyBroadException
try:
# make sure we are writing to the original stdout
StdStreamPatch.stderr_original_write(
'trains.Logger failed sending log [level {}]: "{}"\n'.format(level, msg))
except Exception:
pass
else:
# noinspection PyProtectedMember
self._task._reporter.report_console(message=msg, level=level)
if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout and not running_remotely():
if self._connect_std_streams and not running_remotely():
# noinspection PyBroadException
try:
# make sure we are writing to the original stdout
@ -1218,7 +1260,7 @@ class Logger(object):
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:type matrix: np.array
:param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int
@ -1306,14 +1348,15 @@ class Logger(object):
pass
def _flush_stdout_handler(self):
if self._task_handler and DevWorker.report_stdout:
if self._task_handler:
self._task_handler.flush()
def _close_stdout_handler(self, wait=True):
# detach the sys stdout/stderr
if self._connect_std_streams:
StdStreamPatch.remove_std_logger(self)
if self._task_handler and DevWorker.report_stdout:
if self._task_handler:
t = self._task_handler
self._task_handler = None
t.close(wait)

View File

@ -190,6 +190,7 @@ class Task(_Task):
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
auto_resource_monitoring=True, # type: bool
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
):
# type: (...) -> Task
"""
@ -350,6 +351,24 @@ class Task(_Task):
- ``False`` - Do not automatically create.
- Class Type - Create ResourceMonitor object of the specified class type.
:param auto_connect_streams: Control the automatic logging of stdout and stderr
The values are:
- ``True`` - Automatically connect (default)
- ``False`` - Do not automatically connect
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of stdout and
stderr. The dictionary keys are 'stdout' , 'stderr' and 'logging', the values are booleans.
Keys missing from the dictionary default to ``False``, and an empty dictionary defaults to ``False``.
Notice, the default behaviour is logging stdout/stderr the
`logging` module is logged as a by product of the stderr logging
For example:
.. code-block:: py
auto_connect_streams={'stdout': True, 'stderr': True`, 'logging': False}
:return: The main execution Task (Task context).
"""
@ -390,7 +409,7 @@ class Task(_Task):
# create a new logger (to catch stdout/err)
cls.__main_task._logger = None
cls.__main_task.__reporter = None
cls.__main_task.get_logger()
cls.__main_task._get_logger(auto_connect_streams=auto_connect_streams)
cls.__main_task._artifacts_manager = Artifacts(cls.__main_task)
# unregister signal hooks, they cause subprocess to hang
# noinspection PyProtectedMember
@ -453,7 +472,8 @@ class Task(_Task):
continue_last_task=continue_last_task,
detect_repo=False if (
isinstance(auto_connect_frameworks, dict) and
not auto_connect_frameworks.get('detect_repository', True)) else True
not auto_connect_frameworks.get('detect_repository', True)) else True,
auto_connect_streams=auto_connect_streams,
)
# set defaults
if cls._offline_mode:
@ -545,7 +565,7 @@ class Task(_Task):
# Make sure we start the logger, it will patch the main logging object and pipe all output
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
# The logger will automatically take care of all patching (we just need to make sure to initialize it)
logger = task.get_logger()
logger = task._get_logger(auto_connect_streams=auto_connect_streams)
# show the debug metrics page in the log, it is very convenient
if not is_sub_process_task_id:
if cls._offline_mode:
@ -2110,7 +2130,7 @@ class Task(_Task):
@classmethod
def _create_dev_task(
cls, default_project_name, default_task_name, default_task_type,
reuse_last_task_id, continue_last_task=False, detect_repo=True,
reuse_last_task_id, continue_last_task=False, detect_repo=True, auto_connect_streams=True
):
if not default_project_name or not default_task_name:
# get project name and task name from repository name and entry_point
@ -2230,8 +2250,7 @@ class Task(_Task):
task.reload()
# force update of base logger to this current task (this is the main logger task)
task._setup_log(replace_existing=True)
logger = task.get_logger()
logger = task._get_logger(auto_connect_streams=auto_connect_streams)
if closed_old_task:
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
# print warning, reusing/creating a task
@ -2271,8 +2290,8 @@ class Task(_Task):
return task
def _get_logger(self, flush_period=NotSet):
# type: (Optional[float]) -> Logger
def _get_logger(self, flush_period=NotSet, auto_connect_streams=False):
# type: (Optional[float], Union[bool, dict]) -> Logger
"""
get a logger object for reporting based on the task
@ -2288,10 +2307,15 @@ class Task(_Task):
# do not recreate logger after task was closed/quit
if self._at_exit_called:
raise ValueError("Cannot use Task Logger after task was closed")
# force update of base logger to this current task (this is the main logger task)
self._setup_log(replace_existing=self.is_main_task())
# Get a logger object
self._logger = Logger(private_task=self)
self._logger = Logger(
private_task=self,
connect_stdout=(auto_connect_streams is True) or
(isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stdout', False)),
connect_stderr=(auto_connect_streams is True) or
(isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stderr', False)),
connect_logging=isinstance(auto_connect_streams, dict) and auto_connect_streams.get('logging', False),
)
# make sure we set our reported to async mode
# we make sure we flush it in self._at_exit
self._reporter.async_enable = True