diff --git a/examples/manual_reporting.py b/examples/manual_reporting.py index 1da9080d..d7e875ff 100644 --- a/examples/manual_reporting.py +++ b/examples/manual_reporting.py @@ -1,15 +1,16 @@ # TRAINS - Example of manual graphs and statistics reporting # +import os from PIL import Image import numpy as np import logging from trains import Task -task = Task.init(project_name='examples', task_name='Manual reporting') +task = Task.init(project_name="examples", task_name="Manual reporting") # standard python logging -logging.info('This is an info message') +logging.info("This is an info message") # this is loguru test example try: @@ -30,7 +31,8 @@ logger.report_scalar("example_scalar", "series A", iteration=1, value=200) # report histogram histogram = np.random.randint(10, size=10) -logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram) +logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram, + xaxis="title x", yaxis="title y") # report confusion matrix confusion = np.random.randint(10, size=(10, 10)) @@ -38,15 +40,17 @@ logger.report_matrix("example_confusion", "ignored", iteration=1, matrix=confusi # report 3d surface logger.report_surface("example_surface", "series1", iteration=1, matrix=confusion, - xtitle='title X', ytitle='title Y', ztitle='title Z') + xtitle="title X", ytitle="title Y", ztitle="title Z") # report 2d scatter plot scatter2d = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1)))) -logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=scatter2d) +logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=scatter2d, + xaxis="title x", yaxis="title y") # report 3d scatter plot scatter3d = np.random.randint(10, size=(10, 3)) -logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d) +logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d, + xaxis="title x", yaxis="title y", zaxis="title z") # reporting images m = np.eye(256, 256, dtype=np.float) @@ -55,7 +59,7 @@ m = np.eye(256, 256, dtype=np.uint8)*255 logger.report_image("test case", "image uint8", iteration=1, image=m) m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2) logger.report_image("test case", "image color red", iteration=1, image=m) -image_open = Image.open('./samples/picasso.jpg') +image_open = Image.open(os.path.join("samples", "picasso.jpg")) logger.report_image("test case", "image PIL", iteration=1, image=image_open) # flush reports (otherwise it will be flushed in the background, every couple of seconds) logger.flush() diff --git a/trains/backend_interface/metrics/reporter.py b/trains/backend_interface/metrics/reporter.py index 46af27d0..9b1a196e 100644 --- a/trains/backend_interface/metrics/reporter.py +++ b/trains/backend_interface/metrics/reporter.py @@ -224,7 +224,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan delete_after_upload=delete_after_upload, **kwargs) self._report(ev) - def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None): + def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, + xtitle=None, ytitle=None, comment=None): """ Report an histogram bar plot :param title: Title (AKA metric) @@ -240,12 +241,16 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan :type labels: list of strings. :param xlabels: The labels of the x axis. :type xlabels: List of strings. + :param str xtitle: optional x-axis title + :param str ytitle: optional y-axis title :param comment: comment underneath the title :type comment: str """ plotly_dict = create_2d_histogram_plot( np_row_wise=histogram, title=title, + xtitle=xtitle, + ytitle=ytitle, labels=labels, series=series, xlabels=xlabels, diff --git a/trains/logger.py b/trains/logger.py index f63afdc1..8f3a9176 100644 --- a/trains/logger.py +++ b/trains/logger.py @@ -87,7 +87,8 @@ class Logger(object): self._touch_title_series(title, series) return self._task.reporter.report_scalar(title=title, series=series, value=float(value), iter=iteration) - def report_vector(self, title, series, values, iteration, labels=None, xlabels=None): + def report_vector(self, title, series, values, iteration, labels=None, xlabels=None, + xaxis=None, yaxis=None): """ Report a histogram plot @@ -97,11 +98,15 @@ class Logger(object): :param int iteration: Iteration number :param list(str) labels: optional, labels for each bar group. :param list(str) xlabels: optional label per entry in the vector (bucket in the histogram) + :param str xaxis: optional x-axis title + :param str yaxis: optional y-axis title """ self._touch_title_series(title, series) - return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels) + return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels, + xaxis=xaxis, yaxis=yaxis) - def report_histogram(self, title, series, values, iteration, labels=None, xlabels=None): + def report_histogram(self, title, series, values, iteration, labels=None, xlabels=None, + xaxis=None, yaxis=None): """ Report a histogram plot @@ -111,6 +116,8 @@ class Logger(object): :param int iteration: Iteration number :param list(str) labels: optional, labels for each bar group. :param list(str) xlabels: optional label per entry in the vector (bucket in the histogram) + :param str xaxis: optional x-axis title + :param str yaxis: optional y-axis title """ if not isinstance(values, np.ndarray): @@ -126,6 +133,8 @@ class Logger(object): iter=iteration, labels=labels, xlabels=xlabels, + xtitle=xaxis, + ytitle=yaxis, ) def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines', @@ -195,8 +204,8 @@ class Logger(object): comment=comment, ) - def report_scatter3d(self, title, series, scatter, iteration, labels=None, mode='markers', - fill=False, comment=None): + def report_scatter3d(self, title, series, scatter, iteration, xaxis=None, yaxis=None, zaxis=None, + labels=None, mode='markers', fill=False, comment=None): """ Report a 3d scatter graph (with markers) @@ -205,6 +214,9 @@ class Logger(object): :param np.ndarray scatter: A scattered data: list of (pairs of x,y,z) (or numpy array) or list of series [[(x1,y1,z1)...]] :param int iteration: Iteration number + :param str xaxis: optional x-axis title + :param str yaxis: optional y-axis title + :param str zaxis: optional z-axis title :param list(str) labels: label (text) per point in the scatter (in the same order) :param str mode: scatter plot with 'lines'/'markers'/'lines+markers' :param bool fill: fill area under the curve @@ -245,6 +257,9 @@ class Logger(object): mode=mode, fill=fill, comment=comment, + xtitle=xaxis, + ytitle=yaxis, + ztitle=zaxis, ) def report_confusion_matrix(self, title, series, matrix, iteration, xlabels=None, ylabels=None, comment=None):