mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add support for events.scalar_metrics_iter_raw
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import attr
|
||||
from typing import Sequence, Optional
|
||||
import jsonmodels.fields
|
||||
from boltons.iterutils import bucketize
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.events import (
|
||||
@@ -20,12 +22,17 @@ from apiserver.apimodels.events import (
|
||||
NextDebugImageSampleRequest,
|
||||
MetricVariants as ApiMetrics,
|
||||
TaskPlotsRequest,
|
||||
TaskEventsRequest,
|
||||
ScalarMetricsIterRawRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities import json, extract_properties_to_lists
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
@@ -39,7 +46,6 @@ def add(call: APICall, company_id, _):
|
||||
company_id, [data], call.worker, allow_locked_tasks=allow_locked
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = 1
|
||||
|
||||
|
||||
@endpoint("events.add_batch")
|
||||
@@ -50,7 +56,6 @@ def add_batch(call: APICall, company_id, _):
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
@@ -113,7 +118,8 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
res = event_bll.log_events_iterator.get_task_events(
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.task_log,
|
||||
company_id=task.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=request.batch_size,
|
||||
@@ -258,31 +264,94 @@ def vector_metrics_iter_histogram(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", required_fields=["task"])
|
||||
def get_task_events(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
batch_size = call.data.get("batch_size", 500)
|
||||
event_type = call.data.get("event_type")
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
order = call.data.get("order") or "asc"
|
||||
class GetTaskEventsScroll(Scroll):
|
||||
from_key_value = jsonmodels.fields.StringField()
|
||||
total = jsonmodels.fields.IntField()
|
||||
request: TaskEventsRequest = jsonmodels.fields.EmbeddedField(TaskEventsRequest)
|
||||
|
||||
|
||||
def make_response(
|
||||
total: int, returned: int = 0, scroll_id: str = None, **kwargs
|
||||
) -> dict:
|
||||
return {
|
||||
"returned": returned,
|
||||
"total": total,
|
||||
"scroll_id": scroll_id,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
|
||||
def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=EventType(event_type) if event_type else EventType.all,
|
||||
scroll_id=scroll_id,
|
||||
size=batch_size,
|
||||
|
||||
key = ScalarKeyEnum.iter
|
||||
scalar_key = ScalarKey.resolve(key)
|
||||
|
||||
if not request.scroll_id:
|
||||
from_key_value = None if (request.order == LogOrderEnum.desc) else 0
|
||||
total = None
|
||||
else:
|
||||
try:
|
||||
scroll = GetTaskEventsScroll.from_scroll_id(request.scroll_id)
|
||||
except ValueError:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
|
||||
|
||||
if scroll.from_key_value is None:
|
||||
return make_response(
|
||||
scroll_id=request.scroll_id, total=scroll.total, events=[]
|
||||
)
|
||||
|
||||
from_key_value = scalar_key.cast_value(scroll.from_key_value)
|
||||
total = scroll.total
|
||||
|
||||
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
request = scroll.request
|
||||
|
||||
navigate_earlier = request.order == LogOrderEnum.desc
|
||||
metric_variants = _get_metric_variants_from_request(request.metrics)
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
batch_size = min(
|
||||
request.batch_size,
|
||||
int(
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
|
||||
),
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
events=result.events,
|
||||
returned=len(result.events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
key=ScalarKeyEnum.iter,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
scroll = GetTaskEventsScroll(
|
||||
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
|
||||
total=total,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return make_response(
|
||||
returned=len(res.events),
|
||||
total=total,
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
events=res.events,
|
||||
)
|
||||
|
||||
|
||||
@@ -759,3 +828,95 @@ def _get_top_iter_unique_events(events, max_iters):
|
||||
)
|
||||
unique_events.sort(key=lambda e: e["iter"], reverse=True)
|
||||
return unique_events
|
||||
|
||||
|
||||
class ScalarMetricsIterRawScroll(Scroll):
|
||||
from_key_value = jsonmodels.fields.StringField()
|
||||
total = jsonmodels.fields.IntField()
|
||||
request: ScalarMetricsIterRawRequest = jsonmodels.fields.EmbeddedField(
|
||||
ScalarMetricsIterRawRequest
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.scalar_metrics_iter_raw", min_version="999.0")
|
||||
def scalar_metrics_iter_raw(
|
||||
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
|
||||
):
|
||||
key = request.key or ScalarKeyEnum.iter
|
||||
scalar_key = ScalarKey.resolve(key)
|
||||
|
||||
if not request.scroll_id:
|
||||
from_key_value = None
|
||||
total = None
|
||||
else:
|
||||
try:
|
||||
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
|
||||
except ValueError:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
|
||||
|
||||
if scroll.from_key_value is None:
|
||||
return make_response(
|
||||
scroll_id=request.scroll_id, total=scroll.total, variants={}
|
||||
)
|
||||
|
||||
from_key_value = scalar_key.cast_value(scroll.from_key_value)
|
||||
total = scroll.total
|
||||
|
||||
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
request = scroll.request
|
||||
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
)[0]
|
||||
|
||||
metric_variants = _get_metric_variants_from_request([request.metric])
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
batch_size = min(
|
||||
request.batch_size,
|
||||
int(
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
|
||||
),
|
||||
)
|
||||
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=False,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=key,
|
||||
)
|
||||
|
||||
key = str(key)
|
||||
|
||||
variants = {
|
||||
variant: extract_properties_to_lists(
|
||||
["value", scalar_key.field], events, target_keys=["y", key]
|
||||
)
|
||||
for variant, events in bucketize(res.events, key=itemgetter("variant")).items()
|
||||
}
|
||||
|
||||
scroll = ScalarMetricsIterRawScroll(
|
||||
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
|
||||
total=total,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return make_response(
|
||||
returned=len(res.events),
|
||||
total=total,
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
variants=variants,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user