From f663add27d97d256adf1333194f5754fbbee06da Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 19 Aug 2019 21:26:29 +0300 Subject: [PATCH] Simplify examples --- examples/joblib_example.py | 19 ++++++++++++++++++- examples/manual_reporting.py | 13 ++++++------- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/examples/joblib_example.py b/examples/joblib_example.py index fc1eaf00..a502fe10 100644 --- a/examples/joblib_example.py +++ b/examples/joblib_example.py @@ -6,6 +6,8 @@ except ImportError: from sklearn import datasets from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split +import numpy as np +import matplotlib.pyplot as plt from trains import Task @@ -25,4 +27,19 @@ joblib.dump(model, 'model.pkl', compress=True) loaded_model = joblib.load('model.pkl') result = loaded_model.score(X_test, y_test) -print(result) +x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 +y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 +h = .02 # step size in the mesh +xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) +plt.figure(1, figsize=(4, 3)) + +plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', cmap=plt.cm.Paired) +plt.xlabel('Sepal length') +plt.ylabel('Sepal width') + +plt.xlim(xx.min(), xx.max()) +plt.ylim(yy.min(), yy.max()) +plt.xticks(()) +plt.yticks(()) + +plt.show() \ No newline at end of file diff --git a/examples/manual_reporting.py b/examples/manual_reporting.py index 1feba570..9692501f 100644 --- a/examples/manual_reporting.py +++ b/examples/manual_reporting.py @@ -8,17 +8,12 @@ from trains import Task task = Task.init(project_name='examples', task_name='Manual reporting') # standard python logging -logging.getLogger().setLevel('DEBUG') -logging.debug('This is a debug message') logging.info('This is an info message') -logging.warning('This is a warning message') -logging.error('This is an error message') -logging.critical('This is a critical message') # this is loguru test example try: from loguru import logger - logger.debug("That's it, beautiful and simple logging! (using ANSI colors)") + logger.info("That's it, beautiful and simple logging! (using ANSI colors)") except ImportError: pass @@ -40,6 +35,10 @@ logger.report_histogram("example_histogram", "random histogram", iteration=1, va confusion = np.random.randint(10, size=(10, 10)) logger.report_matrix("example_confusion", "ignored", iteration=1, matrix=confusion) +# report 3d surface +logger.report_surface("example_surface", "series1", iteration=1, matrix=confusion, + xtitle='title X', ytitle='title Y', ztitle='title Z') + # report 2d scatter plot scatter2d = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1)))) logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=scatter2d) @@ -48,7 +47,7 @@ logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=sca scatter3d = np.random.randint(10, size=(10, 3)) logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d) -# report images +# reporting images m = np.eye(256, 256, dtype=np.float) logger.report_image_and_upload("test case", "image float", iteration=1, matrix=m) m = np.eye(256, 256, dtype=np.uint8)*255