Add x/y/z title for scatter 2d/3d plots

This commit is contained in:
allegroai 2019-10-25 15:11:26 +03:00
parent 07f4b86d51
commit cb3167bdd8
3 changed files with 37 additions and 13 deletions

View File

@ -1,15 +1,16 @@
# TRAINS - Example of manual graphs and statistics reporting
#
import os
from PIL import Image
import numpy as np
import logging
from trains import Task
task = Task.init(project_name='examples', task_name='Manual reporting')
task = Task.init(project_name="examples", task_name="Manual reporting")
# standard python logging
logging.info('This is an info message')
logging.info("This is an info message")
# this is loguru test example
try:
@ -30,7 +31,8 @@ logger.report_scalar("example_scalar", "series A", iteration=1, value=200)
# report histogram
histogram = np.random.randint(10, size=10)
logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram)
logger.report_histogram("example_histogram", "random histogram", iteration=1, values=histogram,
xaxis="title x", yaxis="title y")
# report confusion matrix
confusion = np.random.randint(10, size=(10, 10))
@ -38,15 +40,17 @@ logger.report_matrix("example_confusion", "ignored", iteration=1, matrix=confusi
# report 3d surface
logger.report_surface("example_surface", "series1", iteration=1, matrix=confusion,
xtitle='title X', ytitle='title Y', ztitle='title Z')
xtitle="title X", ytitle="title Y", ztitle="title Z")
# report 2d scatter plot
scatter2d = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1))))
logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=scatter2d)
logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=scatter2d,
xaxis="title x", yaxis="title y")
# report 3d scatter plot
scatter3d = np.random.randint(10, size=(10, 3))
logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d)
logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d,
xaxis="title x", yaxis="title y", zaxis="title z")
# reporting images
m = np.eye(256, 256, dtype=np.float)
@ -55,7 +59,7 @@ m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image("test case", "image uint8", iteration=1, image=m)
m = np.concatenate((np.atleast_3d(m), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2)
logger.report_image("test case", "image color red", iteration=1, image=m)
image_open = Image.open('./samples/picasso.jpg')
image_open = Image.open(os.path.join("samples", "picasso.jpg"))
logger.report_image("test case", "image PIL", iteration=1, image=image_open)
# flush reports (otherwise it will be flushed in the background, every couple of seconds)
logger.flush()

View File

@ -224,7 +224,8 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
delete_after_upload=delete_after_upload, **kwargs)
self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None):
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None,
xtitle=None, ytitle=None, comment=None):
"""
Report an histogram bar plot
:param title: Title (AKA metric)
@ -240,12 +241,16 @@ class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncMan
:type labels: list of strings.
:param xlabels: The labels of the x axis.
:type xlabels: List of strings.
:param str xtitle: optional x-axis title
:param str ytitle: optional y-axis title
:param comment: comment underneath the title
:type comment: str
"""
plotly_dict = create_2d_histogram_plot(
np_row_wise=histogram,
title=title,
xtitle=xtitle,
ytitle=ytitle,
labels=labels,
series=series,
xlabels=xlabels,

View File

@ -87,7 +87,8 @@ class Logger(object):
self._touch_title_series(title, series)
return self._task.reporter.report_scalar(title=title, series=series, value=float(value), iter=iteration)
def report_vector(self, title, series, values, iteration, labels=None, xlabels=None):
def report_vector(self, title, series, values, iteration, labels=None, xlabels=None,
xaxis=None, yaxis=None):
"""
Report a histogram plot
@ -97,11 +98,15 @@ class Logger(object):
:param int iteration: Iteration number
:param list(str) labels: optional, labels for each bar group.
:param list(str) xlabels: optional label per entry in the vector (bucket in the histogram)
:param str xaxis: optional x-axis title
:param str yaxis: optional y-axis title
"""
self._touch_title_series(title, series)
return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels)
return self.report_histogram(title, series, values, iteration, labels=labels, xlabels=xlabels,
xaxis=xaxis, yaxis=yaxis)
def report_histogram(self, title, series, values, iteration, labels=None, xlabels=None):
def report_histogram(self, title, series, values, iteration, labels=None, xlabels=None,
xaxis=None, yaxis=None):
"""
Report a histogram plot
@ -111,6 +116,8 @@ class Logger(object):
:param int iteration: Iteration number
:param list(str) labels: optional, labels for each bar group.
:param list(str) xlabels: optional label per entry in the vector (bucket in the histogram)
:param str xaxis: optional x-axis title
:param str yaxis: optional y-axis title
"""
if not isinstance(values, np.ndarray):
@ -126,6 +133,8 @@ class Logger(object):
iter=iteration,
labels=labels,
xlabels=xlabels,
xtitle=xaxis,
ytitle=yaxis,
)
def report_line_plot(self, title, series, iteration, xaxis, yaxis, mode='lines',
@ -195,8 +204,8 @@ class Logger(object):
comment=comment,
)
def report_scatter3d(self, title, series, scatter, iteration, labels=None, mode='markers',
fill=False, comment=None):
def report_scatter3d(self, title, series, scatter, iteration, xaxis=None, yaxis=None, zaxis=None,
labels=None, mode='markers', fill=False, comment=None):
"""
Report a 3d scatter graph (with markers)
@ -205,6 +214,9 @@ class Logger(object):
:param np.ndarray scatter: A scattered data: list of (pairs of x,y,z) (or numpy array)
or list of series [[(x1,y1,z1)...]]
:param int iteration: Iteration number
:param str xaxis: optional x-axis title
:param str yaxis: optional y-axis title
:param str zaxis: optional z-axis title
:param list(str) labels: label (text) per point in the scatter (in the same order)
:param str mode: scatter plot with 'lines'/'markers'/'lines+markers'
:param bool fill: fill area under the curve
@ -245,6 +257,9 @@ class Logger(object):
mode=mode,
fill=fill,
comment=comment,
xtitle=xaxis,
ytitle=yaxis,
ztitle=zaxis,
)
def report_confusion_matrix(self, title, series, matrix, iteration, xlabels=None, ylabels=None, comment=None):