mirror of
https://github.com/clearml/clearml-server
synced 2025-06-16 19:18:06 +00:00
Add x_axis_label support in scalar iter charts
This commit is contained in:
parent
478f6b531b
commit
8c29ebaece
@ -201,6 +201,8 @@ class EventBLL(object):
|
|||||||
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
|
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
|
x_axis_label = event.pop("x_axis_label", None)
|
||||||
|
|
||||||
# remove spaces from event type
|
# remove spaces from event type
|
||||||
event_type = event.get("type")
|
event_type = event.get("type")
|
||||||
if event_type is None:
|
if event_type is None:
|
||||||
@ -296,6 +298,7 @@ class EventBLL(object):
|
|||||||
self._update_last_scalar_events_for_task(
|
self._update_last_scalar_events_for_task(
|
||||||
last_events=task_last_scalar_events[task_or_model_id],
|
last_events=task_last_scalar_events[task_or_model_id],
|
||||||
event=event,
|
event=event,
|
||||||
|
x_axis_label=x_axis_label,
|
||||||
)
|
)
|
||||||
|
|
||||||
actions.append(es_action)
|
actions.append(es_action)
|
||||||
@ -431,7 +434,7 @@ class EventBLL(object):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _update_last_scalar_events_for_task(self, last_events, event):
|
def _update_last_scalar_events_for_task(self, last_events, event, x_axis_label=None):
|
||||||
"""
|
"""
|
||||||
Update last_events structure with the provided event details if this event is more
|
Update last_events structure with the provided event details if this event is more
|
||||||
recent than the currently stored event for its metric/variant combination.
|
recent than the currently stored event for its metric/variant combination.
|
||||||
@ -463,6 +466,8 @@ class EventBLL(object):
|
|||||||
last_event["value"] = value
|
last_event["value"] = value
|
||||||
last_event["iter"] = event_iter
|
last_event["iter"] = event_iter
|
||||||
last_event["timestamp"] = event_timestamp
|
last_event["timestamp"] = event_timestamp
|
||||||
|
if x_axis_label is not None:
|
||||||
|
last_event["x_axis_label"] = x_axis_label
|
||||||
|
|
||||||
first_value_iter = last_event.get("first_value_iter")
|
first_value_iter = last_event.get("first_value_iter")
|
||||||
if first_value_iter is None or event_iter < first_value_iter:
|
if first_value_iter is None or event_iter < first_value_iter:
|
||||||
|
@ -24,6 +24,8 @@ from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
|||||||
from apiserver.bll.query import Builder as QueryBuilder
|
from apiserver.bll.query import Builder as QueryBuilder
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
|
from apiserver.database.model.model import Model
|
||||||
|
from apiserver.database.model.task.task import Task
|
||||||
from apiserver.utilities.dicts import nested_get
|
from apiserver.utilities.dicts import nested_get
|
||||||
|
|
||||||
log = config.logger(__file__)
|
log = config.logger(__file__)
|
||||||
@ -43,6 +45,7 @@ class EventMetrics:
|
|||||||
samples: int,
|
samples: int,
|
||||||
key: ScalarKeyEnum,
|
key: ScalarKeyEnum,
|
||||||
metric_variants: MetricVariants = None,
|
metric_variants: MetricVariants = None,
|
||||||
|
model_events: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Get scalar metric histogram per metric and variant
|
Get scalar metric histogram per metric and variant
|
||||||
@ -60,6 +63,7 @@ class EventMetrics:
|
|||||||
samples=samples,
|
samples=samples,
|
||||||
key=ScalarKey.resolve(key),
|
key=ScalarKey.resolve(key),
|
||||||
metric_variants=metric_variants,
|
metric_variants=metric_variants,
|
||||||
|
model_events=model_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_scalar_average_per_iter_core(
|
def _get_scalar_average_per_iter_core(
|
||||||
@ -71,6 +75,7 @@ class EventMetrics:
|
|||||||
key: ScalarKey,
|
key: ScalarKey,
|
||||||
run_parallel: bool = True,
|
run_parallel: bool = True,
|
||||||
metric_variants: MetricVariants = None,
|
metric_variants: MetricVariants = None,
|
||||||
|
model_events: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
intervals = self._get_task_metric_intervals(
|
intervals = self._get_task_metric_intervals(
|
||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
@ -102,7 +107,22 @@ class EventMetrics:
|
|||||||
)
|
)
|
||||||
|
|
||||||
ret = defaultdict(dict)
|
ret = defaultdict(dict)
|
||||||
|
if not metrics:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
last_metrics = {}
|
||||||
|
cls_ = Model if model_events else Task
|
||||||
|
task = cls_.objects(id=task_id).only("last_metrics").first()
|
||||||
|
if task and task.last_metrics:
|
||||||
|
for m_data in task.last_metrics.values():
|
||||||
|
for v_data in m_data.values():
|
||||||
|
last_metrics[(v_data.metric, v_data.variant)] = v_data
|
||||||
|
|
||||||
for metric_key, metric_values in metrics:
|
for metric_key, metric_values in metrics:
|
||||||
|
for variant_key, data in metric_values.items():
|
||||||
|
last_metrics_data = last_metrics.get((metric_key, variant_key))
|
||||||
|
if last_metrics_data and last_metrics_data.x_axis_label is not None:
|
||||||
|
data["x_axis_label"] = last_metrics_data.x_axis_label
|
||||||
ret[metric_key].update(metric_values)
|
ret[metric_key].update(metric_values)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
@ -113,6 +133,7 @@ class EventMetrics:
|
|||||||
samples,
|
samples,
|
||||||
key: ScalarKeyEnum,
|
key: ScalarKeyEnum,
|
||||||
metric_variants: MetricVariants = None,
|
metric_variants: MetricVariants = None,
|
||||||
|
model_events: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compare scalar metrics for different tasks per metric and variant
|
Compare scalar metrics for different tasks per metric and variant
|
||||||
@ -136,6 +157,7 @@ class EventMetrics:
|
|||||||
key=ScalarKey.resolve(key),
|
key=ScalarKey.resolve(key),
|
||||||
metric_variants=metric_variants,
|
metric_variants=metric_variants,
|
||||||
run_parallel=False,
|
run_parallel=False,
|
||||||
|
model_events=model_events,
|
||||||
)
|
)
|
||||||
task_ids, company_ids = zip(
|
task_ids, company_ids = zip(
|
||||||
*(
|
*(
|
||||||
@ -165,7 +187,7 @@ class EventMetrics:
|
|||||||
self,
|
self,
|
||||||
companies: TaskCompanies,
|
companies: TaskCompanies,
|
||||||
metric_variants: MetricVariants = None,
|
metric_variants: MetricVariants = None,
|
||||||
) -> Mapping[str, dict]:
|
) -> Mapping[str, Sequence[dict]]:
|
||||||
"""
|
"""
|
||||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||||
"""
|
"""
|
||||||
|
@ -408,7 +408,7 @@ class TaskBLL:
|
|||||||
task's last iteration value.
|
task's last iteration value.
|
||||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||||
if the current task's last iteration value is smaller than the provided value.
|
if the current task's last iteration value is smaller than the provided value.
|
||||||
:param last_scalar_values: Last reported metrics summary for scalar events (value, metric, variant).
|
:param last_scalar_events: Last reported metrics summary for scalar events (value, metric, variant).
|
||||||
:param last_events: Last reported metrics summary (value, metric, event type).
|
:param last_events: Last reported metrics summary (value, metric, event type).
|
||||||
:param extra_updates: Extra task updates to include in this update call.
|
:param extra_updates: Extra task updates to include in this update call.
|
||||||
:return:
|
:return:
|
||||||
|
@ -395,7 +395,7 @@ def get_last_metric_updates(
|
|||||||
is_min=(key == "min_value"),
|
is_min=(key == "min_value"),
|
||||||
is_first=(key == "first_value"),
|
is_first=(key == "first_value"),
|
||||||
)
|
)
|
||||||
elif key in ("metric", "variant", "value"):
|
elif key in ("metric", "variant", "value", "x_axis_label"):
|
||||||
extra_updates[f"set__{path}__{key}"] = value
|
extra_updates[f"set__{path}__{key}"] = value
|
||||||
|
|
||||||
count = variant_data.get("count")
|
count = variant_data.get("count")
|
||||||
|
@ -28,6 +28,7 @@ class MetricEvent(EmbeddedDocument):
|
|||||||
first_value_iteration = IntField()
|
first_value_iteration = IntField()
|
||||||
count = IntField()
|
count = IntField()
|
||||||
mean_value = FloatField()
|
mean_value = FloatField()
|
||||||
|
x_axis_label = StringField()
|
||||||
|
|
||||||
|
|
||||||
class EventStats(EmbeddedDocument):
|
class EventStats(EmbeddedDocument):
|
||||||
|
@ -299,6 +299,10 @@ last_metrics_event {
|
|||||||
description: "The total count of reported values"
|
description: "The total count of reported values"
|
||||||
type: integer
|
type: integer
|
||||||
}
|
}
|
||||||
|
x_axis_label {
|
||||||
|
description: The user defined value for the X-Axis name stored with the event
|
||||||
|
type: string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
last_metrics_variants {
|
last_metrics_variants {
|
||||||
|
@ -27,13 +27,17 @@ _definitions {
|
|||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
variant {
|
variant {
|
||||||
description: "E.g. 'class_1', 'total', 'average"
|
description: "E.g. 'class_1', 'total', 'average'"
|
||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
value {
|
value {
|
||||||
description: ""
|
description: ""
|
||||||
type: number
|
type: number
|
||||||
}
|
}
|
||||||
|
x_axis_label {
|
||||||
|
description: "Custom X-Axis label to be used when displaying the scalars histogram"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
metrics_vector_event {
|
metrics_vector_event {
|
||||||
|
@ -490,6 +490,7 @@ def scalar_metrics_iter_histogram(
|
|||||||
samples=request.samples,
|
samples=request.samples,
|
||||||
key=request.key,
|
key=request.key,
|
||||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||||
|
model_events=request.model_events,
|
||||||
)
|
)
|
||||||
call.result.data = metrics
|
call.result.data = metrics
|
||||||
|
|
||||||
@ -540,12 +541,13 @@ def multi_task_scalar_metrics_iter_histogram(
|
|||||||
samples=request.samples,
|
samples=request.samples,
|
||||||
key=request.key,
|
key=request.key,
|
||||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||||
|
model_events=request.model_events,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_single_value_metrics_response(
|
def _get_single_value_metrics_response(
|
||||||
companies: TaskCompanies, value_metrics: Mapping[str, dict]
|
companies: TaskCompanies, value_metrics: Mapping[str, Sequence[dict]]
|
||||||
) -> Sequence[dict]:
|
) -> Sequence[dict]:
|
||||||
task_names = {
|
task_names = {
|
||||||
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
|
task.id: task.name for task in itertools.chain.from_iterable(companies.values())
|
||||||
|
@ -282,6 +282,7 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
|||||||
metric_variants=_get_metric_variants_from_request(
|
metric_variants=_get_metric_variants_from_request(
|
||||||
request.scalar_metrics_iter_histogram.metrics
|
request.scalar_metrics_iter_histogram.metrics
|
||||||
),
|
),
|
||||||
|
model_events=request.model_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.single_value_metrics:
|
if request.single_value_metrics:
|
||||||
|
@ -246,6 +246,7 @@ class TestTaskEvents(TestService):
|
|||||||
"variant": f"Variant{variant_idx}",
|
"variant": f"Variant{variant_idx}",
|
||||||
"value": iteration,
|
"value": iteration,
|
||||||
"model_event": True,
|
"model_event": True,
|
||||||
|
"x_axis_label": f"Label_{metric_idx}_{variant_idx}"
|
||||||
}
|
}
|
||||||
for iteration in range(2)
|
for iteration in range(2)
|
||||||
for metric_idx in range(5)
|
for metric_idx in range(5)
|
||||||
@ -274,6 +275,7 @@ class TestTaskEvents(TestService):
|
|||||||
variant_data = metric_data.Variant0
|
variant_data = metric_data.Variant0
|
||||||
self.assertEqual(variant_data.x, [0, 1])
|
self.assertEqual(variant_data.x, [0, 1])
|
||||||
self.assertEqual(variant_data.y, [0.0, 1.0])
|
self.assertEqual(variant_data.y, [0.0, 1.0])
|
||||||
|
self.assertEqual(variant_data.x_axis_label, "Label_0_0")
|
||||||
|
|
||||||
model_data = self.api.models.get_all_ex(
|
model_data = self.api.models.get_all_ex(
|
||||||
id=[model], only_fields=["last_metrics", "last_iteration"]
|
id=[model], only_fields=["last_metrics", "last_iteration"]
|
||||||
@ -285,6 +287,7 @@ class TestTaskEvents(TestService):
|
|||||||
self.assertEqual(1, metric_data.max_value_iteration)
|
self.assertEqual(1, metric_data.max_value_iteration)
|
||||||
self.assertEqual(0, metric_data.min_value)
|
self.assertEqual(0, metric_data.min_value)
|
||||||
self.assertEqual(0, metric_data.min_value_iteration)
|
self.assertEqual(0, metric_data.min_value_iteration)
|
||||||
|
self.assertEqual("Label_4_4", metric_data.x_axis_label)
|
||||||
|
|
||||||
self._assert_log_events(task=task, expected_total=1)
|
self._assert_log_events(task=task, expected_total=1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user