clearml/trains/logger.py

772 lines
28 KiB
Python

import logging
import re
import sys
import threading
from functools import wraps
import numpy as np
from pathlib2 import Path
from .debugging.log import LoggerRoot
from .backend_interface.task.development.worker import DevWorker
from .backend_interface.task.log import TaskHandler
from .storage import StorageHelper
from .utilities.plotly_reporter import SeriesInfo
from .backend_api.services import tasks
from .backend_interface.task import Task as _Task
from .config import running_remotely, get_cache_dir
def _safe_names(func):
"""
Validate the form of title and series parameters.
This decorator assert that a method receives 'title' and 'series' as its
first positional arguments, and that their values have only legal characters.
'\', '/' and ':' will be replaced automatically by '_'
Whitespace chars will be replaced automatically by ' '
"""
_replacements = {
'_': re.compile(r"[/\\:]"),
' ': re.compile(r"[\s]"),
}
def _make_safe(value):
for repl, regex in _replacements.items():
value = regex.sub(repl, value)
return value
@wraps(func)
def fixed_names(self, title, series, *args, **kwargs):
title = _make_safe(title)
series = _make_safe(series)
func(self, title, series, *args, **kwargs)
return fixed_names
class Logger(object):
"""
Console log and metric statistics interface.
This is how we send graphs/plots/text to the system, later we can compare the performance of different tasks.
**Usage: Task.get_logger()**
"""
SeriesInfo = SeriesInfo
_stdout_proxy = None
_stderr_proxy = None
_stdout_original_write = None
def __init__(self, private_task):
"""
**Do not construct Logger manually!**
please use Task.get_logger()
"""
assert isinstance(private_task, _Task), \
'Logger object cannot be instantiated externally, use Task.get_logger()'
super(Logger, self).__init__()
self._task = private_task
self._default_upload_destination = None
self._flusher = None
self._report_worker = None
self._task_handler = None
if DevWorker.report_stdout and not PrintPatchLogger.patched and not running_remotely():
Logger._stdout_proxy = PrintPatchLogger(sys.stdout, self, level=logging.INFO)
Logger._stderr_proxy = PrintPatchLogger(sys.stderr, self, level=logging.ERROR)
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
# noinspection PyBroadException
try:
Logger._stdout_original_write = sys.stdout.write
# this will only work in python 3, guard it with try/catch
sys.stdout._original_write = sys.stdout.write
sys.stdout.write = stdout__patched__write__
sys.stderr._original_write = sys.stderr.write
sys.stderr.write = stderr__patched__write__
except Exception:
pass
sys.stdout = Logger._stdout_proxy
sys.stderr = Logger._stderr_proxy
elif DevWorker.report_stdout and not running_remotely():
self._task_handler = TaskHandler(self._task.session, self._task.id, capacity=100)
if Logger._stdout_proxy:
Logger._stdout_proxy.connect(self)
if Logger._stderr_proxy:
Logger._stderr_proxy.connect(self)
def console(self, msg, level=logging.INFO, omit_console=False, *args, **kwargs):
"""
print text to log (same as print to console, and also prints to console)
:param msg: text to print to the console (always send to the backend and displayed in console)
:param level: logging level, default: logging.INFO
:param omit_console: If True we only send 'msg' to log (no console print)
"""
try:
level = int(level)
except (TypeError, ValueError):
self._task.log.log(level=logging.ERROR,
msg='Logger failed casting log level "%s" to integer' % str(level))
level = logging.INFO
try:
record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None
)
# find the task handler
if not self._task_handler:
self._task_handler = [h for h in LoggerRoot.get_base_logger().handlers if isinstance(h, TaskHandler)][0]
self._task_handler.emit(record)
except Exception:
self._task.log.log(level=logging.ERROR,
msg='Logger failed sending log: [level %s]: "%s"' % (str(level), str(msg)))
if not omit_console:
# if we are here and we grabbed the stdout, we need to print the real thing
if DevWorker.report_stdout:
try:
# make sure we are writing to the original stdout
Logger._stdout_original_write(str(msg)+'\n')
except Exception:
pass
else:
print(str(msg))
# if task was not started, we have to start it
self._start_task_if_needed()
def report_text(self, msg, level=logging.INFO, print_console=False, *args, **_):
return self.console(msg, level, not print_console, *args, **_)
def debug(self, msg, *args, **kwargs):
""" Print information to the log. This is the same as console(msg, logging.DEBUG) """
self._task.log.log(msg=msg, level=logging.DEBUG, *args, **kwargs)
def info(self, msg, *args, **kwargs):
""" Print information to the log. This is the same as console(msg, logging.INFO) """
self._task.log.log(msg=msg, level=logging.INFO, *args, **kwargs)
def warn(self, msg, *args, **kwargs):
""" Print a warning to the log. This is the same as console(msg, logging.WARNING) """
self._task.log.log(msg=msg, level=logging.WARNING, *args, **kwargs)
warning = warn
def error(self, msg, *args, **kwargs):
""" Print an error to the log. This is the same as console(msg, logging.ERROR) """
self._task.log.log(msg=msg, level=logging.ERROR, *args, **kwargs)
def fatal(self, msg, *args, **kwargs):
""" Print a fatal error to the log. This is the same as console(msg, logging.FATAL) """
self._task.log.log(msg=msg, level=logging.FATAL, *args, **kwargs)
def critical(self, msg, *args, **kwargs):
""" Print a critical error to the log. This is the same as console(msg, logging.CRITICAL) """
self._task.log.log(msg=msg, level=logging.CRITICAL, *args, **kwargs)
def report_scalar(self, title, series, value, iteration):
"""
Report a scalar value
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param value: Reported value
:type value: float
:param iteration: Iteration number
:type value: int
"""
# if task was not started, we have to start it
self._start_task_if_needed()
return self._task.reporter.report_scalar(title=title, series=series, value=float(value), iter=iteration)
def report_vector(self, title, series, values, iteration, labels=None, xlabels=None):
"""
Report a histogram plot
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param values: Reported values (or numpy array)
:type values: [float]
:param iteration: Iteration number
:type iteration: int
:param labels: optional, labels for each bar group.
:type labels: list of strings.
:param xlabels: optional label per entry in the vector (bucket in the histogram)
:type xlabels: list of strings.
"""
return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels)
def report_histogram(self, title, series, values, iteration, labels=None, xlabels=None):
"""
Report a histogram plot
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param values: Reported values (or numpy array)
:type values: [float]
:param iteration: Iteration number
:type iteration: int
:param labels: optional, labels for each bar group.
:type labels: list of strings.
:param xlabels: optional label per entry in the vector (bucket in the histogram)
:type xlabels: list of strings.
"""
if not isinstance(values, np.ndarray):
values = np.array(values)
# if task was not started, we have to start it
self._start_task_if_needed()
return self._task.reporter.report_histogram(
title=title,
series=series,
histogram=values,
iter=iteration,
labels=labels,
xlabels=xlabels,
)
def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines', reverse_xaxis=False, comment=None):
"""
Report a (possibly multiple) line plot.
:param title: Title (AKA metric)
:type title: str
:param series: All the series' data, one for each line in the plot.
:type series: An iterable of LineSeriesInfo.
:param iteration: Iteration number
:type iteration: int
:param xaxis: optional x-axis title
:param yaxis: optional y-axis title
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
:type mode: str
:param reverse_xaxis: If true X axis will be displayed from high to low (reversed)
:type reverse_xaxis: bool
:param comment: comment underneath the title
:type comment: str
"""
series = [self.SeriesInfo(**s) if isinstance(s, dict) else s for s in series]
# if task was not started, we have to start it
self._start_task_if_needed()
return self._task.reporter.report_line_plot(
title=title,
series=series,
iter=iteration,
xtitle=xaxis,
ytitle=yaxis,
mode=mode,
reverse_xaxis=reverse_xaxis,
comment=comment,
)
def report_scatter2d(self, title, series, scatter, iteration, xaxis=None, yaxis=None, labels=None,
mode='lines', comment=None):
"""
Report a 2d scatter graph (with lines)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param scatter: A scattered data: list of (pairs of x,y) (or numpy array)
:type scatter: ndarray or list
:param iteration: Iteration number
:type iteration: int
:param xaxis: optional x-axis title
:param yaxis: optional y-axis title
:param labels: label (text) per point in the scatter (in the same order)
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
:type mode: str
:param comment: comment underneath the title
:type comment: str
"""
if not isinstance(scatter, np.ndarray):
if not isinstance(scatter, list):
scatter = list(scatter)
scatter = np.array(scatter)
# if task was not started, we have to start it
self._start_task_if_needed()
return self._task.reporter.report_2d_scatter(
title=title,
series=series,
data=scatter.astype(np.float32),
iter=iteration,
mode=mode,
xtitle=xaxis,
ytitle=yaxis,
labels=labels,
comment=comment,
)
def report_scatter3d(self, title, series, scatter, iteration, labels=None, mode='markers',
fill=False, comment=None):
"""
Report a 3d scatter graph (with markers)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param scatter: A scattered data: list of (pairs of x,y,z) (or numpy array) or list of series [[(x1,y1,z1)...]]
:type scatter: ndarray or list
:param iteration: Iteration number
:type iteration: int
:param labels: label (text) per point in the scatter (in the same order)
:param mode: scatter plot with 'lines'/'markers'/'lines+markers'
:param fill: fill area under the curve
:param comment: comment underneath the title
"""
# check if multiple series
multi_series = (
isinstance(scatter, list)
and (
isinstance(scatter[0], np.ndarray)
or (
scatter[0]
and isinstance(scatter[0], list)
and isinstance(scatter[0][0], list)
)
)
)
if not multi_series:
if not isinstance(scatter, np.ndarray):
if not isinstance(scatter, list):
scatter = list(scatter)
scatter = np.array(scatter)
try:
scatter = scatter.astype(np.float32)
except ValueError:
pass
# if task was not started, we have to start it
self._start_task_if_needed()
return self._task.reporter.report_3d_scatter(
title=title,
series=series,
data=scatter,
iter=iteration,
labels=labels,
mode=mode,
fill=fill,
comment=comment,
)
def report_confusion_matrix(self, title, series, matrix, iteration, xlabels=None, ylabels=None, comment=None):
"""
Report a heat-map matrix
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param matrix: A heat-map matrix (example: confusion matrix)
:type matrix: ndarray
:param iteration: Iteration number
:type iteration: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
:param comment: comment underneath the title
"""
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix)
# if task was not started, we have to start it
self._start_task_if_needed()
return self._task.reporter.report_value_matrix(
title=title,
series=series,
data=matrix.astype(np.float32),
iter=iteration,
xlabels=xlabels,
ylabels=ylabels,
comment=comment,
)
def report_matrix(self, title, series, matrix, iteration, xlabels=None, ylabels=None):
"""
Same as report_confusion_matrix
Report a heat-map matrix
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param matrix: A heat-map matrix (example: confusion matrix)
:type matrix: ndarray
:param iteration: Iteration number
:type iteration: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
"""
return self.report_confusion_matrix(title, series, matrix, iteration, xlabels=xlabels, ylabels=ylabels)
def report_surface(self, title, series, matrix, iteration, xlabels=None, ylabels=None,
xtitle=None, ytitle=None, camera=None, comment=None):
"""
Report a 3d surface (same data as heat-map matrix, only presented differently)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param matrix: A heat-map matrix (example: confusion matrix)
:type matrix: ndarray
:param iteration: Iteration number
:type iteration: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
:param xtitle: optional x-axis title
:param ytitle: optional y-axis title
:param camera: X,Y,Z camera position. def: (1,1,1)
:param comment: comment underneath the title
"""
if not isinstance(matrix, np.ndarray):
matrix = np.array(matrix)
# if task was not started, we have to start it
self._start_task_if_needed()
return self._task.reporter.report_value_surface(
title=title,
series=series,
data=matrix.astype(np.float32),
iter=iteration,
xlabels=xlabels,
ylabels=ylabels,
xtitle=xtitle,
ytitle=ytitle,
camera=camera,
comment=comment,
)
@_safe_names
def report_image(self, title, series, src, iteration):
"""
Report an image, and register the 'src' as url content.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param src: Image source URI. This URI will be used by the webapp and workers when trying to obtain the image \
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str
:param iteration: Iteration number
:type iteration: int
"""
# if task was not started, we have to start it
self._start_task_if_needed()
self._task.reporter.report_image(
title=title,
series=series,
src=src,
iter=iteration,
)
@_safe_names
def report_image_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None):
"""
Report an image and upload its contents.
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
describing the task ID, title, series and iteration.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param iteration: Iteration number
:type iteration: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int
"""
# if task was not started, we have to start it
self._start_task_if_needed()
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)
storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri)
self._task.reporter.report_image_and_upload(
title=title,
series=series,
path=path,
matrix=matrix,
iter=iteration,
upload_uri=upload_uri,
max_image_history=max_image_history,
)
def report_image_plot_and_upload(self, title, series, iteration, path=None, matrix=None, max_image_history=None):
"""
Report an image, upload its contents, and present in plots section using plotly
Image is uploaded to a preconfigured bucket (see setup_upload()) with a key (filename)
describing the task ID, title, series and iteration.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param iteration: Iteration number
:type iteration: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (RGB). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination \
use negative value for unlimited. default is set in global configuration (default=5)
:type max_image_history: int
"""
# if task was not started, we have to start it
self._start_task_if_needed()
upload_uri = self._default_upload_destination or self._task._get_default_report_storage_uri()
if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images'
upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination
upload_uri = str(upload_uri)
storage = StorageHelper.get(upload_uri)
upload_uri = storage.verify_upload(folder_uri=upload_uri)
self._task.reporter.report_image_plot_and_upload(
title=title,
series=series,
path=path,
matrix=matrix,
iter=iteration,
upload_uri=upload_uri,
max_image_history=max_image_history,
)
def set_default_upload_destination(self, uri):
"""
Set the uri to upload all the debug images to.
Images are uploaded separately to the destination storage (e.g. s3,gc,file) and then
a link to the uploaded image is sent in the report
Notice: credentials for the upload destination will be pooled from the
global configuration file (i.e. ~/trains.conf)
:param uri: example: 's3://bucket/directory/' or 'file:///tmp/debug/'
:return: True if destination scheme is supported (i.e. s3:// file:// gc:// etc...)
"""
# Create the storage helper
storage = StorageHelper.get(uri)
# Verify that we can upload to this destination
uri = storage.verify_upload(folder_uri=uri)
self._default_upload_destination = uri
def flush(self):
"""
Flush cached reports and console outputs to backend.
:return: True if successful
"""
self._flush_stdout_handler()
if self._task:
return self._task.flush()
return False
def get_flush_period(self):
if self._flusher:
return self._flusher.period
return None
def set_flush_period(self, period):
"""
Set the period of the logger flush.
:param period: The period to flush the logger in seconds. If None or 0,
There will be no periodic flush.
"""
if self._task.is_main_task() and DevWorker.report_stdout and DevWorker.report_period and \
not running_remotely() and period is not None:
period = min(period or DevWorker.report_period, DevWorker.report_period)
if not period:
if self._flusher:
self._flusher.exit()
self._flusher = None
elif self._flusher:
self._flusher.set_period(period)
else:
self._flusher = _Flusher(self, period)
self._flusher.start()
@classmethod
def _remove_std_logger(self):
if isinstance(sys.stdout, PrintPatchLogger):
try:
sys.stdout.connect(None)
except Exception:
pass
if isinstance(sys.stderr, PrintPatchLogger):
try:
sys.stderr.connect(None)
except Exception:
pass
def _start_task_if_needed(self):
if self._task._status == tasks.TaskStatusEnum.created:
self._task.mark_started()
self._task._dev_mode_task_start()
def _flush_stdout_handler(self):
if self._task_handler and DevWorker.report_stdout:
self._task_handler.flush()
def stdout__patched__write__(*args, **kwargs):
if Logger._stdout_proxy:
return Logger._stdout_proxy.write(*args, **kwargs)
return sys.stdout._original_write(*args, **kwargs)
def stderr__patched__write__(*args, **kwargs):
if Logger._stderr_proxy:
return Logger._stderr_proxy.write(*args, **kwargs)
return sys.stderr._original_write(*args, **kwargs)
class PrintPatchLogger(object):
"""
Allowed patching a stream into the logger.
Used for capturing and logging stdin and stderr when running in development mode pseudo worker.
"""
patched = False
lock = threading.Lock()
recursion_protect_lock = threading.RLock()
def __init__(self, stream, logger=None, level=logging.INFO):
PrintPatchLogger.patched = True
self._terminal = stream
self._log = logger
self._log_level = level
self._cur_line = ''
def write(self, message):
# make sure that we do not end up in infinite loop (i.e. log.console ends up calling us)
if self._log and not PrintPatchLogger.recursion_protect_lock._is_owned():
try:
self.lock.acquire()
with PrintPatchLogger.recursion_protect_lock:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
else:
self._terminal.write(message)
do_flush = '\n' in message
do_cr = '\r' in message
self._cur_line += message
if (not do_flush and not do_cr) or not message:
return
last_lf = self._cur_line.rindex('\n' if do_flush else '\r')
next_line = self._cur_line[last_lf + 1:]
cur_line = self._cur_line[:last_lf + 1].rstrip()
self._cur_line = next_line
finally:
self.lock.release()
if cur_line:
with PrintPatchLogger.recursion_protect_lock:
self._log.console(cur_line, level=self._log_level, omit_console=True)
else:
if hasattr(self._terminal, '_original_write'):
self._terminal._original_write(message)
else:
self._terminal.write(message)
def connect(self, logger):
if self._log:
self._log._flush_stdout_handler()
self._log = logger
def __getattr__(self, attr):
if attr in ['_log', '_terminal', '_log_level', '_cur_line']:
return self.__dict__.get(attr)
return getattr(self._terminal, attr)
def __setattr__(self, key, value):
if key in ['_log', '_terminal', '_log_level', '_cur_line']:
self.__dict__[key] = value
else:
return setattr(self._terminal, key, value)
class _Flusher(threading.Thread):
def __init__(self, logger, period, **kwargs):
super(_Flusher, self).__init__(**kwargs)
self.daemon = True
self._period = period
self._logger = logger
self._exit_event = threading.Event()
@property
def period(self):
return self._period
def run(self):
self._logger.flush()
# store original wait period
while True:
period = self._period
while not self._exit_event.wait(period or 1.0):
self._logger.flush()
# check if period is negative or None we should exit
if self._period is None or self._period < 0:
break
# check if period was changed, we should restart
self._exit_event.clear()
def exit(self):
self._period = None
self._exit_event.set()
def set_period(self, period):
self._period = period
# make sure we exit the previous wait
self._exit_event.set()