Support model events

This commit is contained in:
allegroai 2022-11-29 17:34:06 +02:00
parent fa5b28ca0e
commit c23e8a90d0
11 changed files with 609 additions and 220 deletions

View File

@ -26,6 +26,7 @@ class MetricVariants(Base):
class ScalarMetricsIterHistogramRequest(HistogramRequestBase): class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True) task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@ -40,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
) )
], ],
) )
model_events: bool = BoolField(default=False)
class TaskMetric(Base): class TaskMetric(Base):
@ -56,6 +58,7 @@ class MetricEventsRequest(Base):
navigate_earlier: bool = BoolField(default=True) navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False) refresh: bool = BoolField(default=False)
scroll_id: str = StringField() scroll_id: str = StringField()
model_events: bool = BoolField()
class TaskMetricVariant(Base): class TaskMetricVariant(Base):
@ -69,12 +72,14 @@ class GetHistorySampleRequest(TaskMetricVariant):
refresh: bool = BoolField(default=False) refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField() scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True) navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class NextHistorySampleRequest(Base): class NextHistorySampleRequest(Base):
task: str = StringField(required=True) task: str = StringField(required=True)
scroll_id: Optional[str] = StringField() scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True) navigate_earlier: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class LogOrderEnum(StringEnum): class LogOrderEnum(StringEnum):
@ -93,6 +98,7 @@ class TaskEventsRequest(TaskEventsRequestBase):
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc) order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField() scroll_id: str = StringField()
count_total: bool = BoolField(default=True) count_total: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class LogEventsRequest(TaskEventsRequestBase): class LogEventsRequest(TaskEventsRequestBase):
@ -108,6 +114,7 @@ class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
metric: MetricVariants = EmbeddedField(MetricVariants, required=True) metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False) count_total: bool = BoolField(default=False)
scroll_id: str = StringField() scroll_id: str = StringField()
model_events: bool = BoolField(default=False)
class IterationEvents(Base): class IterationEvents(Base):
@ -129,6 +136,7 @@ class MultiTasksRequestBase(Base):
tasks: Sequence[str] = ListField( tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)] items_types=str, validators=[Length(minimum_value=1)]
) )
model_events: bool = BoolField(default=False)
class SingleValueMetricsRequest(MultiTasksRequestBase): class SingleValueMetricsRequest(MultiTasksRequestBase):
@ -145,6 +153,7 @@ class TaskPlotsRequest(Base):
scroll_id: str = StringField() scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False) no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
model_events: bool = BoolField(default=False)
class ClearScrollRequest(Base): class ClearScrollRequest(Base):

View File

@ -31,6 +31,7 @@ from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIt
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.util import parallel_chunked_decorator from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.bll.event.event_metrics import EventMetrics from apiserver.bll.event.event_metrics import EventMetrics
@ -68,6 +69,15 @@ class EventBLL(object):
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]", r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
flags=re.IGNORECASE, flags=re.IGNORECASE,
) )
_task_event_query = {
"bool": {
"should": [
{"term": {"model_event": False}},
{"bool": {"must_not": [{"exists": {"field": "model_event"}}]}},
]
}
}
_model_event_query = {"term": {"model_event": True}}
def __init__(self, events_es=None, redis=None): def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events") self.es = events_es or es_factory.connect("events")
@ -103,11 +113,35 @@ class EventBLL(object):
res = Task.objects(query).only("id") res = Task.objects(query).only("id")
return {r.id for r in res} return {r.id for r in res}
@staticmethod
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
"""Verify that task exists and can be updated"""
if not model_ids:
return set()
with translate_errors_context():
query = Q(id__in=model_ids, company=company_id)
if not allow_locked_models:
query &= Q(ready__ne=True)
res = Model.objects(query).only("id")
return {r.id for r in res}
def add_events( def add_events(
self, company_id, events, worker, allow_locked_tasks=False self, company_id, events, worker, allow_locked=False
) -> Tuple[int, int, dict]: ) -> Tuple[int, int, dict]:
model_events = events[0].get("model_event", False)
for event in events:
if event.get("model_event", model_events) != model_events:
raise errors.bad_request.ValidationError(
"Inconsistent model_event setting in the passed events"
)
if event.pop("allow_locked", allow_locked) != allow_locked:
raise errors.bad_request.ValidationError(
"Inconsistent allow_locked setting in the passed events"
)
actions: List[dict] = [] actions: List[dict] = []
task_ids = set() task_or_model_ids = set()
task_iteration = defaultdict(lambda: 0) task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict( task_last_scalar_events = nested_dict(
3, dict 3, dict
@ -117,13 +151,28 @@ class EventBLL(object):
) # task_id -> metric_hash -> event_type -> MetricEvent ) # task_id -> metric_hash -> event_type -> MetricEvent
errors_per_type = defaultdict(int) errors_per_type = defaultdict(int)
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}" invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
valid_tasks = self._get_valid_tasks( if model_events:
for event in events:
model = event.pop("model", None)
if model is not None:
event["task"] = model
valid_entities = self._get_valid_models(
company_id,
model_ids={
event["task"] for event in events if event.get("task") is not None
},
allow_locked_models=allow_locked,
)
entity_name = "model"
else:
valid_entities = self._get_valid_tasks(
company_id, company_id,
task_ids={ task_ids={
event["task"] for event in events if event.get("task") is not None event["task"] for event in events if event.get("task") is not None
}, },
allow_locked_tasks=allow_locked_tasks, allow_locked_tasks=allow_locked,
) )
entity_name = "task"
for event in events: for event in events:
# remove spaces from event type # remove spaces from event type
@ -137,13 +186,17 @@ class EventBLL(object):
errors_per_type[f"Invalid event type {event_type}"] += 1 errors_per_type[f"Invalid event type {event_type}"] += 1
continue continue
task_id = event.get("task") if model_events and event_type == EventType.task_log.value:
if task_id is None: errors_per_type[f"Task log events are not supported for models"] += 1
continue
task_or_model_id = event.get("task")
if task_or_model_id is None:
errors_per_type["Event must have a 'task' field"] += 1 errors_per_type["Event must have a 'task' field"] += 1
continue continue
if task_id not in valid_tasks: if task_or_model_id not in valid_entities:
errors_per_type["Invalid task id"] += 1 errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
continue continue
event["type"] = event_type event["type"] = event_type
@ -165,6 +218,9 @@ class EventBLL(object):
# force iter to be a long int # force iter to be a long int
iter = event.get("iter") iter = event.get("iter")
if iter is not None: if iter is not None:
if model_events:
iter = 0
else:
iter = int(iter) iter = int(iter)
if iter > MAX_LONG or iter < MIN_LONG: if iter > MAX_LONG or iter < MIN_LONG:
errors_per_type[invalid_iteration_error] += 1 errors_per_type[invalid_iteration_error] += 1
@ -178,6 +234,7 @@ class EventBLL(object):
event["metric"] = event.get("metric") or "" event["metric"] = event.get("metric") or ""
event["variant"] = event.get("variant") or "" event["variant"] = event.get("variant") or ""
event["model_event"] = model_events
index_name = get_index_name(company_id, event_type) index_name = get_index_name(company_id, event_type)
es_action = { es_action = {
@ -192,19 +249,24 @@ class EventBLL(object):
else: else:
es_action["_id"] = dbutils.id() es_action["_id"] = dbutils.id()
task_ids.add(task_id) task_or_model_ids.add(task_or_model_id)
if ( if (
iter is not None iter is not None
and not model_events
and event.get("metric") not in self._skip_iteration_for_metric and event.get("metric") not in self._skip_iteration_for_metric
): ):
task_iteration[task_id] = max(iter, task_iteration[task_id]) task_iteration[task_or_model_id] = max(
iter, task_iteration[task_or_model_id]
)
if not model_events:
self._update_last_metric_events_for_task( self._update_last_metric_events_for_task(
last_events=task_last_events[task_id], event=event, last_events=task_last_events[task_or_model_id], event=event,
) )
if event_type == EventType.metrics_scalar.value: if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task( self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_id], event=event last_events=task_last_scalar_events[task_or_model_id],
event=event,
) )
actions.append(es_action) actions.append(es_action)
@ -243,22 +305,25 @@ class EventBLL(object):
else: else:
errors_per_type["Error when indexing events batch"] += 1 errors_per_type["Error when indexing events batch"] += 1
if not model_events:
remaining_tasks = set() remaining_tasks = set()
now = datetime.utcnow() now = datetime.utcnow()
for task_id in task_ids: for task_or_model_id in task_or_model_ids:
# Update related tasks. For reasons of performance, we prefer to update # Update related tasks. For reasons of performance, we prefer to update
# all of them and not only those who's events were successful # all of them and not only those who's events were successful
updated = self._update_task( updated = self._update_task(
company_id=company_id, company_id=company_id,
task_id=task_id, task_id=task_or_model_id,
now=now, now=now,
iter_max=task_iteration.get(task_id), iter_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(task_id), last_scalar_events=task_last_scalar_events.get(
last_events=task_last_events.get(task_id), task_or_model_id
),
last_events=task_last_events.get(task_or_model_id),
) )
if not updated: if not updated:
remaining_tasks.add(task_id) remaining_tasks.add(task_or_model_id)
continue continue
if remaining_tasks: if remaining_tasks:
@ -527,6 +592,7 @@ class EventBLL(object):
scroll_id: str = None, scroll_id: str = None,
no_scroll: bool = False, no_scroll: bool = False,
metric_variants: MetricVariants = None, metric_variants: MetricVariants = None,
model_events: bool = False,
): ):
if scroll_id == self.empty_scroll: if scroll_id == self.empty_scroll:
return TaskEventsResult() return TaskEventsResult()
@ -553,7 +619,7 @@ class EventBLL(object):
} }
must = [plot_valid_condition] must = [plot_valid_condition]
if last_iterations_per_plot is None: if last_iterations_per_plot is None or model_events:
must.append({"terms": {"task": tasks}}) must.append({"terms": {"task": tasks}})
if metric_variants: if metric_variants:
must.append(get_metric_variants_condition(metric_variants)) must.append(get_metric_variants_condition(metric_variants))
@ -709,6 +775,7 @@ class EventBLL(object):
size=500, size=500,
scroll_id=None, scroll_id=None,
no_scroll=False, no_scroll=False,
model_events=False,
) -> TaskEventsResult: ) -> TaskEventsResult:
if scroll_id == self.empty_scroll: if scroll_id == self.empty_scroll:
return TaskEventsResult() return TaskEventsResult()
@ -728,7 +795,7 @@ class EventBLL(object):
if variant: if variant:
must.append({"term": {"variant": variant}}) must.append({"term": {"variant": variant}})
if last_iter_count is None: if last_iter_count is None or model_events:
must.append({"terms": {"task": task_ids}}) must.append({"terms": {"task": task_ids}})
else: else:
tasks_iters = self.get_last_iters( tasks_iters = self.get_last_iters(
@ -989,6 +1056,21 @@ class EventBLL(object):
for tb in es_res["aggregations"]["tasks"]["buckets"] for tb in es_res["aggregations"]["tasks"]["buckets"]
} }
@staticmethod
def _validate_model_state(
company_id: str, model_id: str, allow_locked: bool = False
):
extra_msg = None
query = Q(id=model_id, company=company_id)
if not allow_locked:
query &= Q(ready__ne=True)
extra_msg = "or model published"
res = Model.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidModelId(
extra_msg, company=company_id, id=model_id
)
@staticmethod @staticmethod
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False): def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
extra_msg = None extra_msg = None
@ -1002,7 +1084,12 @@ class EventBLL(object):
extra_msg, company=company_id, id=task_id extra_msg, company=company_id, id=task_id
) )
def delete_task_events(self, company_id, task_id, allow_locked=False): def delete_task_events(self, company_id, task_id, allow_locked=False, model=False):
if model:
self._validate_model_state(
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
)
else:
self._validate_task_state( self._validate_task_state(
company_id=company_id, task_id=task_id, allow_locked=allow_locked company_id=company_id, task_id=task_id, allow_locked=allow_locked
) )

View File

@ -8,9 +8,7 @@ from typing import Sequence, Tuple, Mapping
from boltons.iterutils import bucketize from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event.event_common import ( from apiserver.bll.event.event_common import (
EventType, EventType,
EventSettings, EventSettings,
@ -111,40 +109,19 @@ class EventMetrics:
def compare_scalar_metrics_average_per_iter( def compare_scalar_metrics_average_per_iter(
self, self,
company_id, company_id,
task_ids: Sequence[str], tasks: Sequence[Task],
samples, samples,
key: ScalarKeyEnum, key: ScalarKeyEnum,
allow_public=True,
): ):
""" """
Compare scalar metrics for different tasks per metric and variant Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples The amount of points in each histogram should not exceed the requested samples
""" """
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name", "company", "company_origin"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.get_index_company() for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
event_type = EventType.metrics_scalar event_type = EventType.metrics_scalar
company_id = next(iter(companies))
if check_empty_data(self.es, company_id=company_id, event_type=event_type): if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {} return {}
task_name_by_id = {t.id: t.name for t in tasks}
get_scalar_average_per_iter = partial( get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core, self._get_scalar_average_per_iter_core,
company_id=company_id, company_id=company_id,
@ -153,6 +130,7 @@ class EventMetrics:
key=ScalarKey.resolve(key), key=ScalarKey.resolve(key),
run_parallel=False, run_parallel=False,
) )
task_ids = [t.id for t in tasks]
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
task_metrics = zip( task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids) task_ids, pool.map(get_scalar_average_per_iter, task_ids)
@ -169,7 +147,7 @@ class EventMetrics:
return res return res
def get_task_single_value_metrics( def get_task_single_value_metrics(
self, company_id: str, task_ids: Sequence[str] self, company_id: str, tasks: Sequence[Task]
) -> Mapping[str, dict]: ) -> Mapping[str, dict]:
""" """
For the requested tasks return all the events delivered for the single iteration (-2**31) For the requested tasks return all the events delivered for the single iteration (-2**31)
@ -179,6 +157,7 @@ class EventMetrics:
): ):
return {} return {}
task_ids = [t.id for t in tasks]
task_events = self._get_task_single_value_metrics(company_id, task_ids) task_events = self._get_task_single_value_metrics(company_id, task_ids)
def _get_value(event: dict): def _get_value(event: dict):

View File

@ -1,5 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Callable, Tuple, Sequence, Dict from typing import Callable, Tuple, Sequence, Dict, Optional
from mongoengine import Q
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse from apiserver.apimodels.models import ModelTaskPublishResponse
@ -24,6 +26,33 @@ class ModelBLL:
raise errors.bad_request.InvalidModelId(**query) raise errors.bad_request.InvalidModelId(**query)
return model return model
@staticmethod
def assert_exists(
company_id,
model_ids,
only=None,
allow_public=False,
return_models=True,
) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids)
query = Q(id__in=ids)
q = Model.get_many(
company=company_id,
query=query,
allow_public=allow_public,
return_dicts=False,
)
if only:
q = q.only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidModelId(ids=model_ids)
if return_models:
return list(q)
@classmethod @classmethod
def publish_model( def publish_model(
cls, cls,

View File

@ -92,3 +92,6 @@ class Model(AttributedDocument):
metadata = SafeMapField( metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
) )
def get_index_company(self) -> str:
return self.company or self.company_origin or ""

View File

@ -35,6 +35,12 @@
}, },
"value": { "value": {
"type": "float" "type": "float"
},
"company_id": {
"type": "keyword"
},
"model_event": {
"type": "boolean"
} }
} }
} }

View File

@ -968,5 +968,5 @@ class PrePopulate:
ev["task"] = task_id ev["task"] = task_id
ev["company_id"] = company_id ev["company_id"] = company_id
cls.event_bll.add_events( cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True company_id, events=events, worker="", allow_locked=True
) )

View File

@ -405,13 +405,27 @@ add {
additionalProperties: true additionalProperties: true
} }
} }
"999.0": ${add."2.1"} {
request.properties {
model_event {
type: boolean
description: If set then the event is for a model. Otherwise for a task. Cannot be used with task log events. If used in batch then all the events should be marked the same
default: false
}
allow_locked {
type: boolean
description: Allow adding events to published tasks or models
default: false
}
}
}
} }
add_batch { add_batch {
"2.1" { "2.1" {
description: "Adds a batch of events in a single call (json-lines format, stream-friendly)" description: "Adds a batch of events in a single call (json-lines format, stream-friendly)"
batch_request: { batch_request: {
action: add action: add
version: 1.5 version: 2.1
} }
response { response {
type: object type: object
@ -422,10 +436,16 @@ add_batch {
} }
} }
} }
"999.0": ${add_batch."2.1"} {
batch_request: {
action: add
version: 999.0
}
}
} }
delete_for_task { delete_for_task {
"2.1" { "2.1" {
description: "Delete all task event. *This cannot be undone!*" description: "Delete all task events. *This cannot be undone!*"
request { request {
type: object type: object
required: [ required: [
@ -454,6 +474,37 @@ delete_for_task {
} }
} }
} }
delete_for_model {
"999.0" {
description: "Delete all model events. *This cannot be undone!*"
request {
type: object
required: [
model
]
properties {
model {
type: string
description: "Model ID"
}
allow_locked {
type: boolean
description: "Allow deleting events even if the model is locked"
default: false
}
}
}
response {
type: object
properties {
deleted {
type: boolean
description: "Number of deleted events"
}
}
}
}
}
debug_images { debug_images {
"2.1" { "2.1" {
description: "Get all debug images of a task" description: "Get all debug images of a task"
@ -548,6 +599,13 @@ debug_images {
} }
} }
} }
"999.0": ${debug_images."2.14"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
plots { plots {
"2.20" { "2.20" {
@ -583,6 +641,13 @@ plots {
} }
response {"$ref": "#/definitions/plots_response"} response {"$ref": "#/definitions/plots_response"}
} }
"999.0": ${plots."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
} }
get_debug_image_sample { get_debug_image_sample {
"2.12": { "2.12": {
@ -626,6 +691,13 @@ get_debug_image_sample {
default: true default: true
} }
} }
"999.0": ${get_debug_image_sample."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model debug images. Otherwise task debug images
default: false
}
}
} }
next_debug_image_sample { next_debug_image_sample {
"2.12": { "2.12": {
@ -651,6 +723,18 @@ next_debug_image_sample {
} }
response {"$ref": "#/definitions/debug_image_sample_response"} response {"$ref": "#/definitions/debug_image_sample_response"}
} }
"999.0": ${next_debug_image_sample."2.12"} {
request.properties.next_iteration {
type: boolean
default: false
description: If set then navigate to the next/previous iteration
}
model_events {
type: boolean
description: If set then the retrieving model debug images. Otherwise task debug images
default: false
}
}
} }
get_plot_sample { get_plot_sample {
"2.20": { "2.20": {
@ -692,6 +776,13 @@ get_plot_sample {
} }
response {"$ref": "#/definitions/plot_sample_response"} response {"$ref": "#/definitions/plot_sample_response"}
} }
"999.0": ${get_debug_image_sample."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
} }
next_plot_sample { next_plot_sample {
"2.20": { "2.20": {
@ -717,6 +808,18 @@ next_plot_sample {
} }
response {"$ref": "#/definitions/plot_sample_response"} response {"$ref": "#/definitions/plot_sample_response"}
} }
"999.0": ${next_plot_sample."2.20"} {
request.properties.next_iteration {
type: boolean
default: false
description: If set then navigate to the next/previous iteration
}
model_events {
type: boolean
description: If set then the retrieving model plots. Otherwise task plots
default: false
}
}
} }
get_task_metrics{ get_task_metrics{
"2.7": { "2.7": {
@ -749,6 +852,13 @@ get_task_metrics{
} }
} }
} }
"999.0": ${get_task_metrics."2.7"} {
model_events {
type: boolean
description: If set then get metrics from model events. Otherwise from task events
default: false
}
}
} }
get_task_log { get_task_log {
"1.5" { "1.5" {
@ -966,6 +1076,13 @@ get_task_events {
} }
} }
} }
"999.0": ${get_task_events."2.1"} {
model_events {
type: boolean
description: If set then get retrieving model events. Otherwise task events
default: false
}
}
} }
download_task_log { download_task_log {
@ -1067,6 +1184,13 @@ get_task_plots {
default: false default: false
} }
} }
"999.0": ${get_task_plots."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
get_multi_task_plots { get_multi_task_plots {
"2.1" { "2.1" {
@ -1124,6 +1248,13 @@ get_multi_task_plots {
default: false default: false
} }
} }
"999.0": ${get_multi_task_plots."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
get_vector_metrics_and_variants { get_vector_metrics_and_variants {
"2.1" { "2.1" {
@ -1152,6 +1283,13 @@ get_vector_metrics_and_variants {
} }
} }
} }
"999.0": ${get_vector_metrics_and_variants."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
vector_metrics_iter_histogram { vector_metrics_iter_histogram {
"2.1" { "2.1" {
@ -1190,6 +1328,13 @@ vector_metrics_iter_histogram {
} }
} }
} }
"999.0": ${vector_metrics_iter_histogram."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
scalar_metrics_iter_histogram { scalar_metrics_iter_histogram {
"2.1" { "2.1" {
@ -1243,6 +1388,13 @@ scalar_metrics_iter_histogram {
} }
} }
} }
"999.0": ${scalar_metrics_iter_histogram."2.14"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
multi_task_scalar_metrics_iter_histogram { multi_task_scalar_metrics_iter_histogram {
"2.1" { "2.1" {
@ -1282,6 +1434,13 @@ multi_task_scalar_metrics_iter_histogram {
additionalProperties: true additionalProperties: true
} }
} }
"999.0": ${multi_task_scalar_metrics_iter_histogram."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
get_task_single_value_metrics { get_task_single_value_metrics {
"2.20" { "2.20" {
@ -1331,6 +1490,13 @@ get_task_single_value_metrics {
} }
} }
} }
"999.0": ${get_task_single_value_metrics."2.20"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
get_task_latest_scalar_values { get_task_latest_scalar_values {
"2.1" { "2.1" {
@ -1410,6 +1576,13 @@ get_scalar_metrics_and_variants {
} }
} }
} }
"999.0": ${get_scalar_metrics_and_variants."2.1"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
get_scalar_metric_data { get_scalar_metric_data {
"2.1" { "2.1" {
@ -1459,6 +1632,13 @@ get_scalar_metric_data {
default: false default: false
} }
} }
"999.0": ${get_scalar_metric_data."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
scalar_metrics_iter_raw { scalar_metrics_iter_raw {
"2.16" { "2.16" {
@ -1523,6 +1703,13 @@ scalar_metrics_iter_raw {
} }
} }
} }
"999.0": ${scalar_metrics_iter_raw."2.16"} {
model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
clear_scroll { clear_scroll {
"2.18" { "2.18" {

View File

@ -2,7 +2,7 @@ import itertools
import math import math
from collections import defaultdict from collections import defaultdict
from operator import itemgetter from operator import itemgetter
from typing import Sequence, Optional from typing import Sequence, Optional, Union, Tuple
import attr import attr
import jsonmodels.fields import jsonmodels.fields
@ -33,13 +33,36 @@ from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.event.events_iterator import Scroll from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.model import ModelBLL
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json, extract_properties_to_lists from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL() task_bll = TaskBLL()
event_bll = EventBLL() event_bll = EventBLL()
model_bll = ModelBLL()
def _assert_task_or_model_exists(
company_id: str, task_ids: Union[str, Sequence[str]], model_events: bool
) -> Union[Sequence[Model], Sequence[Task]]:
if model_events:
return model_bll.assert_exists(
company_id,
task_ids,
allow_public=True,
only=("id", "name", "company", "company_origin"),
)
return task_bll.assert_exists(
company_id,
task_ids,
allow_public=True,
only=("id", "name", "company", "company_origin"),
)
@endpoint("events.add") @endpoint("events.add")
@ -47,7 +70,7 @@ def add(call: APICall, company_id, _):
data = call.data.copy() data = call.data.copy()
allow_locked = data.pop("allow_locked", False) allow_locked = data.pop("allow_locked", False)
added, err_count, err_info = event_bll.add_events( added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked company_id, [data], call.worker, allow_locked=allow_locked
) )
call.result.data = dict(added=added, errors=err_count, errors_info=err_info) call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -58,7 +81,12 @@ def add_batch(call: APICall, company_id, _):
if events is None or len(events) == 0: if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems() raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker) added, err_count, err_info = event_bll.add_events(
company_id,
events,
call.worker,
allow_locked=events[0].get("allow_locked", False),
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info) call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
@ -225,12 +253,13 @@ def download_task_log(call, company_id, _):
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"]) @endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _): def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( model_events = call.data["model_events"]
company_id, task_id, allow_public=True, only=("company", "company_origin") task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, EventType.metrics_vector task_or_model.get_index_company(), task_id, EventType.metrics_vector
) )
) )
@ -238,12 +267,13 @@ def get_vector_metrics_and_variants(call, company_id, _):
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"]) @endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _): def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( model_events = call.data["model_events"]
company_id, task_id, allow_public=True, only=("company", "company_origin") task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, EventType.metrics_scalar task_or_model.get_index_company(), task_id, EventType.metrics_scalar
) )
) )
@ -255,13 +285,14 @@ def get_scalar_metrics_and_variants(call, company_id, _):
) )
def vector_metrics_iter_histogram(call, company_id, _): def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( model_events = call.data["model_events"]
company_id, task_id, allow_public=True, only=("company", "company_origin") task_or_model = _assert_task_or_model_exists(
company_id, task_id, model_events=model_events,
)[0] )[0]
metric = call.data["metric"] metric = call.data["metric"]
variant = call.data["variant"] variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter( iterations, vectors = event_bll.get_vector_metrics_per_iter(
task.get_index_company(), task_id, metric, variant task_or_model.get_index_company(), task_id, metric, variant
) )
call.result.data = dict( call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations metric=metric, variant=variant, vectors=vectors, iterations=iterations
@ -286,11 +317,10 @@ def make_response(
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest) @endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
def get_task_events(call, company_id, request: TaskEventsRequest): def get_task_events(_, company_id, request: TaskEventsRequest):
task_id = request.task task_id = request.task
task_or_model = _assert_task_or_model_exists(
task = task_bll.assert_exists( company_id, task_id, model_events=request.model_events,
company_id, task_id, allow_public=True, only=("company",),
)[0] )[0]
key = ScalarKeyEnum.iter key = ScalarKeyEnum.iter
@ -322,7 +352,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
if request.count_total and total is None: if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events( total = event_bll.events_iterator.count_task_events(
event_type=request.event_type, event_type=request.event_type,
company_id=task.company, company_id=task_or_model.get_index_company(),
task_id=task_id, task_id=task_id,
metric_variants=metric_variants, metric_variants=metric_variants,
) )
@ -336,7 +366,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
res = event_bll.events_iterator.get_task_events( res = event_bll.events_iterator.get_task_events(
event_type=request.event_type, event_type=request.event_type,
company_id=task.company, company_id=task_or_model.get_index_company(),
task_id=task_id, task_id=task_id,
batch_size=batch_size, batch_size=batch_size,
key=ScalarKeyEnum.iter, key=ScalarKeyEnum.iter,
@ -365,18 +395,20 @@ def get_scalar_metric_data(call, company_id, _):
metric = call.data["metric"] metric = call.data["metric"]
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False) no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
task = task_bll.assert_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin") company_id, task_id, model_events=model_events,
)[0] )[0]
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.get_index_company(), task_or_model.get_index_company(),
task_id, task_id,
event_type=EventType.metrics_scalar, event_type=EventType.metrics_scalar,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
metric=metric, metric=metric,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=no_scroll, no_scroll=no_scroll,
model_events=model_events,
) )
call.result.data = dict( call.result.data = dict(
@ -398,7 +430,7 @@ def get_task_latest_scalar_values(call, company_id, _):
index_company, task_id index_company, task_id
) )
last_iters = event_bll.get_last_iters( last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1 company_id=index_company, event_type=EventType.all, task_id=task_id, iters=1
).get(task_id) ).get(task_id)
call.result.data = dict( call.result.data = dict(
metrics=metrics, metrics=metrics,
@ -417,11 +449,11 @@ def get_task_latest_scalar_values(call, company_id, _):
def scalar_metrics_iter_histogram( def scalar_metrics_iter_histogram(
call, company_id, request: ScalarMetricsIterHistogramRequest call, company_id, request: ScalarMetricsIterHistogramRequest
): ):
task = task_bll.assert_exists( task_or_model = _assert_task_or_model_exists(
company_id, request.task, allow_public=True, only=("company", "company_origin") company_id, request.task, model_events=request.model_events
)[0] )[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter( metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.get_index_company(), company_id=task_or_model.get_index_company(),
task_id=request.task, task_id=request.task,
samples=request.samples, samples=request.samples,
key=request.key, key=request.key,
@ -429,24 +461,55 @@ def scalar_metrics_iter_histogram(
call.result.data = metrics call.result.data = metrics
def _get_task_or_model_index_company(
company_id: str, task_ids: Sequence[str], model_events=False,
) -> Tuple[str, Sequence[Task]]:
"""
Verify that all tasks exists and belong to store data in the same company index
Return company and tasks
"""
tasks_or_models = _assert_task_or_model_exists(
company_id, task_ids, model_events=model_events,
)
unique_ids = set(task_ids)
if len(tasks_or_models) < len(unique_ids):
invalid = tuple(unique_ids - {t.id for t in tasks_or_models})
error_cls = (
errors.bad_request.InvalidModelId
if model_events
else errors.bad_request.InvalidTaskId
)
raise error_cls(company=company_id, ids=invalid)
companies = {t.get_index_company() for t in tasks_or_models}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
return companies.pop(), tasks_or_models
@endpoint( @endpoint(
"events.multi_task_scalar_metrics_iter_histogram", "events.multi_task_scalar_metrics_iter_histogram",
request_data_model=MultiTaskScalarMetricsIterHistogramRequest, request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
) )
def multi_task_scalar_metrics_iter_histogram( def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest call, company_id, request: MultiTaskScalarMetricsIterHistogramRequest
): ):
task_ids = req_model.tasks task_ids = request.tasks
if isinstance(task_ids, str): if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")] task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names company, tasks_or_models = _get_task_or_model_index_company(
company_id, task_ids, request.model_events
)
call.result.data = dict( call.result.data = dict(
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter( metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id, company_id=company,
task_ids=task_ids, tasks=tasks_or_models,
samples=req_model.samples, samples=request.samples,
allow_public=True, key=request.key,
key=req_model.key,
) )
) )
@ -455,21 +518,11 @@ def multi_task_scalar_metrics_iter_histogram(
def get_task_single_value_metrics( def get_task_single_value_metrics(
call, company_id: str, request: SingleValueMetricsRequest call, company_id: str, request: SingleValueMetricsRequest
): ):
task_ids = call.data["tasks"] company, tasks_or_models = _get_task_or_model_index_company(
tasks = task_bll.assert_exists( company_id, request.tasks, request.model_events
company_id=call.identity.company,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
) )
companies = {t.get_index_company() for t in tasks} res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models)
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids)
call.result.data = dict( call.result.data = dict(
tasks=[{"task": task, "values": values} for task, values in res.items()] tasks=[{"task": task, "values": values} for task, values in res.items()]
) )
@ -481,22 +534,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
iters = call.data.get("iters", 1) iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists( company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
company_id=company_id,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination # Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events( result = event_bll.get_task_events(
next(iter(companies)), company,
task_ids, task_ids,
event_type=EventType.metrics_plot, event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
@ -504,7 +546,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
scroll_id=scroll_id, scroll_id=scroll_id,
) )
tasks = {t.id: t.name for t in tasks} tasks = {t.id: t.name for t in tasks_or_models}
return_events = _get_top_iter_unique_events_per_task( return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks result.events, max_iters=iters, tasks=tasks
@ -519,36 +561,29 @@ def get_multi_task_plots_v1_7(call, company_id, _):
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"]) @endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, req_model): def get_multi_task_plots(call, company_id, _):
task_ids = call.data["tasks"] task_ids = call.data["tasks"]
iters = call.data.get("iters", 1) iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False) no_scroll = call.data.get("no_scroll", False)
model_events = call.data.get("model_events", False)
tasks = task_bll.assert_exists( company, tasks_or_models = _get_task_or_model_index_company(
company_id=call.identity.company, company_id, task_ids, model_events
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
) )
result = event_bll.get_task_events( result = event_bll.get_task_events(
next(iter(companies)), company,
task_ids, task_ids,
event_type=EventType.metrics_plot, event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters, last_iter_count=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=no_scroll, no_scroll=no_scroll,
model_events=model_events,
) )
tasks = {t.id: t.name for t in tasks} tasks = {t.id: t.name for t in tasks_or_models}
return_events = _get_top_iter_unique_events_per_task( return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks result.events, max_iters=iters, tasks=tasks
@ -615,17 +650,18 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
iters = request.iters iters = request.iters
scroll_id = request.scroll_id scroll_id = request.scroll_id
task = task_bll.assert_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin") company_id, task_id, model_events=request.model_events
)[0] )[0]
result = event_bll.get_task_plots( result = event_bll.get_task_plots(
task.get_index_company(), task_or_model.get_index_company(),
tasks=[task_id], tasks=[task_id],
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters, last_iterations_per_plot=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=request.no_scroll, no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics), metric_variants=_get_metric_variants_from_request(request.metrics),
model_events=request.model_events,
) )
return_events = result.events return_events = result.events
@ -651,21 +687,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
if None in metrics: if None in metrics:
metrics.clear() metrics.clear()
tasks = task_bll.assert_exists( company, _ = _get_task_or_model_index_company(
company_id, company_id, task_ids=list(task_metrics), model_events=request.model_events
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
) )
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.plots_iterator.get_task_events( result = event_bll.plots_iterator.get_task_events(
company_id=next(iter(companies)), company_id=company,
task_metrics=task_metrics, task_metrics=task_metrics,
iter_count=request.iters, iter_count=request.iters,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
@ -730,17 +756,19 @@ def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
iters = call.data.get("iters") or 1 iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
model_events = call.data.get("model_events", False)
task = task_bll.assert_exists( tasks_or_model = _assert_task_or_model_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin") company_id, task_id, model_events=model_events,
)[0] )[0]
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.get_index_company(), tasks_or_model.get_index_company(),
task_id, task_id,
event_type=EventType.metrics_image, event_type=EventType.metrics_image,
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters, last_iter_count=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
model_events=model_events,
) )
return_events = result.events return_events = result.events
@ -768,21 +796,12 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
if None in metrics: if None in metrics:
metrics.clear() metrics.clear()
tasks = task_bll.assert_exists( company, _ = _get_task_or_model_index_company(
company_id, company_id, task_ids=list(task_metrics), model_events=request.model_events
task_ids=list(task_metrics),
allow_public=True,
only=("company", "company_origin"),
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
) )
result = event_bll.debug_images_iterator.get_task_events( result = event_bll.debug_images_iterator.get_task_events(
company_id=next(iter(companies)), company_id=company,
task_metrics=task_metrics, task_metrics=task_metrics,
iter_count=request.iters, iter_count=request.iters,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
@ -811,11 +830,11 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
request_data_model=GetHistorySampleRequest, request_data_model=GetHistorySampleRequest,
) )
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest): def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
task = task_bll.assert_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",) company_id, request.task, model_events=request.model_events,
)[0] )[0]
res = event_bll.debug_image_sample_history.get_sample_for_variant( res = event_bll.debug_image_sample_history.get_sample_for_variant(
company_id=task.company, company_id=task_or_model.get_index_company(),
task=request.task, task=request.task,
metric=request.metric, metric=request.metric,
variant=request.variant, variant=request.variant,
@ -833,11 +852,11 @@ def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
request_data_model=NextHistorySampleRequest, request_data_model=NextHistorySampleRequest,
) )
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest): def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
task = task_bll.assert_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",) company_id, request.task, model_events=request.model_events,
)[0] )[0]
res = event_bll.debug_image_sample_history.get_next_sample( res = event_bll.debug_image_sample_history.get_next_sample(
company_id=task.company, company_id=task_or_model.get_index_company(),
task=request.task, task=request.task,
state_id=request.scroll_id, state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
@ -849,11 +868,11 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
"events.get_plot_sample", request_data_model=GetHistorySampleRequest, "events.get_plot_sample", request_data_model=GetHistorySampleRequest,
) )
def get_plot_sample(call, company_id, request: GetHistorySampleRequest): def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
task = task_bll.assert_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",) company_id, request.task, model_events=request.model_events,
)[0] )[0]
res = event_bll.plot_sample_history.get_sample_for_variant( res = event_bll.plot_sample_history.get_sample_for_variant(
company_id=task.company, company_id=task_or_model.get_index_company(),
task=request.task, task=request.task,
metric=request.metric, metric=request.metric,
variant=request.variant, variant=request.variant,
@ -869,11 +888,11 @@ def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
"events.next_plot_sample", request_data_model=NextHistorySampleRequest, "events.next_plot_sample", request_data_model=NextHistorySampleRequest,
) )
def next_plot_sample(call, company_id, request: NextHistorySampleRequest): def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task = task_bll.assert_exists( task_or_model = _assert_task_or_model_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",) company_id, request.task, model_events=request.model_events,
)[0] )[0]
res = event_bll.plot_sample_history.get_next_sample( res = event_bll.plot_sample_history.get_next_sample(
company_id=task.company, company_id=task_or_model.get_index_company(),
task=request.task, task=request.task,
state_id=request.scroll_id, state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
@ -883,14 +902,11 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists( company, _ = _get_task_or_model_index_company(
company_id, company_id, request.tasks, model_events=request.model_events
task_ids=request.tasks, )
allow_public=True,
only=("company", "company_origin"),
)[0]
res = event_bll.metrics.get_task_metrics( res = event_bll.metrics.get_task_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type company, task_ids=request.tasks, event_type=request.event_type
) )
call.result.data = { call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res] "metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
@ -898,7 +914,7 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
@endpoint("events.delete_for_task", required_fields=["task"]) @endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model): def delete_for_task(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False) allow_locked = call.data.get("allow_locked", False)
@ -910,6 +926,19 @@ def delete_for_task(call, company_id, req_model):
) )
@endpoint("events.delete_for_model", required_fields=["model"])
def delete_for_model(call: APICall, company_id: str, _):
model_id = call.data["model"]
allow_locked = call.data.get("allow_locked", False)
model_bll.assert_exists(company_id, model_id, return_models=False)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, model_id, allow_locked=allow_locked, model=True
)
)
@endpoint("events.clear_task_log") @endpoint("events.clear_task_log")
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest): def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
task_id = request.task task_id = request.task
@ -1004,17 +1033,13 @@ def scalar_metrics_iter_raw(
request.batch_size = request.batch_size or scroll.request.batch_size request.batch_size = request.batch_size or scroll.request.batch_size
task_id = request.task task_id = request.task
task_or_model = _assert_task_or_model_exists(company_id, task_id, model_events=request.model_events)[0]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
)[0]
metric_variants = _get_metric_variants_from_request([request.metric]) metric_variants = _get_metric_variants_from_request([request.metric])
if request.count_total and total is None: if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events( total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar, event_type=EventType.metrics_scalar,
company_id=task.company, company_id=task_or_model.get_index_company(),
task_id=task_id, task_id=task_id,
metric_variants=metric_variants, metric_variants=metric_variants,
) )
@ -1030,7 +1055,7 @@ def scalar_metrics_iter_raw(
for iteration in range(0, math.ceil(batch_size / 10_000)): for iteration in range(0, math.ceil(batch_size / 10_000)):
res = event_bll.events_iterator.get_task_events( res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar, event_type=EventType.metrics_scalar,
company_id=task.company, company_id=task_or_model.get_index_company(),
task_id=task_id, task_id=task_id,
batch_size=min(batch_size, 10_000), batch_size=min(batch_size, 10_000),
navigate_earlier=False, navigate_earlier=False,

View File

@ -6,17 +6,24 @@ from typing import Sequence, Optional, Tuple
from boltons.iterutils import first from boltons.iterutils import first
from apiserver.apierrors import errors
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
from apiserver.apierrors.errors.bad_request import EventsNotAdded from apiserver.apierrors.errors.bad_request import EventsNotAdded
from apiserver.tests.automated import TestService from apiserver.tests.automated import TestService
class TestTaskEvents(TestService): class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True)
def _temp_task(self, name="test task events"): def _temp_task(self, name="test task events"):
task_input = dict( task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])), name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
) )
return self.create_temp("tasks", **task_input) return self.create_temp("tasks", delete_paramse=self.delete_params, **task_input)
def _temp_model(self, name="test model events", **kwargs):
self.update_missing(kwargs, name=name, uri="file:///a/b", labels={})
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
@staticmethod @staticmethod
def _create_task_event(type_, task, iteration, **kwargs): def _create_task_event(type_, task, iteration, **kwargs):
@ -172,6 +179,42 @@ class TestTaskEvents(TestService):
self.assertEqual(iter_count - 1, metric_data.max_value) self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(0, metric_data.min_value) self.assertEqual(0, metric_data.min_value)
def test_model_events(self):
model = self._temp_model(ready=False)
# task log events are not allowed
log_event = self._create_task_event(
"log",
task=model,
iteration=0,
msg=f"This is a log message",
model_event=True,
)
with self.api.raises(errors.bad_request.EventsNotAdded):
self.send(log_event)
# send metric events and check that model data always have iteration 0 and only last data is saved
events = [
{
**self._create_task_event("training_stats_scalar", model, iteration),
"metric": f"Metric{metric_idx}",
"variant": f"Variant{variant_idx}",
"value": iteration,
"model_event": True,
}
for iteration in range(2)
for metric_idx in range(5)
for variant_idx in range(5)
]
self.send_batch(events)
data = self.api.events.scalar_metrics_iter_histogram(task=model, model_events=True)
self.assertEqual(list(data), [f"Metric{idx}" for idx in range(5)])
metric_data = data.Metric0
self.assertEqual(list(metric_data), [f"Variant{idx}" for idx in range(5)])
variant_data = metric_data.Variant0
self.assertEqual(variant_data.x, [0])
self.assertEqual(variant_data.y, [1.0])
def test_error_events(self): def test_error_events(self):
task = self._temp_task() task = self._temp_task()
events = [ events = [
@ -555,7 +598,8 @@ class TestTaskEvents(TestService):
return data return data
def send(self, event): def send(self, event):
self.api.send("events.add", event) _, data = self.api.send("events.add", event)
return data
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -50,10 +50,12 @@ class TestTasksResetDelete(TestService):
self.assertEqual(res.urls.artifact_urls, []) self.assertEqual(res.urls.artifact_urls, [])
task = self.new_task() task = self.new_task()
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task)
published_model_urls, draft_model_urls = self.create_task_models(task) published_model_urls, draft_model_urls = self.create_task_models(task)
artifact_urls = self.send_artifacts(task) artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task) event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task)) event_urls.update(self.send_plot_events(task))
event_urls.update(self.send_model_events(model))
res = self.assert_delete_task(task, force=True, return_file_urls=True) res = self.assert_delete_task(task, force=True, return_file_urls=True)
self.assertEqual(set(res.urls.model_urls), draft_model_urls) self.assertEqual(set(res.urls.model_urls), draft_model_urls)
self.assertEqual(set(res.urls.event_urls), event_urls) self.assertEqual(set(res.urls.event_urls), event_urls)
@ -120,10 +122,12 @@ class TestTasksResetDelete(TestService):
self, **kwargs self, **kwargs
) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]: ) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]:
task = self.new_task(**kwargs) task = self.new_task(**kwargs)
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task, **kwargs)
published_model_urls, draft_model_urls = self.create_task_models(task, **kwargs) published_model_urls, draft_model_urls = self.create_task_models(task, **kwargs)
artifact_urls = self.send_artifacts(task) artifact_urls = self.send_artifacts(task)
event_urls = self.send_debug_image_events(task) event_urls = self.send_debug_image_events(task)
event_urls.update(self.send_plot_events(task)) event_urls.update(self.send_plot_events(task))
event_urls.update(self.send_model_events(model))
return task, (published_model_urls, draft_model_urls), artifact_urls, event_urls return task, (published_model_urls, draft_model_urls), artifact_urls, event_urls
def assert_delete_task(self, task_id, force=False, return_file_urls=False): def assert_delete_task(self, task_id, force=False, return_file_urls=False):
@ -137,15 +141,17 @@ class TestTasksResetDelete(TestService):
self.assertEqual(tasks, []) self.assertEqual(tasks, [])
return res return res
def create_task_models(self, task, **kwargs) -> Tuple[Set[str], Set[str]]: def create_task_models(self, task, **kwargs) -> Tuple:
""" """
Update models from task and return only non public models Update models from task and return only non public models
""" """
model_ready = self.new_model(uri="ready", **kwargs) ready_uri = "ready"
model_not_ready = self.new_model(uri="not_ready", ready=False, **kwargs) not_ready_uri = "not_ready"
model_ready = self.new_model(uri=ready_uri, **kwargs)
model_not_ready = self.new_model(uri=not_ready_uri, ready=False, **kwargs)
self.api.models.edit(model=model_not_ready, task=task) self.api.models.edit(model=model_not_ready, task=task)
self.api.models.edit(model=model_ready, task=task) self.api.models.edit(model=model_ready, task=task)
return {"ready"}, {"not_ready"} return (model_ready, {ready_uri}), (model_not_ready, {not_ready_uri})
def send_artifacts(self, task) -> Set[str]: def send_artifacts(self, task) -> Set[str]:
""" """
@ -159,6 +165,20 @@ class TestTasksResetDelete(TestService):
self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts) self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts)
return {"test2"} return {"test2"}
def send_model_events(self, model) -> Set[str]:
url1 = "http://link1"
url2 = "http://link2"
events = [
self.create_event(
model, "training_debug_image", 0, url=url1, model_event=True
),
self.create_event(
model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True
)
]
self.send_batch(events)
return {url1, url2}
def send_debug_image_events(self, task) -> Set[str]: def send_debug_image_events(self, task) -> Set[str]:
events = [ events = [
self.create_event( self.create_event(