mirror of
https://github.com/clearml/clearml-server
synced 2025-04-27 17:31:25 +00:00
Add x_axis_label support in scalar iter charts
This commit is contained in:
parent
478f6b531b
commit
8c29ebaece
@ -201,6 +201,8 @@ class EventBLL(object):
|
||||
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
|
||||
|
||||
for event in events:
|
||||
x_axis_label = event.pop("x_axis_label", None)
|
||||
|
||||
# remove spaces from event type
|
||||
event_type = event.get("type")
|
||||
if event_type is None:
|
||||
@ -296,6 +298,7 @@ class EventBLL(object):
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_or_model_id],
|
||||
event=event,
|
||||
x_axis_label=x_axis_label,
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
@ -431,7 +434,7 @@ class EventBLL(object):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
||||
def _update_last_scalar_events_for_task(self, last_events, event, x_axis_label=None):
|
||||
"""
|
||||
Update last_events structure with the provided event details if this event is more
|
||||
recent than the currently stored event for its metric/variant combination.
|
||||
@ -463,6 +466,8 @@ class EventBLL(object):
|
||||
last_event["value"] = value
|
||||
last_event["iter"] = event_iter
|
||||
last_event["timestamp"] = event_timestamp
|
||||
if x_axis_label is not None:
|
||||
last_event["x_axis_label"] = x_axis_label
|
||||
|
||||
first_value_iter = last_event.get("first_value_iter")
|
||||
if first_value_iter is None or event_iter < first_value_iter:
|
||||
|
@ -24,6 +24,8 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.bll.query import Builder as QueryBuilder
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -43,6 +45,7 @@ class EventMetrics:
|
||||
samples: int,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
model_events: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
@ -60,6 +63,7 @@ class EventMetrics:
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
model_events=model_events,
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
@ -71,6 +75,7 @@ class EventMetrics:
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
metric_variants: MetricVariants = None,
|
||||
model_events: bool = False,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
@ -102,7 +107,22 @@ class EventMetrics:
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
if not metrics:
|
||||
return ret
|
||||
|
||||
last_metrics = {}
|
||||
cls_ = Model if model_events else Task
|
||||
task = cls_.objects(id=task_id).only("last_metrics").first()
|
||||
if task and task.last_metrics:
|
||||
for m_data in task.last_metrics.values():
|
||||
for v_data in m_data.values():
|
||||
last_metrics[(v_data.metric, v_data.variant)] = v_data
|
||||
|
||||
for metric_key, metric_values in metrics:
|
||||
for variant_key, data in metric_values.items():
|
||||
last_metrics_data = last_metrics.get((metric_key, variant_key))
|
||||
if last_metrics_data and last_metrics_data.x_axis_label is not None:
|
||||
data["x_axis_label"] = last_metrics_data.x_axis_label
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
@ -113,6 +133,7 @@ class EventMetrics:
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
model_events: bool = False,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
@ -136,6 +157,7 @@ class EventMetrics:
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
run_parallel=False,
|
||||
model_events=model_events,
|
||||
)
|
||||
task_ids, company_ids = zip(
|
||||
*(
|
||||
@ -165,7 +187,7 @@ class EventMetrics:
|
||||
self,
|
||||
companies: TaskCompanies,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Mapping[str, dict]:
|
||||
) -> Mapping[str, Sequence[dict]]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
"""
|
||||
|
@ -408,7 +408,7 @@ class TaskBLL:
|
||||
task's last iteration value.
|
||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||
if the current task's last iteration value is smaller than the provided value.
|
||||
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
|
||||
:param last_scalar_events: Last reported metrics summary for scalar events (value, metric, variant).
|
||||
:param last_events: Last reported metrics summary (value, metric, event type).
|
||||
:param extra_updates: Extra task updates to include in this update call.
|
||||
:return:
|
||||
|
@ -395,7 +395,7 @@ def get_last_metric_updates(
|
||||
is_min=(key == "min_value"),
|
||||
is_first=(key == "first_value"),
|
||||
)
|
||||
elif key in ("metric", "variant", "value"):
|
||||
elif key in ("metric", "variant", "value", "x_axis_label"):
|
||||
extra_updates[f"set__{path}__{key}"] = value
|
||||
|
||||
count = variant_data.get("count")
|
||||
|
@ -28,6 +28,7 @@ class MetricEvent(EmbeddedDocument):
|
||||
first_value_iteration = IntField()
|
||||
count = IntField()
|
||||
mean_value = FloatField()
|
||||
x_axis_label = StringField()
|
||||
|
||||
|
||||
class EventStats(EmbeddedDocument):
|
||||
|
@ -299,6 +299,10 @@ last_metrics_event {
|
||||
description: "The total count of reported values"
|
||||
type: integer
|
||||
}
|
||||
x_axis_label {
|
||||
description: The user defined value for the X-Axis name stored with the event
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
last_metrics_variants {
|
||||
|
@ -27,13 +27,17 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "E.g. 'class_1', 'total', 'average"
|
||||
description: "E.g. 'class_1', 'total', 'average'"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: ""
|
||||
type: number
|
||||
}
|
||||
x_axis_label {
|
||||
description: "Custom X-Axis label to be used when displaying the scalars histogram"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
metrics_vector_event {
|
||||
|
@ -490,6 +490,7 @@ def scalar_metrics_iter_histogram(
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
model_events=request.model_events,
|
||||
)
|
||||
call.result.data = metrics
|
||||
|
||||
@ -540,12 +541,13 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
model_events=request.model_events,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _get_single_value_metrics_response(
|
||||
companies: TaskCompanies, value_metrics: Mapping[str, dict]
|
||||
companies: TaskCompanies, value_metrics: Mapping[str, Sequence[dict]]
|
||||
) -> Sequence[dict]:
|
||||
task_names = {
|
||||
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
|
||||
|
@ -282,6 +282,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
metric_variants=_get_metric_variants_from_request(
|
||||
request.scalar_metrics_iter_histogram.metrics
|
||||
),
|
||||
model_events=request.model_events,
|
||||
)
|
||||
|
||||
if request.single_value_metrics:
|
||||
|
@ -246,6 +246,7 @@ class TestTaskEvents(TestService):
|
||||
"variant": f"Variant{variant_idx}",
|
||||
"value": iteration,
|
||||
"model_event": True,
|
||||
"x_axis_label": f"Label_{metric_idx}_{variant_idx}"
|
||||
}
|
||||
for iteration in range(2)
|
||||
for metric_idx in range(5)
|
||||
@ -274,6 +275,7 @@ class TestTaskEvents(TestService):
|
||||
variant_data = metric_data.Variant0
|
||||
self.assertEqual(variant_data.x, [0, 1])
|
||||
self.assertEqual(variant_data.y, [0.0, 1.0])
|
||||
self.assertEqual(variant_data.x_axis_label, "Label_0_0")
|
||||
|
||||
model_data = self.api.models.get_all_ex(
|
||||
id=[model], only_fields=["last_metrics", "last_iteration"]
|
||||
@ -285,6 +287,7 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(1, metric_data.max_value_iteration)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
self.assertEqual(0, metric_data.min_value_iteration)
|
||||
self.assertEqual("Label_4_4", metric_data.x_axis_label)
|
||||
|
||||
self._assert_log_events(task=task, expected_total=1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user