Black formatting

This commit is contained in:
allegroai 2024-01-20 15:18:30 +02:00
parent e7ec688022
commit 0806902f8e

View File

@ -26,7 +26,7 @@ from .storage.helper import StorageHelper
from .utilities.plotly_reporter import SeriesInfo from .utilities.plotly_reporter import SeriesInfo
# Make sure that DeprecationWarning within this package always gets printed # Make sure that DeprecationWarning within this package always gets printed
warnings.filterwarnings('always', category=DeprecationWarning, module=__name__) warnings.filterwarnings("always", category=DeprecationWarning, module=__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -56,9 +56,10 @@ class Logger(object):
""" """
SeriesInfo = SeriesInfo SeriesInfo = SeriesInfo
_tensorboard_logging_auto_group_scalars = False _tensorboard_logging_auto_group_scalars = False
_tensorboard_single_series_per_graph = deferred_config('metrics.tensorboard_single_series_per_graph', False) _tensorboard_single_series_per_graph = deferred_config("metrics.tensorboard_single_series_per_graph", False)
def __init__(self, private_task, connect_stdout=True, connect_stderr=True, connect_logging=False): def __init__(self, private_task, connect_stdout=True, connect_stderr=True, connect_logging=False):
""" """
@ -66,8 +67,9 @@ class Logger(object):
**Do not construct Logger manually!** **Do not construct Logger manually!**
Please use :meth:`Logger.get_current` Please use :meth:`Logger.get_current`
""" """
assert isinstance(private_task, _Task), \ assert isinstance(
'Logger object cannot be instantiated externally, use Logger.current_logger()' private_task, _Task
), "Logger object cannot be instantiated externally, use Logger.current_logger()"
super(Logger, self).__init__() super(Logger, self).__init__()
self._task = private_task self._task = private_task
self._default_upload_destination = None self._default_upload_destination = None
@ -75,16 +77,19 @@ class Logger(object):
self._report_worker = None self._report_worker = None
self._graph_titles = {} self._graph_titles = {}
self._tensorboard_series_force_prefix = None self._tensorboard_series_force_prefix = None
self._task_handler = TaskHandler(task=self._task, capacity=100) \ self._task_handler = (
if private_task.is_main_task() or (connect_stdout or connect_stderr or connect_logging) else None TaskHandler(task=self._task, capacity=100)
if private_task.is_main_task() or (connect_stdout or connect_stderr or connect_logging)
else None
)
self._connect_std_streams = connect_stdout or connect_stderr self._connect_std_streams = connect_stdout or connect_stderr
self._connect_logging = connect_logging self._connect_logging = connect_logging
self._default_max_sample_history = None self._default_max_sample_history = None
# Make sure urllib is never in debug/info, # Make sure urllib is never in debug/info,
disable_urllib3_info = config.get('log.disable_urllib3_info', True) disable_urllib3_info = config.get("log.disable_urllib3_info", True)
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO): if disable_urllib3_info and logging.getLogger("urllib3").isEnabledFor(logging.INFO):
logging.getLogger('urllib3').setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING)
if self._task.is_main_task(): if self._task.is_main_task():
StdStreamPatch.patch_std_streams(self, connect_stdout=connect_stdout, connect_stderr=connect_stderr) StdStreamPatch.patch_std_streams(self, connect_stdout=connect_stdout, connect_stderr=connect_stderr)
@ -112,6 +117,7 @@ class Logger(object):
:return: The Logger object (a singleton) for the current running Task. :return: The Logger object (a singleton) for the current running Task.
""" """
from .task import Task from .task import Task
task = Task.current_task() task = Task.current_task()
if not task: if not task:
return None return None
@ -181,7 +187,7 @@ class Logger(object):
:param name: Metric's name :param name: Metric's name
:param value: Metric's value :param value: Metric's value
""" """
return self.report_scalar(title="Summary", series=name, value=value, iteration=-2**31) return self.report_scalar(title="Summary", series=name, value=value, iteration=-(2**31))
def report_vector( def report_vector(
self, self,
@ -229,13 +235,22 @@ class Logger(object):
example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'} example: extra_layout={'showlegend': False, 'plot_bgcolor': 'yellow'}
""" """
warnings.warn( warnings.warn(
":meth:`Logger.report_vector` is deprecated;" ":meth:`Logger.report_vector` is deprecated; use :meth:`Logger.report_histogram` instead.",
"use :meth:`Logger.report_histogram` instead.", DeprecationWarning,
DeprecationWarning
) )
self._touch_title_series(title, series) self._touch_title_series(title, series)
return self.report_histogram(title, series, values, iteration or 0, labels=labels, xlabels=xlabels, return self.report_histogram(
xaxis=xaxis, yaxis=yaxis, mode=mode, extra_layout=extra_layout) title,
series,
values,
iteration or 0,
labels=labels,
xlabels=xlabels,
xaxis=xaxis,
yaxis=yaxis,
mode=mode,
extra_layout=extra_layout
)
def report_histogram( def report_histogram(
self, self,
@ -301,7 +316,7 @@ class Logger(object):
xlabels=xlabels, xlabels=xlabels,
xtitle=xaxis, xtitle=xaxis,
ytitle=yaxis, ytitle=yaxis,
mode=mode or 'group', mode=mode or "group",
data_args=data_args, data_args=data_args,
layout_config=extra_layout, layout_config=extra_layout,
) )
@ -374,10 +389,7 @@ class Logger(object):
) )
""" """
mutually_exclusive( mutually_exclusive(UsageError, _check_none=True, table_plot=table_plot, csv=csv, url=url)
UsageError, _check_none=True,
table_plot=table_plot, csv=csv, url=url
)
table = table_plot table = table_plot
if url or csv: if url or csv:
if not pd: if not pd:
@ -417,7 +429,7 @@ class Logger(object):
series, # type: Sequence[SeriesInfo] series, # type: Sequence[SeriesInfo]
xaxis, # type: str xaxis, # type: str
yaxis, # type: str yaxis, # type: str
mode='lines', # type: str mode="lines", # type: str
iteration=None, # type: Optional[int] iteration=None, # type: Optional[int]
reverse_xaxis=False, # type: bool reverse_xaxis=False, # type: bool
comment=None, # type: Optional[str] comment=None, # type: Optional[str]
@ -464,7 +476,7 @@ class Logger(object):
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
self._touch_title_series(title, series[0].name if series else '') self._touch_title_series(title, series[0].name if series else "")
# noinspection PyProtectedMember # noinspection PyProtectedMember
return self._task._reporter.report_line_plot( return self._task._reporter.report_line_plot(
title=title, title=title,
@ -487,7 +499,7 @@ class Logger(object):
xaxis=None, # type: Optional[str] xaxis=None, # type: Optional[str]
yaxis=None, # type: Optional[str] yaxis=None, # type: Optional[str]
labels=None, # type: Optional[List[str]] labels=None, # type: Optional[List[str]]
mode='lines', # type: str mode="lines", # type: str
comment=None, # type: Optional[str] comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict] extra_layout=None, # type: Optional[dict]
): ):
@ -567,7 +579,7 @@ class Logger(object):
yaxis=None, # type: Optional[str] yaxis=None, # type: Optional[str]
zaxis=None, # type: Optional[str] zaxis=None, # type: Optional[str]
labels=None, # type: Optional[List[str]] labels=None, # type: Optional[List[str]]
mode='markers', # type: str mode="markers", # type: str
fill=False, # type: bool fill=False, # type: bool
comment=None, # type: Optional[str] comment=None, # type: Optional[str]
extra_layout=None, # type: Optional[dict] extra_layout=None, # type: Optional[dict]
@ -605,16 +617,9 @@ class Logger(object):
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}} example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
""" """
# check if multiple series # check if multiple series
multi_series = ( multi_series = isinstance(scatter, list) and (
isinstance(scatter, list)
and (
isinstance(scatter[0], np.ndarray) isinstance(scatter[0], np.ndarray)
or ( or (scatter[0] and isinstance(scatter[0], list) and isinstance(scatter[0][0], list))
scatter[0]
and isinstance(scatter[0], list)
and isinstance(scatter[0][0], list)
)
)
) )
if not multi_series: if not multi_series:
@ -690,7 +695,7 @@ class Logger(object):
matrix = np.array(matrix) matrix = np.array(matrix)
if extra_layout is None: if extra_layout is None:
extra_layout = {'texttemplate': '%{z}'} extra_layout = {"texttemplate": "%{z}"}
# if task was not started, we have to start it # if task was not started, we have to start it
self._start_task_if_needed() self._start_task_if_needed()
@ -744,15 +749,22 @@ class Logger(object):
example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}} example: extra_layout={'xaxis': {'type': 'date', 'range': ['2020-01-01', '2020-01-31']}}
""" """
warnings.warn( warnings.warn(
":meth:`Logger.report_matrix` is deprecated;" ":meth:`Logger.report_matrix` is deprecated;" "use :meth:`Logger.report_confusion_matrix` instead.",
"use :meth:`Logger.report_confusion_matrix` instead.",
DeprecationWarning DeprecationWarning
) )
self._touch_title_series(title, series) self._touch_title_series(title, series)
return self.report_confusion_matrix(title, series, matrix, iteration or 0, return self.report_confusion_matrix(
xaxis=xaxis, yaxis=yaxis, xlabels=xlabels, ylabels=ylabels, title,
series,
matrix,
iteration or 0,
xaxis=xaxis,
yaxis=yaxis,
xlabels=xlabels,
ylabels=ylabels,
yaxis_reversed=yaxis_reversed, yaxis_reversed=yaxis_reversed,
extra_layout=extra_layout) extra_layout=extra_layout
)
def report_surface( def report_surface(
self, self,
@ -830,7 +842,7 @@ class Logger(object):
matrix=None, # type: Optional[np.ndarray] matrix=None, # type: Optional[np.ndarray]
max_image_history=None, # type: Optional[int] max_image_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool delete_after_upload=False, # type: bool
url=None # type: Optional[str] url=None, # type: Optional[str]
): ):
""" """
For explicit reporting, report an image and upload its contents. For explicit reporting, report an image and upload its contents.
@ -875,8 +887,7 @@ class Logger(object):
- ``False`` - Do not delete after upload. (default) - ``False`` - Do not delete after upload. (default)
""" """
mutually_exclusive( mutually_exclusive(
UsageError, _check_none=True, UsageError, _check_none=True, local_path=local_path or None, url=url or None, image=image, matrix=matrix
local_path=local_path or None, url=url or None, image=image, matrix=matrix
) )
if matrix is not None: if matrix is not None:
warnings.warn("'matrix' variable is deprecated; use 'image' instead.", DeprecationWarning) warnings.warn("'matrix' variable is deprecated; use 'image' instead.", DeprecationWarning)
@ -902,7 +913,7 @@ class Logger(object):
else: else:
upload_uri = self.get_default_upload_destination() upload_uri = self.get_default_upload_destination()
if not upload_uri: if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images' upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True) upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination # Verify that we can upload to this destination
upload_uri = str(upload_uri) upload_uri = str(upload_uri)
@ -934,7 +945,7 @@ class Logger(object):
file_extension=None, # type: Optional[str] file_extension=None, # type: Optional[str]
max_history=None, # type: Optional[int] max_history=None, # type: Optional[int]
delete_after_upload=False, # type: bool delete_after_upload=False, # type: bool
url=None # type: Optional[str] url=None, # type: Optional[str]
): ):
""" """
Report media upload its contents, including images, audio, and video. Report media upload its contents, including images, audio, and video.
@ -966,8 +977,11 @@ class Logger(object):
""" """
mutually_exclusive( mutually_exclusive(
UsageError, _check_none=True, UsageError,
local_path=local_path or None, url=url or None, stream=stream, _check_none=True,
local_path=local_path or None,
url=url or None,
stream=stream
) )
if stream is not None and not file_extension: if stream is not None and not file_extension:
raise ValueError("No file extension provided for stream media upload") raise ValueError("No file extension provided for stream media upload")
@ -989,7 +1003,7 @@ class Logger(object):
else: else:
upload_uri = self.get_default_upload_destination() upload_uri = self.get_default_upload_destination()
if not upload_uri: if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images' upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True) upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination # Verify that we can upload to this destination
upload_uri = str(upload_uri) upload_uri = str(upload_uri)
@ -1033,7 +1047,7 @@ class Logger(object):
plot = figure if isinstance(figure, dict) else figure.to_plotly_json() plot = figure if isinstance(figure, dict) else figure.to_plotly_json()
# noinspection PyBroadException # noinspection PyBroadException
try: try:
plot['layout']['title'] = series plot["layout"]["title"] = series
except Exception: except Exception:
pass pass
# noinspection PyProtectedMember # noinspection PyProtectedMember
@ -1078,8 +1092,7 @@ class Logger(object):
figure=figure, figure=figure,
iter=iteration or 0, iter=iteration or 0,
logger=self, logger=self,
force_save_as_image=False if report_interactive and not report_image force_save_as_image=False if report_interactive and not report_image else ("png" if report_image else True),
else ('png' if report_image else True),
) )
def set_default_upload_destination(self, uri): def set_default_upload_destination(self, uri):
@ -1208,14 +1221,21 @@ class Logger(object):
path=None, # type: Optional[str] path=None, # type: Optional[str]
matrix=None, # type: Optional[Union[np.ndarray, Image.Image]] matrix=None, # type: Optional[Union[np.ndarray, Image.Image]]
max_image_history=None, # type: Optional[int] max_image_history=None, # type: Optional[int]
delete_after_upload=False # type: bool delete_after_upload=False, # type: bool
): ):
""" """
.. deprecated:: 0.13.0 .. deprecated:: 0.13.0
Use :meth:`Logger.report_image` instead Use :meth:`Logger.report_image` instead
""" """
self.report_image(title=title, series=series, iteration=iteration or 0, local_path=path, image=matrix, self.report_image(
max_image_history=max_image_history, delete_after_upload=delete_after_upload) title=title,
series=series,
iteration=iteration or 0,
local_path=path,
image=matrix,
max_image_history=max_image_history,
delete_after_upload=delete_after_upload
)
def capture_logging(self): def capture_logging(self):
# type: () -> "_LoggingContext" # type: () -> "_LoggingContext"
@ -1224,6 +1244,7 @@ class Logger(object):
:return: a ContextManager :return: a ContextManager
""" """
class _LoggingContext(object): class _LoggingContext(object):
def __init__(self, a_logger): def __init__(self, a_logger):
self.logger = a_logger self.logger = a_logger
@ -1285,6 +1306,7 @@ class Logger(object):
:param force: If True, all matplotlib figures are converted automatically to non-interactive plots. :param force: If True, all matplotlib figures are converted automatically to non-interactive plots.
""" """
from clearml.backend_interface.metrics import Reporter from clearml.backend_interface.metrics import Reporter
Reporter.matplotlib_force_report_non_interactive(force=force) Reporter.matplotlib_force_report_non_interactive(force=force)
@classmethod @classmethod
@ -1327,8 +1349,7 @@ class Logger(object):
try: try:
return int(level) return int(level)
except (TypeError, ValueError): except (TypeError, ValueError):
self._task.log.log(level=logging.ERROR, self._task.log.log(level=logging.ERROR, msg='Logger failed casting log level "%s" to integer' % str(level))
msg='Logger failed casting log level "%s" to integer' % str(level))
return logging.INFO return logging.INFO
def _console(self, msg, level=logging.INFO, omit_console=False, force_send=False, *args, **_): def _console(self, msg, level=logging.INFO, omit_console=False, force_send=False, *args, **_):
@ -1356,7 +1377,7 @@ class Logger(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
record = self._task.log.makeRecord( record = self._task.log.makeRecord(
"console", level=level, fn='', lno=0, func='', msg=msg, args=args, exc_info=None "console", level=level, fn="", lno=0, func="", msg=msg, args=args, exc_info=None
) )
# find the task handler that matches our task # find the task handler that matches our task
self._task_handler.emit(record) self._task_handler.emit(record)
@ -1366,7 +1387,8 @@ class Logger(object):
try: try:
# make sure we are writing to the original stdout # make sure we are writing to the original stdout
StdStreamPatch.stderr_original_write( StdStreamPatch.stderr_original_write(
'clearml.Logger failed sending log [level {}]: "{}"\n'.format(level, msg)) 'clearml.Logger failed sending log [level {}]: "{}"\n'.format(level, msg)
)
except Exception: except Exception:
pass pass
else: else:
@ -1379,7 +1401,7 @@ class Logger(object):
# noinspection PyBroadException # noinspection PyBroadException
try: try:
# make sure we are writing to the original stdout # make sure we are writing to the original stdout
StdStreamPatch.stdout_original_write(str(msg) + '\n') StdStreamPatch.stdout_original_write(str(msg) + "\n")
except Exception: except Exception:
pass pass
else: else:
@ -1396,7 +1418,7 @@ class Logger(object):
path=None, # type: Optional[str] path=None, # type: Optional[str]
matrix=None, # type: Optional[np.ndarray] matrix=None, # type: Optional[np.ndarray]
max_image_history=None, # type: Optional[int] max_image_history=None, # type: Optional[int]
delete_after_upload=False # type: bool delete_after_upload=False, # type: bool
): ):
""" """
Report an image, upload its contents, and present in plots section using plotly Report an image, upload its contents, and present in plots section using plotly
@ -1418,7 +1440,7 @@ class Logger(object):
self._start_task_if_needed() self._start_task_if_needed()
upload_uri = self.get_default_upload_destination() upload_uri = self.get_default_upload_destination()
if not upload_uri: if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images' upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True) upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination # Verify that we can upload to this destination
upload_uri = str(upload_uri) upload_uri = str(upload_uri)
@ -1444,7 +1466,7 @@ class Logger(object):
iteration=None, # type: Optional[int] iteration=None, # type: Optional[int]
path=None, # type: Optional[str] path=None, # type: Optional[str]
max_file_history=None, # type: Optional[int] max_file_history=None, # type: Optional[int]
delete_after_upload=False # type: bool delete_after_upload=False, # type: bool
): ):
""" """
Upload a file and report it as link in the debug images section. Upload a file and report it as link in the debug images section.
@ -1465,7 +1487,7 @@ class Logger(object):
self._start_task_if_needed() self._start_task_if_needed()
upload_uri = self.get_default_upload_destination() upload_uri = self.get_default_upload_destination()
if not upload_uri: if not upload_uri:
upload_uri = Path(get_cache_dir()) / 'debug_images' upload_uri = Path(get_cache_dir()) / "debug_images"
upload_uri.mkdir(parents=True, exist_ok=True) upload_uri.mkdir(parents=True, exist_ok=True)
# Verify that we can upload to this destination # Verify that we can upload to this destination
upload_uri = str(upload_uri) upload_uri = str(upload_uri)