Add x_axis_label support in scalar iter charts

This commit is contained in:
clearml 2024-12-31 22:05:30 +02:00
parent 478f6b531b
commit 8c29ebaece
10 changed files with 48 additions and 6 deletions

View File

@ -201,6 +201,8 @@ class EventBLL(object):
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
for event in events:
x_axis_label = event.pop("x_axis_label", None)
# remove spaces from event type
event_type = event.get("type")
if event_type is None:
@ -296,6 +298,7 @@ class EventBLL(object):
self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id],
event=event,
x_axis_label=x_axis_label,
)
actions.append(es_action)
@ -431,7 +434,7 @@ class EventBLL(object):
return False
return True
def _update_last_scalar_events_for_task(self, last_events, event):
def _update_last_scalar_events_for_task(self, last_events, event, x_axis_label=None):
"""
Update last_events structure with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
@ -463,6 +466,8 @@ class EventBLL(object):
last_event["value"] = value
last_event["iter"] = event_iter
last_event["timestamp"] = event_timestamp
if x_axis_label is not None:
last_event["x_axis_label"] = x_axis_label
first_value_iter = last_event.get("first_value_iter")
if first_value_iter is None or event_iter < first_value_iter:

View File

@ -24,6 +24,8 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.bll.query import Builder as QueryBuilder
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.utilities.dicts import nested_get
log = config.logger(__file__)
@ -43,6 +45,7 @@ class EventMetrics:
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
model_events: bool = False,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@ -60,6 +63,7 @@ class EventMetrics:
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
model_events=model_events,
)
def _get_scalar_average_per_iter_core(
@ -71,6 +75,7 @@ class EventMetrics:
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
model_events: bool = False,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@ -102,7 +107,22 @@ class EventMetrics:
)
ret = defaultdict(dict)
if not metrics:
return ret
last_metrics = {}
cls_ = Model if model_events else Task
task = cls_.objects(id=task_id).only("last_metrics").first()
if task and task.last_metrics:
for m_data in task.last_metrics.values():
for v_data in m_data.values():
last_metrics[(v_data.metric, v_data.variant)] = v_data
for metric_key, metric_values in metrics:
for variant_key, data in metric_values.items():
last_metrics_data = last_metrics.get((metric_key, variant_key))
if last_metrics_data and last_metrics_data.x_axis_label is not None:
data["x_axis_label"] = last_metrics_data.x_axis_label
ret[metric_key].update(metric_values)
return ret
@ -113,6 +133,7 @@ class EventMetrics:
samples,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
model_events: bool = False,
):
"""
Compare scalar metrics for different tasks per metric and variant
@ -136,6 +157,7 @@ class EventMetrics:
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
run_parallel=False,
model_events=model_events,
)
task_ids, company_ids = zip(
*(
@ -165,7 +187,7 @@ class EventMetrics:
self,
companies: TaskCompanies,
metric_variants: MetricVariants = None,
) -> Mapping[str, dict]:
) -> Mapping[str, Sequence[dict]]:
"""
For the requested tasks return all the events delivered for the single iteration (-2**31)
"""

View File

@ -408,7 +408,7 @@ class TaskBLL:
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
:param last_scalar_events: Last reported metrics summary for scalar events (value, metric, variant).
:param last_events: Last reported metrics summary (value, metric, event type).
:param extra_updates: Extra task updates to include in this update call.
:return:

View File

@ -395,7 +395,7 @@ def get_last_metric_updates(
is_min=(key == "min_value"),
is_first=(key == "first_value"),
)
elif key in ("metric", "variant", "value"):
elif key in ("metric", "variant", "value", "x_axis_label"):
extra_updates[f"set__{path}__{key}"] = value
count = variant_data.get("count")

View File

@ -28,6 +28,7 @@ class MetricEvent(EmbeddedDocument):
first_value_iteration = IntField()
count = IntField()
mean_value = FloatField()
x_axis_label = StringField()
class EventStats(EmbeddedDocument):

View File

@ -299,6 +299,10 @@ last_metrics_event {
description: "The total count of reported values"
type: integer
}
x_axis_label {
description: The user defined value for the X-Axis name stored with the event
type: string
}
}
}
last_metrics_variants {

View File

@ -27,13 +27,17 @@ _definitions {
type: string
}
variant {
description: "E.g. 'class_1', 'total', 'average"
description: "E.g. 'class_1', 'total', 'average'"
type: string
}
value {
description: ""
type: number
}
x_axis_label {
description: "Custom X-Axis label to be used when displaying the scalars histogram"
type: string
}
}
}
metrics_vector_event {

View File

@ -490,6 +490,7 @@ def scalar_metrics_iter_histogram(
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
)
call.result.data = metrics
@ -540,12 +541,13 @@ def multi_task_scalar_metrics_iter_histogram(
samples=request.samples,
key=request.key,
metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
)
)
def _get_single_value_metrics_response(
companies: TaskCompanies, value_metrics: Mapping[str, dict]
companies: TaskCompanies, value_metrics: Mapping[str, Sequence[dict]]
) -> Sequence[dict]:
task_names = {
task.id: task.name for task in itertools.chain.from_iterable(companies.values())

View File

@ -282,6 +282,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
metric_variants=_get_metric_variants_from_request(
request.scalar_metrics_iter_histogram.metrics
),
model_events=request.model_events,
)
if request.single_value_metrics:

View File

@ -246,6 +246,7 @@ class TestTaskEvents(TestService):
"variant": f"Variant{variant_idx}",
"value": iteration,
"model_event": True,
"x_axis_label": f"Label_{metric_idx}_{variant_idx}"
}
for iteration in range(2)
for metric_idx in range(5)
@ -274,6 +275,7 @@ class TestTaskEvents(TestService):
variant_data = metric_data.Variant0
self.assertEqual(variant_data.x, [0, 1])
self.assertEqual(variant_data.y, [0.0, 1.0])
self.assertEqual(variant_data.x_axis_label, "Label_0_0")
model_data = self.api.models.get_all_ex(
id=[model], only_fields=["last_metrics", "last_iteration"]
@ -285,6 +287,7 @@ class TestTaskEvents(TestService):
self.assertEqual(1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
self.assertEqual("Label_4_4", metric_data.x_axis_label)
self._assert_log_events(task=task, expected_total=1)