mirror of
https://github.com/clearml/clearml
synced 2025-04-21 14:54:23 +00:00
Add Task.init auto_connect_streams controlling stdout/stderr/logging capture. Issue #181
This commit is contained in:
parent
e9920e27ed
commit
6e012cb205
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
from time import time
|
||||
|
||||
from ..backend_interface.task.development.worker import DevWorker
|
||||
from ..backend_interface.task.log import TaskHandler
|
||||
from ..config import running_remotely
|
||||
from ..binding.frameworks import _patched_call # noqa
|
||||
from ..config import running_remotely, config
|
||||
|
||||
|
||||
class StdStreamPatch(object):
|
||||
@ -14,51 +14,62 @@ class StdStreamPatch(object):
|
||||
_stderr_original_write = None
|
||||
|
||||
@staticmethod
|
||||
def patch_std_streams(logger):
|
||||
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
|
||||
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, logger, level=logging.INFO)
|
||||
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, logger, level=logging.ERROR)
|
||||
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
|
||||
def patch_std_streams(a_logger, connect_stdout=True, connect_stderr=True):
|
||||
if (connect_stdout or connect_stderr) and not PrintPatchLogger.patched and not running_remotely():
|
||||
StdStreamPatch._stdout_proxy = PrintPatchLogger(sys.stdout, a_logger, level=logging.INFO) \
|
||||
if connect_stdout else None
|
||||
StdStreamPatch._stderr_proxy = PrintPatchLogger(sys.stderr, a_logger, level=logging.ERROR) \
|
||||
if connect_stderr else None
|
||||
|
||||
if StdStreamPatch._stdout_proxy:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if StdStreamPatch._stdout_original_write is None:
|
||||
StdStreamPatch._stdout_original_write = sys.stdout.write
|
||||
if StdStreamPatch._stderr_original_write is None:
|
||||
StdStreamPatch._stderr_original_write = sys.stderr.write
|
||||
|
||||
# this will only work in python 3, guard it with try/catch
|
||||
if not hasattr(sys.stdout, '_original_write'):
|
||||
sys.stdout._original_write = sys.stdout.write
|
||||
sys.stdout.write = StdStreamPatch._stdout__patched__write__
|
||||
except Exception:
|
||||
pass
|
||||
sys.stdout = StdStreamPatch._stdout_proxy
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.__stdout__ = sys.stdout
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if StdStreamPatch._stderr_proxy:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if StdStreamPatch._stderr_original_write is None:
|
||||
StdStreamPatch._stderr_original_write = sys.stderr.write
|
||||
if not hasattr(sys.stderr, '_original_write'):
|
||||
sys.stderr._original_write = sys.stderr.write
|
||||
sys.stderr.write = StdStreamPatch._stderr__patched__write__
|
||||
except Exception:
|
||||
pass
|
||||
sys.stdout = StdStreamPatch._stdout_proxy
|
||||
sys.stderr = StdStreamPatch._stderr_proxy
|
||||
|
||||
# patch the base streams of sys (this way colorama will keep its ANSI colors)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.__stderr__ = sys.stderr
|
||||
except Exception:
|
||||
pass
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
sys.__stdout__ = sys.stdout
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# now check if we have loguru and make it re-register the handlers
|
||||
# because it stores internally the stream.write function, which we cant patch
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from loguru import logger
|
||||
from loguru import logger # noqa
|
||||
register_stderr = None
|
||||
register_stdout = None
|
||||
for k, v in logger._handlers.items():
|
||||
if v._name == '<stderr>':
|
||||
for k, v in logger._handlers.items(): # noqa
|
||||
if connect_stderr and v._name == '<stderr>': # noqa
|
||||
register_stderr = k
|
||||
elif v._name == '<stdout>':
|
||||
elif connect_stdout and v._name == '<stdout>': # noqa
|
||||
register_stderr = k
|
||||
if register_stderr is not None:
|
||||
logger.remove(register_stderr)
|
||||
@ -69,12 +80,30 @@ class StdStreamPatch(object):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elif DevWorker.report_stdout and not running_remotely():
|
||||
logger._task_handler = TaskHandler(task=logger._task, capacity=100)
|
||||
if StdStreamPatch._stdout_proxy:
|
||||
StdStreamPatch._stdout_proxy.connect(logger)
|
||||
if StdStreamPatch._stderr_proxy:
|
||||
StdStreamPatch._stderr_proxy.connect(logger)
|
||||
elif (connect_stdout or connect_stderr) and not running_remotely():
|
||||
if StdStreamPatch._stdout_proxy and connect_stdout:
|
||||
StdStreamPatch._stdout_proxy.connect(a_logger)
|
||||
if StdStreamPatch._stderr_proxy and connect_stderr:
|
||||
StdStreamPatch._stderr_proxy.connect(a_logger)
|
||||
|
||||
@staticmethod
|
||||
def patch_logging_formatter(a_logger, logging_handler=None):
|
||||
if not logging_handler:
|
||||
import logging
|
||||
logging_handler = logging.Handler
|
||||
logging_handler.format = _patched_call(logging_handler.format, HandlerFormat(a_logger))
|
||||
|
||||
@staticmethod
|
||||
def remove_patch_logging_formatter(logging_handler=None):
|
||||
if not logging_handler:
|
||||
import logging
|
||||
logging_handler = logging.Handler
|
||||
# remove the function, Hack calling patched logging.Handler.format() returns the original function
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
logging_handler.format = logging_handler.format() # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def remove_std_logger(logger=None):
|
||||
@ -105,13 +134,33 @@ class StdStreamPatch(object):
|
||||
def _stdout__patched__write__(*args, **kwargs):
|
||||
if StdStreamPatch._stdout_proxy:
|
||||
return StdStreamPatch._stdout_proxy.write(*args, **kwargs)
|
||||
return sys.stdout._original_write(*args, **kwargs)
|
||||
return sys.stdout._original_write(*args, **kwargs) # noqa
|
||||
|
||||
@staticmethod
|
||||
def _stderr__patched__write__(*args, **kwargs):
|
||||
if StdStreamPatch._stderr_proxy:
|
||||
return StdStreamPatch._stderr_proxy.write(*args, **kwargs)
|
||||
return sys.stderr._original_write(*args, **kwargs)
|
||||
return sys.stderr._original_write(*args, **kwargs) # noqa
|
||||
|
||||
|
||||
class HandlerFormat(object):
|
||||
def __init__(self, logger):
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, original_format_func, *args):
|
||||
# hack get back original function, so we can remove it
|
||||
if all(a is None for a in args):
|
||||
return original_format_func
|
||||
if len(args) == 1:
|
||||
record = args[0]
|
||||
msg = original_format_func(record)
|
||||
else:
|
||||
handler = args[0]
|
||||
record = args[1]
|
||||
msg = original_format_func(handler, record)
|
||||
|
||||
self._logger.report_text(msg=msg, level=record.levelno, print_console=False)
|
||||
return msg
|
||||
|
||||
|
||||
class PrintPatchLogger(object):
|
||||
@ -122,6 +171,7 @@ class PrintPatchLogger(object):
|
||||
patched = False
|
||||
lock = threading.Lock()
|
||||
recursion_protect_lock = threading.RLock()
|
||||
lf_flush_period = config.get("development.worker.console_lf_flush_period", 0)
|
||||
|
||||
def __init__(self, stream, logger=None, level=logging.INFO):
|
||||
PrintPatchLogger.patched = True
|
||||
@ -129,23 +179,35 @@ class PrintPatchLogger(object):
|
||||
self._log = logger
|
||||
self._log_level = level
|
||||
self._cur_line = ''
|
||||
self._force_lf_flush = False
|
||||
|
||||
def write(self, message):
|
||||
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
|
||||
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
|
||||
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned(): # noqa
|
||||
try:
|
||||
self.lock.acquire()
|
||||
with PrintPatchLogger.recursion_protect_lock:
|
||||
if hasattr(self._terminal, '_original_write'):
|
||||
self._terminal._original_write(message)
|
||||
self._terminal._original_write(message) # noqa
|
||||
else:
|
||||
self._terminal.write(message)
|
||||
|
||||
do_flush = '\n' in message
|
||||
do_cr = '\r' in message
|
||||
self._cur_line += message
|
||||
if (not do_flush and not do_cr) or not message:
|
||||
|
||||
if not do_flush and do_cr and PrintPatchLogger.lf_flush_period and self._force_lf_flush:
|
||||
self._cur_line += '\n'
|
||||
do_flush = True
|
||||
|
||||
self._force_lf_flush = False
|
||||
|
||||
if (not do_flush and (PrintPatchLogger.lf_flush_period or not do_cr)) or not message:
|
||||
return
|
||||
|
||||
if PrintPatchLogger.lf_flush_period and self._cur_line:
|
||||
self._cur_line = '\n'.join(line.split('\r')[-1] for line in self._cur_line.split('\n'))
|
||||
|
||||
last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
|
||||
next_line = self._cur_line[last_lf + 1:]
|
||||
cur_line = self._cur_line[:last_lf + 1].rstrip()
|
||||
@ -158,13 +220,14 @@ class PrintPatchLogger(object):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if self._log:
|
||||
# noinspection PyProtectedMember
|
||||
self._log._console(cur_line, level=self._log_level, omit_console=True)
|
||||
except Exception:
|
||||
# what can we do, nothing
|
||||
pass
|
||||
else:
|
||||
if hasattr(self._terminal, '_original_write'):
|
||||
self._terminal._original_write(message)
|
||||
self._terminal._original_write(message) # noqa
|
||||
else:
|
||||
self._terminal.write(message)
|
||||
|
||||
@ -177,13 +240,16 @@ class PrintPatchLogger(object):
|
||||
if not logger or self._log == logger:
|
||||
self.connect(None)
|
||||
|
||||
def force_lf_flush(self):
|
||||
self._force_lf_flush = True
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
|
||||
if attr in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
|
||||
return self.__dict__.get(attr)
|
||||
return getattr(self._terminal, attr)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key in ['_log', '_terminal', '_log_level', '_cur_line']:
|
||||
if key in ['_log', '_terminal', '_log_level', '_cur_line', '_cr_overwrite', '_force_lf_flush']:
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
return setattr(self._terminal, key, value)
|
||||
@ -197,6 +263,11 @@ class LogFlusher(threading.Thread):
|
||||
self._period = period
|
||||
self._logger = logger
|
||||
self._exit_event = threading.Event()
|
||||
self._lf_last_flush = 0
|
||||
try:
|
||||
self._lf_flush_period = float(PrintPatchLogger.lf_flush_period)
|
||||
except (ValueError, TypeError):
|
||||
self._lf_flush_period = 0
|
||||
|
||||
@property
|
||||
def period(self):
|
||||
@ -208,7 +279,15 @@ class LogFlusher(threading.Thread):
|
||||
while True:
|
||||
period = self._period
|
||||
while not self._exit_event.wait(period or 1.0):
|
||||
if self._lf_flush_period and time() - self._lf_last_flush > self._lf_flush_period:
|
||||
if isinstance(sys.stdout, PrintPatchLogger):
|
||||
sys.stdout.force_lf_flush()
|
||||
if isinstance(sys.stderr, PrintPatchLogger):
|
||||
sys.stderr.force_lf_flush()
|
||||
self._lf_last_flush = time()
|
||||
# now signal the real flush
|
||||
self._logger.flush()
|
||||
|
||||
# check if period is negative or None we should exit
|
||||
if self._period is None or self._period < 0:
|
||||
break
|
||||
|
@ -1,6 +1,7 @@
|
||||
import abc
|
||||
import hashlib
|
||||
import time
|
||||
from logging import getLevelName
|
||||
from multiprocessing import Lock
|
||||
|
||||
import attr
|
||||
@ -132,6 +133,24 @@ class ScalarEvent(MetricsEventAdapter):
|
||||
**self._get_base_dict())
|
||||
|
||||
|
||||
class ConsoleEvent(MetricsEventAdapter):
|
||||
""" Console log event adapter """
|
||||
|
||||
def __init__(self, message, level, worker, **kwargs):
|
||||
self._value = str(message)
|
||||
self._level = getLevelName(level) if isinstance(level, int) else str(level)
|
||||
self._worker = worker
|
||||
super(ConsoleEvent, self).__init__(metric=None, variant=None, iter=0, **kwargs)
|
||||
|
||||
def get_api_event(self):
|
||||
return events.TaskLogEvent(
|
||||
task=self._task,
|
||||
timestamp=self._timestamp,
|
||||
level=self._level,
|
||||
worker=self._worker,
|
||||
msg=self._value)
|
||||
|
||||
|
||||
class VectorEvent(MetricsEventAdapter):
|
||||
""" Vector event adapter """
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
|
||||
try:
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable # noqa
|
||||
except ImportError:
|
||||
from collections import Iterable
|
||||
|
||||
@ -17,7 +18,8 @@ from ...utilities.plotly_reporter import create_2d_histogram_plot, create_value_
|
||||
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict, \
|
||||
create_image_plot, create_plotly_table
|
||||
from ...utilities.py3_interop import AbstractContextManager
|
||||
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, UploadEvent, MediaEvent
|
||||
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload, \
|
||||
UploadEvent, MediaEvent, ConsoleEvent
|
||||
from ...config import config
|
||||
|
||||
|
||||
@ -143,7 +145,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param value: Reported value
|
||||
:type value: float
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:type iter: int
|
||||
"""
|
||||
ev = ScalarEvent(metric=self._normalize_name(title), variant=self._normalize_name(series), value=value,
|
||||
iter=iter)
|
||||
@ -157,9 +159,9 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param values: Reported values
|
||||
:type value: [float]
|
||||
:type values: [float]
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:type iter: int
|
||||
"""
|
||||
if not isinstance(values, Iterable):
|
||||
raise ValueError('values: expected an iterable')
|
||||
@ -190,7 +192,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:type plot: str or dict
|
||||
:param iter: Iteration number
|
||||
:param round_digits: number of digits after the dot to leave
|
||||
:type value: int
|
||||
:type round_digits: int
|
||||
"""
|
||||
def floatstr(o):
|
||||
if o != o:
|
||||
@ -251,7 +253,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
|
||||
:type src: str
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:type iter: int
|
||||
"""
|
||||
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
|
||||
src=src)
|
||||
@ -268,7 +270,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
|
||||
:type src: str
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:type iter: int
|
||||
"""
|
||||
ev = ImageEventNoUpload(metric=self._normalize_name(title), variant=self._normalize_name(series), iter=iter,
|
||||
src=src)
|
||||
@ -290,6 +292,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:type path: str
|
||||
:param image: Image data. Required unless filename is provided.
|
||||
:type image: A PIL.Image.Image object or a 3D numpy.ndarray object
|
||||
:param upload_uri: Destination URL
|
||||
:param max_image_history: maximum number of image to store per metric/variant combination
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
:param delete_after_upload: if True, one the file was uploaded the local copy will be deleted
|
||||
@ -321,6 +324,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param path: A path to an image file. Required unless matrix is provided.
|
||||
:type path: str
|
||||
:param stream: File/String stream
|
||||
:param upload_uri: Destination URL
|
||||
:param file_extension: file extension to use when stream is passed
|
||||
:param max_history: maximum number of files to store per metric/variant combination
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
@ -353,7 +357,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
A row for each dataset(bar in a bar group). A column for each bucket.
|
||||
:type histogram: numpy array
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:type iter: int
|
||||
:param labels: The labels for each bar group.
|
||||
:type labels: list of strings.
|
||||
:param xlabels: The labels of the x axis.
|
||||
@ -466,7 +470,7 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param data: A scattered data: pairs of x,y as rows in a numpy array
|
||||
:type scatter: ndarray
|
||||
:type data: ndarray
|
||||
:param iter: Iteration number
|
||||
:type iter: int
|
||||
:param mode: (type str) 'lines'/'markers'/'lines+markers'
|
||||
@ -520,16 +524,17 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param xtitle: optional x-axis title
|
||||
:param ytitle: optional y-axis title
|
||||
:param ztitle: optional z-axis title
|
||||
:param fill: optional
|
||||
:param comment: comment underneath the title
|
||||
:param layout_config: optional dictionary for layout configuration, passed directly to plotly
|
||||
:type layout_config: dict or None
|
||||
"""
|
||||
data_series = data if isinstance(data, list) else [data]
|
||||
|
||||
def get_labels(i):
|
||||
def get_labels(a_i):
|
||||
if labels and isinstance(labels, list):
|
||||
try:
|
||||
item = labels[i]
|
||||
item = labels[a_i]
|
||||
except IndexError:
|
||||
item = labels[-1]
|
||||
if isinstance(item, list):
|
||||
@ -666,11 +671,11 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:type iter: int
|
||||
:param path: A path to an image file. Required unless matrix is provided.
|
||||
:type path: str
|
||||
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
|
||||
:type matrix: str
|
||||
:type matrix: np.ndarray
|
||||
:param upload_uri: upload image destination (str)
|
||||
:type upload_uri: str
|
||||
:param max_image_history: maximum number of image to store per metric/variant combination
|
||||
@ -686,8 +691,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
file_history_size=max_image_history)
|
||||
|
||||
if matrix is not None:
|
||||
width = matrix.shape[1]
|
||||
height = matrix.shape[0]
|
||||
width = matrix.shape[1] # noqa
|
||||
height = matrix.shape[0] # noqa
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
@ -722,6 +727,21 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
def report_console(self, message, level=logging.INFO):
|
||||
"""
|
||||
Report a scalar value
|
||||
:param message: message (AKA metric)
|
||||
:type message: str
|
||||
:param level: log level (int or string, log level)
|
||||
:type level: int
|
||||
"""
|
||||
ev = ConsoleEvent(
|
||||
message=message,
|
||||
level=level,
|
||||
worker=self.session.worker,
|
||||
)
|
||||
self._report(ev)
|
||||
|
||||
@classmethod
|
||||
def _normalize_name(cls, name):
|
||||
return name
|
||||
|
@ -172,57 +172,9 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
if running_remotely() or DevWorker.report_stdout:
|
||||
log_to_backend = False
|
||||
self._log_to_backend = log_to_backend
|
||||
self._setup_log(default_log_to_backend=log_to_backend)
|
||||
self._artifacts_manager = Artifacts(self)
|
||||
self._hyper_params_manager = HyperParams(self)
|
||||
|
||||
def _setup_log(self, default_log_to_backend=None, replace_existing=False):
|
||||
"""
|
||||
Setup logging facilities for this task.
|
||||
:param default_log_to_backend: Should this task log to the backend. If not specified, value for this option
|
||||
will be obtained from the environment, with this value acting as a default in case configuration for this is
|
||||
missing.
|
||||
If the value for this option is false, we won't touch the current logger configuration regarding TaskHandler(s)
|
||||
:param replace_existing: If True and another task is already logging to the backend, replace the handler with
|
||||
a handler for this task.
|
||||
"""
|
||||
# Make sure urllib is never in debug/info,
|
||||
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
|
||||
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
|
||||
log_to_backend = get_log_to_backend(default=default_log_to_backend) or self._log_to_backend
|
||||
if not log_to_backend:
|
||||
return
|
||||
|
||||
# Handle the root logger and our own logger. We use set() to make sure we create no duplicates
|
||||
# in case these are the same logger...
|
||||
loggers = {logging.getLogger(), LoggerRoot.get_base_logger()}
|
||||
|
||||
# Find all TaskHandler handlers for these loggers
|
||||
handlers = {logger: h for logger in loggers for h in logger.handlers if isinstance(h, TaskHandler)}
|
||||
|
||||
if handlers and not replace_existing:
|
||||
# Handlers exist and we shouldn't replace them
|
||||
return
|
||||
|
||||
# Remove all handlers, we'll add new ones
|
||||
for logger, handler in handlers.items():
|
||||
logger.removeHandler(handler)
|
||||
|
||||
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
|
||||
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
|
||||
# handler instance handled them)
|
||||
backend_handler = TaskHandler(task=self)
|
||||
|
||||
# Add backend handler to both loggers:
|
||||
# 1. to root logger root logger
|
||||
# 2. to our own logger as well, since our logger is not propagated to the root logger
|
||||
# (if we propagate our logger will be caught be the root handlers as well, and
|
||||
# we do not want that)
|
||||
for logger in loggers:
|
||||
logger.addHandler(backend_handler)
|
||||
|
||||
def _validate(self, check_output_dest_credentials=True):
|
||||
raise_errors = self._raise_on_validation_errors
|
||||
output_dest = self.get_output_destination(raise_on_error=False, log_on_error=False)
|
||||
|
@ -176,6 +176,10 @@
|
||||
# Log all stdout & stderr
|
||||
log_stdout: true
|
||||
|
||||
# Line feed (\r) support. If zero (0) \r treated as \n and flushed to backend
|
||||
# Line feed flush support in seconds, flush consecutive line feeds (\r) every X (default: 10) seconds
|
||||
console_lf_flush_period: 10
|
||||
|
||||
# compatibility feature, report memory usage for the entire machine
|
||||
# default (false), report only on the running process and its sub-processes
|
||||
report_global_mem_used: false
|
||||
|
@ -5,13 +5,15 @@ from typing import Any, Sequence, Union, List, Optional, Tuple, Dict, TYPE_CHECK
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from PIL import Image
|
||||
from pathlib2 import Path
|
||||
|
||||
from .debugging.log import LoggerRoot
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
pd = None
|
||||
from PIL import Image
|
||||
from pathlib2 import Path
|
||||
|
||||
from .backend_interface.logger import StdStreamPatch, LogFlusher
|
||||
from .backend_interface.task import Task as _Task
|
||||
@ -19,7 +21,6 @@ from .backend_interface.task.development.worker import DevWorker
|
||||
from .backend_interface.task.log import TaskHandler
|
||||
from .backend_interface.util import mutually_exclusive
|
||||
from .config import running_remotely, get_cache_dir, config
|
||||
from .debugging.log import LoggerRoot
|
||||
from .errors import UsageError
|
||||
from .storage.helper import StorageHelper
|
||||
from .utilities.plotly_reporter import SeriesInfo
|
||||
@ -29,9 +30,8 @@ warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.figure import Figure as MatplotlibFigure
|
||||
from matplotlib import pyplot
|
||||
from plotly.graph_objects import Figure
|
||||
from matplotlib.figure import Figure as MatplotlibFigure # noqa
|
||||
from matplotlib import pyplot # noqa
|
||||
|
||||
|
||||
class Logger(object):
|
||||
@ -60,7 +60,7 @@ class Logger(object):
|
||||
_tensorboard_logging_auto_group_scalars = False
|
||||
_tensorboard_single_series_per_graph = config.get('metrics.tensorboard_single_series_per_graph', False)
|
||||
|
||||
def __init__(self, private_task):
|
||||
def __init__(self, private_task, connect_stdout=True, connect_stderr=True, connect_logging=False):
|
||||
"""
|
||||
.. warning::
|
||||
**Do not construct Logger manually!**
|
||||
@ -73,12 +73,26 @@ class Logger(object):
|
||||
self._default_upload_destination = None
|
||||
self._flusher = None
|
||||
self._report_worker = None
|
||||
self._task_handler = None
|
||||
self._graph_titles = {}
|
||||
self._tensorboard_series_force_prefix = None
|
||||
self._task_handler = TaskHandler(task=self._task, capacity=100)
|
||||
self._connect_std_streams = connect_stdout or connect_stderr
|
||||
self._connect_logging = connect_logging
|
||||
|
||||
if self._task.is_main_task():
|
||||
StdStreamPatch.patch_std_streams(self)
|
||||
# Make sure urllib is never in debug/info,
|
||||
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
|
||||
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
|
||||
StdStreamPatch.patch_std_streams(self, connect_stdout=connect_stdout, connect_stderr=connect_stderr)
|
||||
|
||||
if self._connect_logging:
|
||||
StdStreamPatch.patch_logging_formatter(self)
|
||||
elif not self._connect_std_streams:
|
||||
# make sure that at least the main trains logger is connect
|
||||
base_logger = LoggerRoot.get_base_logger()
|
||||
if base_logger and base_logger.handlers:
|
||||
StdStreamPatch.patch_logging_formatter(self, base_logger.handlers[0])
|
||||
|
||||
@classmethod
|
||||
def current_logger(cls):
|
||||
@ -381,6 +395,7 @@ class Logger(object):
|
||||
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
|
||||
"""
|
||||
|
||||
# noinspection PyArgumentList
|
||||
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
|
||||
|
||||
# if task was not started, we have to start it
|
||||
@ -832,7 +847,7 @@ class Logger(object):
|
||||
upload_uri = storage.verify_upload(folder_uri=upload_uri)
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
image = np.array(image)
|
||||
image = np.array(image) # noqa
|
||||
# noinspection PyProtectedMember
|
||||
self._task._reporter.report_image_and_upload(
|
||||
title=title,
|
||||
@ -1068,7 +1083,7 @@ class Logger(object):
|
||||
:param float period: The period to flush the logger in seconds. To set no periodic flush,
|
||||
specify ``None`` or ``0``.
|
||||
"""
|
||||
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
|
||||
if self._task.is_main_task() and self._task_handler and DevWorker.report_period and \
|
||||
not running_remotely() and period is not None:
|
||||
period = min(period or DevWorker.report_period, DevWorker.report_period)
|
||||
|
||||
@ -1099,6 +1114,30 @@ class Logger(object):
|
||||
self.report_image(title=title, series=series, iteration=iteration, local_path=path, image=matrix,
|
||||
max_image_history=max_image_history, delete_after_upload=delete_after_upload)
|
||||
|
||||
def capture_logging(self):
|
||||
# type: () -> "_LoggingContext"
|
||||
"""
|
||||
Return context capturing all the logs (via logging) reported under the context
|
||||
|
||||
:return: a ContextManager
|
||||
"""
|
||||
class _LoggingContext(object):
|
||||
def __init__(self, a_logger):
|
||||
self.logger = a_logger
|
||||
|
||||
def __enter__(self, *_, **__):
|
||||
if not self.logger:
|
||||
return
|
||||
StdStreamPatch.patch_logging_formatter(self.logger)
|
||||
|
||||
def __exit__(self, *_, **__):
|
||||
if not self.logger:
|
||||
return
|
||||
StdStreamPatch.remove_patch_logging_formatter()
|
||||
|
||||
# Do nothing if we already have full logging support
|
||||
return _LoggingContext(None if self._connect_logging else self)
|
||||
|
||||
@classmethod
|
||||
def tensorboard_auto_group_scalars(cls, group_scalars=False):
|
||||
# type: (bool) -> None
|
||||
@ -1137,7 +1176,7 @@ class Logger(object):
|
||||
def _remove_std_logger(cls):
|
||||
StdStreamPatch.remove_std_logger()
|
||||
|
||||
def _console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs):
|
||||
def _console(self, msg, level=logging.INFO, omit_console=False, *args, **_):
|
||||
# type: (str, int, bool, Any, Any) -> None
|
||||
"""
|
||||
print text to log (same as print to console, and also prints to console)
|
||||
@ -1158,29 +1197,32 @@ class Logger(object):
|
||||
msg='Logger failed casting log level "%s" to integer' % str(level))
|
||||
level = logging.INFO
|
||||
|
||||
if not running_remotely():
|
||||
# noinspection PyProtectedMember
|
||||
if not running_remotely() or not self._task._is_remote_main_task():
|
||||
if self._task_handler:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
record = self._task.log.makeRecord(
|
||||
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
|
||||
)
|
||||
# find the task handler that matches our task
|
||||
if not self._task_handler:
|
||||
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers
|
||||
if isinstance(h, TaskHandler) and h.task_id == self._task.id][0]
|
||||
self._task_handler.emit(record)
|
||||
except Exception:
|
||||
# avoid infinite loop, output directly to stderr
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we are writing to the original stdout
|
||||
StdStreamPatch.stderr_original_write(
|
||||
'trains.Logger failed sending log [level {}]: "{}"\n'.format(level, msg))
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# noinspection PyProtectedMember
|
||||
self._task._reporter.report_console(message=msg, level=level)
|
||||
|
||||
if not omit_console:
|
||||
# if we are here and we grabbed the stdout, we need to print the real thing
|
||||
if DevWorker.report_stdout and not running_remotely():
|
||||
if self._connect_std_streams and not running_remotely():
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
# make sure we are writing to the original stdout
|
||||
@ -1218,7 +1260,7 @@ class Logger(object):
|
||||
:param path: A path to an image file. Required unless matrix is provided.
|
||||
:type path: str
|
||||
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
|
||||
:type matrix: str
|
||||
:type matrix: np.array
|
||||
:param max_image_history: maximum number of image to store per metric/variant combination \
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
:type max_image_history: int
|
||||
@ -1306,14 +1348,15 @@ class Logger(object):
|
||||
pass
|
||||
|
||||
def _flush_stdout_handler(self):
|
||||
if self._task_handler and DevWorker.report_stdout:
|
||||
if self._task_handler:
|
||||
self._task_handler.flush()
|
||||
|
||||
def _close_stdout_handler(self, wait=True):
|
||||
# detach the sys stdout/stderr
|
||||
if self._connect_std_streams:
|
||||
StdStreamPatch.remove_std_logger(self)
|
||||
|
||||
if self._task_handler and DevWorker.report_stdout:
|
||||
if self._task_handler:
|
||||
t = self._task_handler
|
||||
self._task_handler = None
|
||||
t.close(wait)
|
||||
|
@ -190,6 +190,7 @@ class Task(_Task):
|
||||
auto_connect_arg_parser=True, # type: Union[bool, Mapping[str, bool]]
|
||||
auto_connect_frameworks=True, # type: Union[bool, Mapping[str, bool]]
|
||||
auto_resource_monitoring=True, # type: bool
|
||||
auto_connect_streams=True, # type: Union[bool, Mapping[str, bool]]
|
||||
):
|
||||
# type: (...) -> Task
|
||||
"""
|
||||
@ -350,6 +351,24 @@ class Task(_Task):
|
||||
- ``False`` - Do not automatically create.
|
||||
- Class Type - Create ResourceMonitor object of the specified class type.
|
||||
|
||||
:param auto_connect_streams: Control the automatic logging of stdout and stderr
|
||||
|
||||
The values are:
|
||||
|
||||
- ``True`` - Automatically connect (default)
|
||||
- ``False`` - Do not automatically connect
|
||||
- A dictionary - In addition to a boolean, you can use a dictionary for fined grained control of stdout and
|
||||
stderr. The dictionary keys are 'stdout' , 'stderr' and 'logging', the values are booleans.
|
||||
Keys missing from the dictionary default to ``False``, and an empty dictionary defaults to ``False``.
|
||||
Notice, the default behaviour is logging stdout/stderr the
|
||||
`logging` module is logged as a by product of the stderr logging
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: py
|
||||
|
||||
auto_connect_streams={'stdout': True, 'stderr': True`, 'logging': False}
|
||||
|
||||
:return: The main execution Task (Task context).
|
||||
"""
|
||||
|
||||
@ -390,7 +409,7 @@ class Task(_Task):
|
||||
# create a new logger (to catch stdout/err)
|
||||
cls.__main_task._logger = None
|
||||
cls.__main_task.__reporter = None
|
||||
cls.__main_task.get_logger()
|
||||
cls.__main_task._get_logger(auto_connect_streams=auto_connect_streams)
|
||||
cls.__main_task._artifacts_manager = Artifacts(cls.__main_task)
|
||||
# unregister signal hooks, they cause subprocess to hang
|
||||
# noinspection PyProtectedMember
|
||||
@ -453,7 +472,8 @@ class Task(_Task):
|
||||
continue_last_task=continue_last_task,
|
||||
detect_repo=False if (
|
||||
isinstance(auto_connect_frameworks, dict) and
|
||||
not auto_connect_frameworks.get('detect_repository', True)) else True
|
||||
not auto_connect_frameworks.get('detect_repository', True)) else True,
|
||||
auto_connect_streams=auto_connect_streams,
|
||||
)
|
||||
# set defaults
|
||||
if cls._offline_mode:
|
||||
@ -545,7 +565,7 @@ class Task(_Task):
|
||||
# Make sure we start the logger, it will patch the main logging object and pipe all output
|
||||
# if we are running locally and using development mode worker, we will pipe all stdout to logger.
|
||||
# The logger will automatically take care of all patching (we just need to make sure to initialize it)
|
||||
logger = task.get_logger()
|
||||
logger = task._get_logger(auto_connect_streams=auto_connect_streams)
|
||||
# show the debug metrics page in the log, it is very convenient
|
||||
if not is_sub_process_task_id:
|
||||
if cls._offline_mode:
|
||||
@ -2110,7 +2130,7 @@ class Task(_Task):
|
||||
@classmethod
|
||||
def _create_dev_task(
|
||||
cls, default_project_name, default_task_name, default_task_type,
|
||||
reuse_last_task_id, continue_last_task=False, detect_repo=True,
|
||||
reuse_last_task_id, continue_last_task=False, detect_repo=True, auto_connect_streams=True
|
||||
):
|
||||
if not default_project_name or not default_task_name:
|
||||
# get project name and task name from repository name and entry_point
|
||||
@ -2230,8 +2250,7 @@ class Task(_Task):
|
||||
task.reload()
|
||||
|
||||
# force update of base logger to this current task (this is the main logger task)
|
||||
task._setup_log(replace_existing=True)
|
||||
logger = task.get_logger()
|
||||
logger = task._get_logger(auto_connect_streams=auto_connect_streams)
|
||||
if closed_old_task:
|
||||
logger.report_text('TRAINS Task: Closing old development task id={}'.format(default_task.get('id')))
|
||||
# print warning, reusing/creating a task
|
||||
@ -2271,8 +2290,8 @@ class Task(_Task):
|
||||
|
||||
return task
|
||||
|
||||
def _get_logger(self, flush_period=NotSet):
|
||||
# type: (Optional[float]) -> Logger
|
||||
def _get_logger(self, flush_period=NotSet, auto_connect_streams=False):
|
||||
# type: (Optional[float], Union[bool, dict]) -> Logger
|
||||
"""
|
||||
get a logger object for reporting based on the task
|
||||
|
||||
@ -2288,10 +2307,15 @@ class Task(_Task):
|
||||
# do not recreate logger after task was closed/quit
|
||||
if self._at_exit_called:
|
||||
raise ValueError("Cannot use Task Logger after task was closed")
|
||||
# force update of base logger to this current task (this is the main logger task)
|
||||
self._setup_log(replace_existing=self.is_main_task())
|
||||
# Get a logger object
|
||||
self._logger = Logger(private_task=self)
|
||||
self._logger = Logger(
|
||||
private_task=self,
|
||||
connect_stdout=(auto_connect_streams is True) or
|
||||
(isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stdout', False)),
|
||||
connect_stderr=(auto_connect_streams is True) or
|
||||
(isinstance(auto_connect_streams, dict) and auto_connect_streams.get('stderr', False)),
|
||||
connect_logging=isinstance(auto_connect_streams, dict) and auto_connect_streams.get('logging', False),
|
||||
)
|
||||
# make sure we set our reported to async mode
|
||||
# we make sure we flush it in self._at_exit
|
||||
self._reporter.async_enable = True
|
||||
|
Loading…
Reference in New Issue
Block a user