Add support for events.scalar_metrics_iter_raw

This commit is contained in:
allegroai 2022-02-13 19:26:03 +02:00
parent f20cd6536e
commit 36e013b40c
9 changed files with 501 additions and 156 deletions

View File

@ -2,7 +2,7 @@ from enum import auto
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.fields import StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Length, Min, Max
@ -81,14 +81,33 @@ class LogOrderEnum(StringEnum):
desc = auto()
class LogEventsRequest(Base):
class TaskEventsRequestBase(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
class TaskEventsRequest(TaskEventsRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField()
count_total: bool = BoolField(default=True)
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)

View File

@ -24,13 +24,13 @@ from apiserver.bll.event.event_common import (
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@ -73,7 +73,7 @@ class EventBLL(object):
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
self.events_iterator = EventsIterator(es=self.es)
@property
def metrics(self) -> EventMetrics:

View File

@ -69,6 +69,13 @@ def delete_company_events(
return es.delete_by_query(index=es_index, body=body, **kwargs)
def count_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.count(index=es_index, body=body, **kwargs)
def get_metric_variants_condition(
metric_variants: MetricVariants,
) -> Sequence:

View File

@ -0,0 +1,205 @@
from typing import Optional, Tuple, Sequence, Any
import attr
import jsonmodels.models
import jwt
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
MetricVariants,
get_metric_variants_condition,
count_company_events,
)
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class EventsIterator:
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_key_value: Optional[Any] = None,
metric_variants: MetricVariants = None,
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
**kwargs,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
from_key_value = kwargs.pop("from_timestamp", from_key_value)
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
event_type=event_type,
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=ScalarKey.resolve(key),
)
return res
def count_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
metric_variants: MetricVariants = None,
) -> int:
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
es_req = {
"query": query,
}
with translate_errors_context(), TimingContext("es", "count_task_events"):
es_result = count_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
return es_result["count"]
def _get_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
key: ScalarKey,
from_key_value: Optional[Any],
metric_variants: MetricVariants = None,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If from_key_field is not set then start either from latest or earliest.
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must(task_id, metric_variants)
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": query,
"sort": {key.field: "desc" if navigate_earlier else "asc"},
}
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": must + [{"term": {key.field: events[-1][key.field]}}]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)
@staticmethod
def _get_initial_query_and_must(
task_id: str, metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
must = [{"term": {"task": task_id}}]
query = {"term": {"task": task_id}}
else:
must = [
{"term": {"task": task_id}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}
return query, must
class Scroll(jsonmodels.models.Base):
def get_scroll_id(self) -> str:
return jwt.encode(
self.to_struct(),
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
).decode()
@classmethod
def from_scroll_id(cls, scroll_id: str):
try:
return cls(
**jwt.decode(
scroll_id,
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
)
)
except jwt.PyJWTError:
raise ValueError("Invalid Scroll ID")

View File

@ -1,127 +0,0 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
)
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = EventType.task_log
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return TaskEventsResult()
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)

View File

@ -4,6 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from typing import Any
from apiserver.utilities import extract_properties_to_lists
from apiserver.utilities.stringenum import StringEnum
from apiserver.config_repo import config
@ -96,6 +98,10 @@ class ScalarKey(ABC):
"""
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
def cast_value(self, value: Any) -> Any:
"""Cast value to appropriate type"""
return value
class TimestampKey(ScalarKey):
"""
@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class IterKey(ScalarKey):
"""
@ -134,6 +143,9 @@ class IterKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class ISOTimeKey(ScalarKey):
"""

View File

@ -17,6 +17,10 @@ events_retrieval {
# the max amount of variants to aggregate on
max_variants_count: 100
max_raw_scalars_size: 10000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
}
# if set then plot str will be checked for the valid json on plot add

View File

@ -1219,3 +1219,67 @@ get_scalar_metric_data {
}
}
}
scalar_metrics_iter_raw {
"999.0" {
description: "Get raw data for a specific metric variants in the task"
request {
type: object
required: [
task, metric
]
properties {
task {
type: string
description: "Task ID"
}
metric {
description: "Metric and variants for which to return data points"
"$ref": "#/definitions/metric_variants"
}
key {
description: """Array of x axis to return. Supported values:
iter - iteration number
timestamp - event timestamp as milliseconds since epoch
"""
"$ref": "#/definitions/scalar_key_enum"
}
batch_size {
description: "The number of data points to return for this call. Optional, the default value is 5000"
type: integer
default: 5000
}
count_total {
description: "Count the total number of data points. If false, total number of data points is not counted and null is returned"
type: boolean
default: false
}
scroll_id {
description: "Optional Scroll ID. Use to get more data points following a previous call"
type: string
}
}
}
response {
type: object
properties {
variants {
description: "Raw data points for each variant"
type: object
additionalProperties: true
}
total {
description: "Total data points count. If count_total is false, null is returned"
type: integer
}
returned {
description: "Number of data points returned in this call. If 0 results were returned, no more results are avilable"
type: integer
}
scroll_id {
description: "Scroll ID. Use to get more data points when calling this endpoint again"
type: string
}
}
}
}
}

View File

@ -1,9 +1,11 @@
import itertools
from collections import defaultdict
from operator import itemgetter
from typing import Sequence, Optional
import attr
from typing import Sequence, Optional
import jsonmodels.fields
from boltons.iterutils import bucketize
from apiserver.apierrors import errors
from apiserver.apimodels.events import (
@ -20,12 +22,17 @@ from apiserver.apimodels.events import (
NextDebugImageSampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
TaskEventsRequest,
ScalarMetricsIterRawRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL()
event_bll = EventBLL()
@ -39,7 +46,6 @@ def add(call: APICall, company_id, _):
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1
@endpoint("events.add_batch")
@ -50,7 +56,6 @@ def add_batch(call: APICall, company_id, _):
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
@ -113,7 +118,8 @@ def get_task_log(call, company_id, request: LogEventsRequest):
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
res = event_bll.log_events_iterator.get_task_events(
res = event_bll.events_iterator.get_task_events(
event_type=EventType.task_log,
company_id=task.get_index_company(),
task_id=task_id,
batch_size=request.batch_size,
@ -258,31 +264,94 @@ def vector_metrics_iter_histogram(call, company_id, _):
)
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, _):
task_id = call.data["task"]
batch_size = call.data.get("batch_size", 500)
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
class GetTaskEventsScroll(Scroll):
from_key_value = jsonmodels.fields.StringField()
total = jsonmodels.fields.IntField()
request: TaskEventsRequest = jsonmodels.fields.EmbeddedField(TaskEventsRequest)
def make_response(
total: int, returned: int = 0, scroll_id: str = None, **kwargs
) -> dict:
return {
"returned": returned,
"total": total,
"scroll_id": scroll_id,
**kwargs,
}
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
def get_task_events(call, company_id, request: TaskEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
company_id, task_id, allow_public=True, only=("company",),
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=EventType(event_type) if event_type else EventType.all,
scroll_id=scroll_id,
size=batch_size,
key = ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key)
if not request.scroll_id:
from_key_value = None if (request.order == LogOrderEnum.desc) else 0
total = None
else:
try:
scroll = GetTaskEventsScroll.from_scroll_id(request.scroll_id)
except ValueError:
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
if scroll.from_key_value is None:
return make_response(
scroll_id=request.scroll_id, total=scroll.total, events=[]
)
from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
request = scroll.request
navigate_earlier = request.order == LogOrderEnum.desc
metric_variants = _get_metric_variants_from_request(request.metrics)
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=request.event_type,
company_id=task.company,
task_id=task_id,
metric_variants=metric_variants,
)
batch_size = min(
request.batch_size,
int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
),
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
res = event_bll.events_iterator.get_task_events(
event_type=request.event_type,
company_id=task.company,
task_id=task_id,
batch_size=batch_size,
key=ScalarKeyEnum.iter,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
)
scroll = GetTaskEventsScroll(
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
total=total,
request=request,
)
return make_response(
returned=len(res.events),
total=total,
scroll_id=scroll.get_scroll_id(),
events=res.events,
)
@ -759,3 +828,95 @@ def _get_top_iter_unique_events(events, max_iters):
)
unique_events.sort(key=lambda e: e["iter"], reverse=True)
return unique_events
class ScalarMetricsIterRawScroll(Scroll):
from_key_value = jsonmodels.fields.StringField()
total = jsonmodels.fields.IntField()
request: ScalarMetricsIterRawRequest = jsonmodels.fields.EmbeddedField(
ScalarMetricsIterRawRequest
)
@endpoint("events.scalar_metrics_iter_raw", min_version="999.0")
def scalar_metrics_iter_raw(
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
):
key = request.key or ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key)
if not request.scroll_id:
from_key_value = None
total = None
else:
try:
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
except ValueError:
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
if scroll.from_key_value is None:
return make_response(
scroll_id=request.scroll_id, total=scroll.total, variants={}
)
from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
request = scroll.request
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
)[0]
metric_variants = _get_metric_variants_from_request([request.metric])
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
metric_variants=metric_variants,
)
batch_size = min(
request.batch_size,
int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
),
)
res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=False,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=key,
)
key = str(key)
variants = {
variant: extract_properties_to_lists(
["value", scalar_key.field], events, target_keys=["y", key]
)
for variant, events in bucketize(res.events, key=itemgetter("variant")).items()
}
scroll = ScalarMetricsIterRawScroll(
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
total=total,
request=request,
)
return make_response(
returned=len(res.events),
total=total,
scroll_id=scroll.get_scroll_id(),
variants=variants,
)