From 072abfd6fd75968be2ccf91ab50b9e4a39d08b4a Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 22 May 2020 10:34:45 +0300 Subject: [PATCH] Add Task.get_reported_console_output() and Task.get_reported_scalars() --- trains/backend_interface/task/task.py | 64 +++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/trains/backend_interface/task/task.py b/trains/backend_interface/task/task.py index 197a155c..f3912331 100644 --- a/trains/backend_interface/task/task.py +++ b/trains/backend_interface/task/task.py @@ -929,6 +929,70 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin): self._data.status = status return str(status) + def get_reported_scalars(self, max_samples=0, x_axis='iter'): + """ + Return a nested dictionary for the scalar graphs, + where the first key is the graph title and the second is the series name. + Value is a dict with 'x': values and 'y': values + + Notice: This call is not cached, any call will retrieve all the scalar reports from the back-end. + If the Task has many scalars reported, it might take long for the call to return. + + Example: + {'title': {'series': { + 'x': [0, 1 ,2], + 'y': [10, 11 ,12], + }}} + :param int max_samples: Maximum samples per series to return. Default is 0 returning all scalars. + With sample limit, average scalar values inside sampling window. + :param str x_axis: scalar x_axis, possible values: + 'iter': iteration (default), 'timestamp': seconds from start, 'iso_time': absolute time + :return dict: Nested scalar graphs: dict[title(str), dict[series(str), dict[axis(str), list(float)]]] + """ + if x_axis not in ('iter', 'timestamp', 'iso_time'): + raise ValueError("Scalar x-axis supported values are: 'iter', 'timestamp', 'iso_time'") + + # send request + res = self.send( + events.ScalarMetricsIterHistogramRequest(task=self.id, key=x_axis, samples=max(0, max_samples)) + ) + response = res.wait() + if not response.ok() or not response.response_data: + return {} + + return response.response_data + + def get_reported_console_output(self, number_of_reports=1): + """ + Return a list of console outputs reported by the Task. + Returned console outputs are retrieved from the most updated console outputs. + + :param int number_of_reports: number of reports to return, default 1, the last (most updated) console output + :return list: List of strings each entry corresponds to one report. + """ + res = self.send( + events.GetTaskLogRequest( + task=self.id, + order='asc', + from_='tail', + batch_size=number_of_reports,) + ) + response = res.wait() + if not response.ok() or not response.response_data.get('events'): + return [] + + lines = [r.get('msg', '') for r in response.response_data['events']] + return lines + + @classmethod + def add_requirements(cls, package_name, package_version=None): + """ + Force package in requirements list. If version is not specified, use the installed package version if found. + :param str package_name: Package name to add to the "Installed Packages" section of the task + :param package_version: Package version requirements. If None use the installed version + """ + cls._force_requirements[package_name] = package_version + def _get_models(self, model_type='output'): model_type = model_type.lower().strip() assert model_type == 'output' or model_type == 'input'