Add Task.get_reported_console_output() and Task.get_reported_scalars()

This commit is contained in:
allegroai 2020-05-22 10:34:45 +03:00
parent 2d22efcead
commit 072abfd6fd

View File

@ -929,6 +929,70 @@ class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
self._data.status = status
return str(status)
def get_reported_scalars(self, max_samples=0, x_axis='iter'):
"""
Return a nested dictionary for the scalar graphs,
where the first key is the graph title and the second is the series name.
Value is a dict with 'x': values and 'y': values
Notice: This call is not cached, any call will retrieve all the scalar reports from the back-end.
If the Task has many scalars reported, it might take long for the call to return.
Example:
{'title': {'series': {
'x': [0, 1 ,2],
'y': [10, 11 ,12],
}}}
:param int max_samples: Maximum samples per series to return. Default is 0 returning all scalars.
With sample limit, average scalar values inside sampling window.
:param str x_axis: scalar x_axis, possible values:
'iter': iteration (default), 'timestamp': seconds from start, 'iso_time': absolute time
:return dict: Nested scalar graphs: dict[title(str), dict[series(str), dict[axis(str), list(float)]]]
"""
if x_axis not in ('iter', 'timestamp', 'iso_time'):
raise ValueError("Scalar x-axis supported values are: 'iter', 'timestamp', 'iso_time'")
# send request
res = self.send(
events.ScalarMetricsIterHistogramRequest(task=self.id, key=x_axis, samples=max(0, max_samples))
)
response = res.wait()
if not response.ok() or not response.response_data:
return {}
return response.response_data
def get_reported_console_output(self, number_of_reports=1):
"""
Return a list of console outputs reported by the Task.
Returned console outputs are retrieved from the most updated console outputs.
:param int number_of_reports: number of reports to return, default 1, the last (most updated) console output
:return list: List of strings each entry corresponds to one report.
"""
res = self.send(
events.GetTaskLogRequest(
task=self.id,
order='asc',
from_='tail',
batch_size=number_of_reports,)
)
response = res.wait()
if not response.ok() or not response.response_data.get('events'):
return []
lines = [r.get('msg', '') for r in response.response_data['events']]
return lines
@classmethod
def add_requirements(cls, package_name, package_version=None):
"""
Force package in requirements list. If version is not specified, use the installed package version if found.
:param str package_name: Package name to add to the "Installed Packages" section of the task
:param package_version: Package version requirements. If None use the installed version
"""
cls._force_requirements[package_name] = package_version
def _get_models(self, model_type='output'):
model_type = model_type.lower().strip()
assert model_type == 'output' or model_type == 'input'