mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Support model events
This commit is contained in:
parent
fa5b28ca0e
commit
c23e8a90d0
@ -26,6 +26,7 @@ class MetricVariants(Base):
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
@ -40,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
)
|
||||
],
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class TaskMetric(Base):
|
||||
@ -56,6 +58,7 @@ class MetricEventsRequest(Base):
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField()
|
||||
|
||||
|
||||
class TaskMetricVariant(Base):
|
||||
@ -69,12 +72,14 @@ class GetHistorySampleRequest(TaskMetricVariant):
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextHistorySampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogOrderEnum(StringEnum):
|
||||
@ -93,6 +98,7 @@ class TaskEventsRequest(TaskEventsRequestBase):
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
|
||||
scroll_id: str = StringField()
|
||||
count_total: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
@ -108,6 +114,7 @@ class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
@ -129,6 +136,7 @@ class MultiTasksRequestBase(Base):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class SingleValueMetricsRequest(MultiTasksRequestBase):
|
||||
@ -145,6 +153,7 @@ class TaskPlotsRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class ClearScrollRequest(Base):
|
||||
|
@ -31,6 +31,7 @@ from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIt
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_metrics import EventMetrics
|
||||
@ -68,6 +69,15 @@ class EventBLL(object):
|
||||
r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
_task_event_query = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{"term": {"model_event": False}},
|
||||
{"bool": {"must_not": [{"exists": {"field": "model_event"}}]}},
|
||||
]
|
||||
}
|
||||
}
|
||||
_model_event_query = {"term": {"model_event": True}}
|
||||
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.es = events_es or es_factory.connect("events")
|
||||
@ -103,11 +113,35 @@ class EventBLL(object):
|
||||
res = Task.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
@staticmethod
|
||||
def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set:
|
||||
"""Verify that task exists and can be updated"""
|
||||
if not model_ids:
|
||||
return set()
|
||||
|
||||
with translate_errors_context():
|
||||
query = Q(id__in=model_ids, company=company_id)
|
||||
if not allow_locked_models:
|
||||
query &= Q(ready__ne=True)
|
||||
res = Model.objects(query).only("id")
|
||||
return {r.id for r in res}
|
||||
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
self, company_id, events, worker, allow_locked=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
model_events = events[0].get("model_event", False)
|
||||
for event in events:
|
||||
if event.get("model_event", model_events) != model_events:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Inconsistent model_event setting in the passed events"
|
||||
)
|
||||
if event.pop("allow_locked", allow_locked) != allow_locked:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Inconsistent allow_locked setting in the passed events"
|
||||
)
|
||||
|
||||
actions: List[dict] = []
|
||||
task_ids = set()
|
||||
task_or_model_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_scalar_events = nested_dict(
|
||||
3, dict
|
||||
@ -117,13 +151,28 @@ class EventBLL(object):
|
||||
) # task_id -> metric_hash -> event_type -> MetricEvent
|
||||
errors_per_type = defaultdict(int)
|
||||
invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}"
|
||||
valid_tasks = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked_tasks,
|
||||
)
|
||||
if model_events:
|
||||
for event in events:
|
||||
model = event.pop("model", None)
|
||||
if model is not None:
|
||||
event["task"] = model
|
||||
valid_entities = self._get_valid_models(
|
||||
company_id,
|
||||
model_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_models=allow_locked,
|
||||
)
|
||||
entity_name = "model"
|
||||
else:
|
||||
valid_entities = self._get_valid_tasks(
|
||||
company_id,
|
||||
task_ids={
|
||||
event["task"] for event in events if event.get("task") is not None
|
||||
},
|
||||
allow_locked_tasks=allow_locked,
|
||||
)
|
||||
entity_name = "task"
|
||||
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
@ -137,13 +186,17 @@ class EventBLL(object):
|
||||
errors_per_type[f"Invalid event type {event_type}"] += 1
|
||||
continue
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is None:
|
||||
if model_events and event_type == EventType.task_log.value:
|
||||
errors_per_type[f"Task log events are not supported for models"] += 1
|
||||
continue
|
||||
|
||||
task_or_model_id = event.get("task")
|
||||
if task_or_model_id is None:
|
||||
errors_per_type["Event must have a 'task' field"] += 1
|
||||
continue
|
||||
|
||||
if task_id not in valid_tasks:
|
||||
errors_per_type["Invalid task id"] += 1
|
||||
if task_or_model_id not in valid_entities:
|
||||
errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1
|
||||
continue
|
||||
|
||||
event["type"] = event_type
|
||||
@ -165,10 +218,13 @@ class EventBLL(object):
|
||||
# force iter to be a long int
|
||||
iter = event.get("iter")
|
||||
if iter is not None:
|
||||
iter = int(iter)
|
||||
if iter > MAX_LONG or iter < MIN_LONG:
|
||||
errors_per_type[invalid_iteration_error] += 1
|
||||
continue
|
||||
if model_events:
|
||||
iter = 0
|
||||
else:
|
||||
iter = int(iter)
|
||||
if iter > MAX_LONG or iter < MIN_LONG:
|
||||
errors_per_type[invalid_iteration_error] += 1
|
||||
continue
|
||||
event["iter"] = iter
|
||||
|
||||
# used to have "values" to indicate array. no need anymore
|
||||
@ -178,6 +234,7 @@ class EventBLL(object):
|
||||
|
||||
event["metric"] = event.get("metric") or ""
|
||||
event["variant"] = event.get("variant") or ""
|
||||
event["model_event"] = model_events
|
||||
|
||||
index_name = get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
@ -192,21 +249,26 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_ids.add(task_id)
|
||||
task_or_model_ids.add(task_or_model_id)
|
||||
if (
|
||||
iter is not None
|
||||
and not model_events
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_id], event=event
|
||||
task_iteration[task_or_model_id] = max(
|
||||
iter, task_iteration[task_or_model_id]
|
||||
)
|
||||
|
||||
if not model_events:
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_or_model_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_or_model_id],
|
||||
event=event,
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
plot_actions = [
|
||||
@ -243,28 +305,31 @@ class EventBLL(object):
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_id),
|
||||
last_scalar_events=task_last_scalar_events.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
)
|
||||
if not model_events:
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_or_model_id in task_or_model_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_or_model_id,
|
||||
now=now,
|
||||
iter_max=task_iteration.get(task_or_model_id),
|
||||
last_scalar_events=task_last_scalar_events.get(
|
||||
task_or_model_id
|
||||
),
|
||||
last_events=task_last_events.get(task_or_model_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
if not updated:
|
||||
remaining_tasks.add(task_or_model_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# this is for backwards compatibility with streaming bulk throwing exception on those
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
@ -527,6 +592,7 @@ class EventBLL(object):
|
||||
scroll_id: str = None,
|
||||
no_scroll: bool = False,
|
||||
metric_variants: MetricVariants = None,
|
||||
model_events: bool = False,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@ -553,7 +619,7 @@ class EventBLL(object):
|
||||
}
|
||||
must = [plot_valid_condition]
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
if last_iterations_per_plot is None or model_events:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
@ -709,6 +775,7 @@ class EventBLL(object):
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
no_scroll=False,
|
||||
model_events=False,
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@ -728,7 +795,7 @@ class EventBLL(object):
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
if last_iter_count is None or model_events:
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
tasks_iters = self.get_last_iters(
|
||||
@ -989,6 +1056,21 @@ class EventBLL(object):
|
||||
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _validate_model_state(
|
||||
company_id: str, model_id: str, allow_locked: bool = False
|
||||
):
|
||||
extra_msg = None
|
||||
query = Q(id=model_id, company=company_id)
|
||||
if not allow_locked:
|
||||
query &= Q(ready__ne=True)
|
||||
extra_msg = "or model published"
|
||||
res = Model.objects(query).only("id").first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
extra_msg, company=company_id, id=model_id
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False):
|
||||
extra_msg = None
|
||||
@ -1002,10 +1084,15 @@ class EventBLL(object):
|
||||
extra_msg, company=company_id, id=task_id
|
||||
)
|
||||
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False):
|
||||
self._validate_task_state(
|
||||
company_id=company_id, task_id=task_id, allow_locked=allow_locked
|
||||
)
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False, model=False):
|
||||
if model:
|
||||
self._validate_model_state(
|
||||
company_id=company_id, model_id=task_id, allow_locked=allow_locked,
|
||||
)
|
||||
else:
|
||||
self._validate_task_state(
|
||||
company_id=company_id, task_id=task_id, allow_locked=allow_locked
|
||||
)
|
||||
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context():
|
||||
|
@ -8,9 +8,7 @@ from typing import Sequence, Tuple, Mapping
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
@ -111,40 +109,19 @@ class EventMetrics:
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
task_ids: Sequence[str],
|
||||
tasks: Sequence[Task],
|
||||
samples,
|
||||
key: ScalarKeyEnum,
|
||||
allow_public=True,
|
||||
):
|
||||
"""
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name", "company", "company_origin"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.get_index_company() for t in task_objs}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
event_type = EventType.metrics_scalar
|
||||
company_id = next(iter(companies))
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in tasks}
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
company_id=company_id,
|
||||
@ -153,6 +130,7 @@ class EventMetrics:
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
task_ids = [t.id for t in tasks]
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
@ -169,7 +147,7 @@ class EventMetrics:
|
||||
return res
|
||||
|
||||
def get_task_single_value_metrics(
|
||||
self, company_id: str, task_ids: Sequence[str]
|
||||
self, company_id: str, tasks: Sequence[Task]
|
||||
) -> Mapping[str, dict]:
|
||||
"""
|
||||
For the requested tasks return all the events delivered for the single iteration (-2**31)
|
||||
@ -179,6 +157,7 @@ class EventMetrics:
|
||||
):
|
||||
return {}
|
||||
|
||||
task_ids = [t.id for t in tasks]
|
||||
task_events = self._get_task_single_value_metrics(company_id, task_ids)
|
||||
|
||||
def _get_value(event: dict):
|
||||
|
@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Tuple, Sequence, Dict
|
||||
from typing import Callable, Tuple, Sequence, Dict, Optional
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
@ -24,6 +26,33 @@ class ModelBLL:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(
|
||||
company_id,
|
||||
model_ids,
|
||||
only=None,
|
||||
allow_public=False,
|
||||
return_models=True,
|
||||
) -> Optional[Sequence[Model]]:
|
||||
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
|
||||
ids = set(model_ids)
|
||||
query = Q(id__in=ids)
|
||||
|
||||
q = Model.get_many(
|
||||
company=company_id,
|
||||
query=query,
|
||||
allow_public=allow_public,
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
q = q.only(*only)
|
||||
|
||||
if q.count() != len(ids):
|
||||
raise errors.bad_request.InvalidModelId(ids=model_ids)
|
||||
|
||||
if return_models:
|
||||
return list(q)
|
||||
|
||||
@classmethod
|
||||
def publish_model(
|
||||
cls,
|
||||
|
@ -92,3 +92,6 @@ class Model(AttributedDocument):
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
return self.company or self.company_origin or ""
|
||||
|
@ -35,6 +35,12 @@
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"company_id": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"model_event": {
|
||||
"type": "boolean"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -968,5 +968,5 @@ class PrePopulate:
|
||||
ev["task"] = task_id
|
||||
ev["company_id"] = company_id
|
||||
cls.event_bll.add_events(
|
||||
company_id, events=events, worker="", allow_locked_tasks=True
|
||||
company_id, events=events, worker="", allow_locked=True
|
||||
)
|
||||
|
@ -405,13 +405,27 @@ add {
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
"999.0": ${add."2.1"} {
|
||||
request.properties {
|
||||
model_event {
|
||||
type: boolean
|
||||
description: If set then the event is for a model. Otherwise for a task. Cannot be used with task log events. If used in batch then all the events should be marked the same
|
||||
default: false
|
||||
}
|
||||
allow_locked {
|
||||
type: boolean
|
||||
description: Allow adding events to published tasks or models
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
add_batch {
|
||||
"2.1" {
|
||||
description: "Adds a batch of events in a single call (json-lines format, stream-friendly)"
|
||||
batch_request: {
|
||||
action: add
|
||||
version: 1.5
|
||||
version: 2.1
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
@ -422,10 +436,16 @@ add_batch {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${add_batch."2.1"} {
|
||||
batch_request: {
|
||||
action: add
|
||||
version: 999.0
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_for_task {
|
||||
"2.1" {
|
||||
description: "Delete all task event. *This cannot be undone!*"
|
||||
description: "Delete all task events. *This cannot be undone!*"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
@ -454,6 +474,37 @@ delete_for_task {
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_for_model {
|
||||
"999.0" {
|
||||
description: "Delete all model events. *This cannot be undone!*"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
model
|
||||
]
|
||||
properties {
|
||||
model {
|
||||
type: string
|
||||
description: "Model ID"
|
||||
}
|
||||
allow_locked {
|
||||
type: boolean
|
||||
description: "Allow deleting events even if the model is locked"
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
type: boolean
|
||||
description: "Number of deleted events"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images {
|
||||
"2.1" {
|
||||
description: "Get all debug images of a task"
|
||||
@ -548,6 +599,13 @@ debug_images {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${debug_images."2.14"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
plots {
|
||||
"2.20" {
|
||||
@ -583,6 +641,13 @@ plots {
|
||||
}
|
||||
response {"$ref": "#/definitions/plots_response"}
|
||||
}
|
||||
"999.0": ${plots."2.20"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model plots. Otherwise task plots
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_debug_image_sample {
|
||||
"2.12": {
|
||||
@ -626,6 +691,13 @@ get_debug_image_sample {
|
||||
default: true
|
||||
}
|
||||
}
|
||||
"999.0": ${get_debug_image_sample."2.20"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model debug images. Otherwise task debug images
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
next_debug_image_sample {
|
||||
"2.12": {
|
||||
@ -651,6 +723,18 @@ next_debug_image_sample {
|
||||
}
|
||||
response {"$ref": "#/definitions/debug_image_sample_response"}
|
||||
}
|
||||
"999.0": ${next_debug_image_sample."2.12"} {
|
||||
request.properties.next_iteration {
|
||||
type: boolean
|
||||
default: false
|
||||
description: If set then navigate to the next/previous iteration
|
||||
}
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model debug images. Otherwise task debug images
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_plot_sample {
|
||||
"2.20": {
|
||||
@ -692,6 +776,13 @@ get_plot_sample {
|
||||
}
|
||||
response {"$ref": "#/definitions/plot_sample_response"}
|
||||
}
|
||||
"999.0": ${get_debug_image_sample."2.20"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model plots. Otherwise task plots
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
next_plot_sample {
|
||||
"2.20": {
|
||||
@ -717,6 +808,18 @@ next_plot_sample {
|
||||
}
|
||||
response {"$ref": "#/definitions/plot_sample_response"}
|
||||
}
|
||||
"999.0": ${next_plot_sample."2.20"} {
|
||||
request.properties.next_iteration {
|
||||
type: boolean
|
||||
default: false
|
||||
description: If set then navigate to the next/previous iteration
|
||||
}
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model plots. Otherwise task plots
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_metrics{
|
||||
"2.7": {
|
||||
@ -749,6 +852,13 @@ get_task_metrics{
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_task_metrics."2.7"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then get metrics from model events. Otherwise from task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_log {
|
||||
"1.5" {
|
||||
@ -966,6 +1076,13 @@ get_task_events {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_task_events."2.1"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then get retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
download_task_log {
|
||||
@ -1067,6 +1184,13 @@ get_task_plots {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"999.0": ${get_task_plots."2.16"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_plots {
|
||||
"2.1" {
|
||||
@ -1124,6 +1248,13 @@ get_multi_task_plots {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"999.0": ${get_multi_task_plots."2.16"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_vector_metrics_and_variants {
|
||||
"2.1" {
|
||||
@ -1152,6 +1283,13 @@ get_vector_metrics_and_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_vector_metrics_and_variants."2.1"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
vector_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
@ -1190,6 +1328,13 @@ vector_metrics_iter_histogram {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${vector_metrics_iter_histogram."2.1"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
@ -1243,6 +1388,13 @@ scalar_metrics_iter_histogram {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${scalar_metrics_iter_histogram."2.14"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
multi_task_scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
@ -1282,6 +1434,13 @@ multi_task_scalar_metrics_iter_histogram {
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
"999.0": ${multi_task_scalar_metrics_iter_histogram."2.1"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_single_value_metrics {
|
||||
"2.20" {
|
||||
@ -1331,6 +1490,13 @@ get_task_single_value_metrics {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_task_single_value_metrics."2.20"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_latest_scalar_values {
|
||||
"2.1" {
|
||||
@ -1410,6 +1576,13 @@ get_scalar_metrics_and_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_scalar_metrics_and_variants."2.1"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_scalar_metric_data {
|
||||
"2.1" {
|
||||
@ -1459,6 +1632,13 @@ get_scalar_metric_data {
|
||||
default: false
|
||||
}
|
||||
}
|
||||
"999.0": ${get_scalar_metric_data."2.16"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_metrics_iter_raw {
|
||||
"2.16" {
|
||||
@ -1523,6 +1703,13 @@ scalar_metrics_iter_raw {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${scalar_metrics_iter_raw."2.16"} {
|
||||
model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
clear_scroll {
|
||||
"2.18" {
|
||||
|
@ -2,7 +2,7 @@ import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional
|
||||
from typing import Sequence, Optional, Union, Tuple
|
||||
|
||||
import attr
|
||||
import jsonmodels.fields
|
||||
@ -33,13 +33,36 @@ from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json, extract_properties_to_lists
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
model_bll = ModelBLL()
|
||||
|
||||
|
||||
def _assert_task_or_model_exists(
|
||||
company_id: str, task_ids: Union[str, Sequence[str]], model_events: bool
|
||||
) -> Union[Sequence[Model], Sequence[Task]]:
|
||||
if model_events:
|
||||
return model_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids,
|
||||
allow_public=True,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
)
|
||||
|
||||
return task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids,
|
||||
allow_public=True,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.add")
|
||||
@ -47,7 +70,7 @@ def add(call: APICall, company_id, _):
|
||||
data = call.data.copy()
|
||||
allow_locked = data.pop("allow_locked", False)
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id, [data], call.worker, allow_locked_tasks=allow_locked
|
||||
company_id, [data], call.worker, allow_locked=allow_locked
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
@ -58,7 +81,12 @@ def add_batch(call: APICall, company_id, _):
|
||||
if events is None or len(events) == 0:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
|
||||
added, err_count, err_info = event_bll.add_events(
|
||||
company_id,
|
||||
events,
|
||||
call.worker,
|
||||
allow_locked=events[0].get("allow_locked", False),
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
|
||||
|
||||
@ -225,12 +253,13 @@ def download_task_log(call, company_id, _):
|
||||
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
|
||||
def get_vector_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
model_events = call.data["model_events"]
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.get_index_company(), task_id, EventType.metrics_vector
|
||||
task_or_model.get_index_company(), task_id, EventType.metrics_vector
|
||||
)
|
||||
)
|
||||
|
||||
@ -238,12 +267,13 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
||||
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
|
||||
def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
model_events = call.data["model_events"]
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.get_index_company(), task_id, EventType.metrics_scalar
|
||||
task_or_model.get_index_company(), task_id, EventType.metrics_scalar
|
||||
)
|
||||
)
|
||||
|
||||
@ -255,13 +285,14 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
)
|
||||
def vector_metrics_iter_histogram(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
model_events = call.data["model_events"]
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(
|
||||
task.get_index_company(), task_id, metric, variant
|
||||
task_or_model.get_index_company(), task_id, metric, variant
|
||||
)
|
||||
call.result.data = dict(
|
||||
metric=metric, variant=variant, vectors=vectors, iterations=iterations
|
||||
@ -286,11 +317,10 @@ def make_response(
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
|
||||
def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
def get_task_events(_, company_id, request: TaskEventsRequest):
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=request.model_events,
|
||||
)[0]
|
||||
|
||||
key = ScalarKeyEnum.iter
|
||||
@ -322,7 +352,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
@ -336,7 +366,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
key=ScalarKeyEnum.iter,
|
||||
@ -365,18 +395,20 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_or_model.get_index_company(),
|
||||
task_id,
|
||||
event_type=EventType.metrics_scalar,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
model_events=model_events,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
@ -398,7 +430,7 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
index_company, task_id
|
||||
)
|
||||
last_iters = event_bll.get_last_iters(
|
||||
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
|
||||
company_id=index_company, event_type=EventType.all, task_id=task_id, iters=1
|
||||
).get(task_id)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
@ -417,11 +449,11 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
def scalar_metrics_iter_histogram(
|
||||
call, company_id, request: ScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, request.task, allow_public=True, only=("company", "company_origin")
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events
|
||||
)[0]
|
||||
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
||||
task.get_index_company(),
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=request.task,
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
@ -429,24 +461,55 @@ def scalar_metrics_iter_histogram(
|
||||
call.result.data = metrics
|
||||
|
||||
|
||||
def _get_task_or_model_index_company(
|
||||
company_id: str, task_ids: Sequence[str], model_events=False,
|
||||
) -> Tuple[str, Sequence[Task]]:
|
||||
"""
|
||||
Verify that all tasks exists and belong to store data in the same company index
|
||||
Return company and tasks
|
||||
"""
|
||||
tasks_or_models = _assert_task_or_model_exists(
|
||||
company_id, task_ids, model_events=model_events,
|
||||
)
|
||||
|
||||
unique_ids = set(task_ids)
|
||||
if len(tasks_or_models) < len(unique_ids):
|
||||
invalid = tuple(unique_ids - {t.id for t in tasks_or_models})
|
||||
error_cls = (
|
||||
errors.bad_request.InvalidModelId
|
||||
if model_events
|
||||
else errors.bad_request.InvalidTaskId
|
||||
)
|
||||
raise error_cls(company=company_id, ids=invalid)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks_or_models}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
return companies.pop(), tasks_or_models
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.multi_task_scalar_metrics_iter_histogram",
|
||||
request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
|
||||
)
|
||||
def multi_task_scalar_metrics_iter_histogram(
|
||||
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
|
||||
call, company_id, request: MultiTaskScalarMetricsIterHistogramRequest
|
||||
):
|
||||
task_ids = req_model.tasks
|
||||
task_ids = request.tasks
|
||||
if isinstance(task_ids, str):
|
||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||
# Note, bll already validates task ids as it needs their names
|
||||
company, tasks_or_models = _get_task_or_model_index_company(
|
||||
company_id, task_ids, request.model_events
|
||||
)
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
samples=req_model.samples,
|
||||
allow_public=True,
|
||||
key=req_model.key,
|
||||
company_id=company,
|
||||
tasks=tasks_or_models,
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
)
|
||||
)
|
||||
|
||||
@ -455,21 +518,11 @@ def multi_task_scalar_metrics_iter_histogram(
|
||||
def get_task_single_value_metrics(
|
||||
call, company_id: str, request: SingleValueMetricsRequest
|
||||
):
|
||||
task_ids = call.data["tasks"]
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
company, tasks_or_models = _get_task_or_model_index_company(
|
||||
company_id, request.tasks, request.model_events
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids)
|
||||
res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models)
|
||||
call.result.data = dict(
|
||||
tasks=[{"task": task, "values": values} for task, values in res.items()]
|
||||
)
|
||||
@ -481,22 +534,11 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=company_id,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids)
|
||||
|
||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
next(iter(companies)),
|
||||
company,
|
||||
task_ids,
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
@ -504,7 +546,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
tasks = {t.id: t.name for t in tasks_or_models}
|
||||
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=iters, tasks=tasks
|
||||
@ -519,36 +561,29 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
|
||||
def get_multi_task_plots(call, company_id, req_model):
|
||||
def get_multi_task_plots(call, company_id, _):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
only=("id", "name", "company", "company_origin"),
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
company, tasks_or_models = _get_task_or_model_index_company(
|
||||
company_id, task_ids, model_events
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.get_task_events(
|
||||
next(iter(companies)),
|
||||
company,
|
||||
task_ids,
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
model_events=model_events,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
tasks = {t.id: t.name for t in tasks_or_models}
|
||||
|
||||
return_events = _get_top_iter_unique_events_per_task(
|
||||
result.events, max_iters=iters, tasks=tasks
|
||||
@ -615,17 +650,18 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=request.model_events
|
||||
)[0]
|
||||
result = event_bll.get_task_plots(
|
||||
task.get_index_company(),
|
||||
task_or_model.get_index_company(),
|
||||
tasks=[task_id],
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
model_events=request.model_events,
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@ -651,21 +687,11 @@ def task_plots(call, company_id, request: MetricEventsRequest):
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=list(task_metrics),
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
company, _ = _get_task_or_model_index_company(
|
||||
company_id, task_ids=list(task_metrics), model_events=request.model_events
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.plots_iterator.get_task_events(
|
||||
company_id=next(iter(companies)),
|
||||
company_id=company,
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@ -730,17 +756,19 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
model_events = call.data.get("model_events", False)
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
tasks_or_model = _assert_task_or_model_exists(
|
||||
company_id, task_id, model_events=model_events,
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
tasks_or_model.get_index_company(),
|
||||
task_id,
|
||||
event_type=EventType.metrics_image,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
model_events=model_events,
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@ -768,21 +796,12 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=list(task_metrics),
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
company, _ = _get_task_or_model_index_company(
|
||||
company_id, task_ids=list(task_metrics), model_events=request.model_events
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
if len(companies) > 1:
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=next(iter(companies)),
|
||||
company_id=company,
|
||||
task_metrics=task_metrics,
|
||||
iter_count=request.iters,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@ -811,11 +830,11 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
|
||||
request_data_model=GetHistorySampleRequest,
|
||||
)
|
||||
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.debug_image_sample_history.get_sample_for_variant(
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
metric=request.metric,
|
||||
variant=request.variant,
|
||||
@ -833,11 +852,11 @@ def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
request_data_model=NextHistorySampleRequest,
|
||||
)
|
||||
def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.debug_image_sample_history.get_next_sample(
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
state_id=request.scroll_id,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@ -849,11 +868,11 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
|
||||
"events.get_plot_sample", request_data_model=GetHistorySampleRequest,
|
||||
)
|
||||
def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.plot_sample_history.get_sample_for_variant(
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
metric=request.metric,
|
||||
variant=request.variant,
|
||||
@ -869,11 +888,11 @@ def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
"events.next_plot_sample", request_data_model=NextHistorySampleRequest,
|
||||
)
|
||||
def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=[request.task], allow_public=True, only=("company",)
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.plot_sample_history.get_next_sample(
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
state_id=request.scroll_id,
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
@ -883,14 +902,11 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
|
||||
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
)[0]
|
||||
company, _ = _get_task_or_model_index_company(
|
||||
company_id, request.tasks, model_events=request.model_events
|
||||
)
|
||||
res = event_bll.metrics.get_task_metrics(
|
||||
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
|
||||
company, task_ids=request.tasks, event_type=request.event_type
|
||||
)
|
||||
call.result.data = {
|
||||
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
|
||||
@ -898,7 +914,7 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||
def delete_for_task(call, company_id, req_model):
|
||||
def delete_for_task(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
@ -910,6 +926,19 @@ def delete_for_task(call, company_id, req_model):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.delete_for_model", required_fields=["model"])
|
||||
def delete_for_model(call: APICall, company_id: str, _):
|
||||
model_id = call.data["model"]
|
||||
allow_locked = call.data.get("allow_locked", False)
|
||||
|
||||
model_bll.assert_exists(company_id, model_id, return_models=False)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.delete_task_events(
|
||||
company_id, model_id, allow_locked=allow_locked, model=True
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.clear_task_log")
|
||||
def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest):
|
||||
task_id = request.task
|
||||
@ -1004,17 +1033,13 @@ def scalar_metrics_iter_raw(
|
||||
request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
)[0]
|
||||
|
||||
task_or_model = _assert_task_or_model_exists(company_id, task_id, model_events=request.model_events)[0]
|
||||
metric_variants = _get_metric_variants_from_request([request.metric])
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
@ -1030,7 +1055,7 @@ def scalar_metrics_iter_raw(
|
||||
for iteration in range(0, math.ceil(batch_size / 10_000)):
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=min(batch_size, 10_000),
|
||||
navigate_earlier=False,
|
||||
|
@ -6,17 +6,24 @@ from typing import Sequence, Optional, Tuple
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors.errors.bad_request import EventsNotAdded
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||
)
|
||||
return self.create_temp("tasks", **task_input)
|
||||
return self.create_temp("tasks", delete_paramse=self.delete_params, **task_input)
|
||||
|
||||
def _temp_model(self, name="test model events", **kwargs):
|
||||
self.update_missing(kwargs, name=name, uri="file:///a/b", labels={})
|
||||
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _create_task_event(type_, task, iteration, **kwargs):
|
||||
@ -172,6 +179,42 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
|
||||
def test_model_events(self):
|
||||
model = self._temp_model(ready=False)
|
||||
|
||||
# task log events are not allowed
|
||||
log_event = self._create_task_event(
|
||||
"log",
|
||||
task=model,
|
||||
iteration=0,
|
||||
msg=f"This is a log message",
|
||||
model_event=True,
|
||||
)
|
||||
with self.api.raises(errors.bad_request.EventsNotAdded):
|
||||
self.send(log_event)
|
||||
|
||||
# send metric events and check that model data always have iteration 0 and only last data is saved
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", model, iteration),
|
||||
"metric": f"Metric{metric_idx}",
|
||||
"variant": f"Variant{variant_idx}",
|
||||
"value": iteration,
|
||||
"model_event": True,
|
||||
}
|
||||
for iteration in range(2)
|
||||
for metric_idx in range(5)
|
||||
for variant_idx in range(5)
|
||||
]
|
||||
self.send_batch(events)
|
||||
data = self.api.events.scalar_metrics_iter_histogram(task=model, model_events=True)
|
||||
self.assertEqual(list(data), [f"Metric{idx}" for idx in range(5)])
|
||||
metric_data = data.Metric0
|
||||
self.assertEqual(list(metric_data), [f"Variant{idx}" for idx in range(5)])
|
||||
variant_data = metric_data.Variant0
|
||||
self.assertEqual(variant_data.x, [0])
|
||||
self.assertEqual(variant_data.y, [1.0])
|
||||
|
||||
def test_error_events(self):
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
@ -555,7 +598,8 @@ class TestTaskEvents(TestService):
|
||||
return data
|
||||
|
||||
def send(self, event):
|
||||
self.api.send("events.add", event)
|
||||
_, data = self.api.send("events.add", event)
|
||||
return data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -50,10 +50,12 @@ class TestTasksResetDelete(TestService):
|
||||
self.assertEqual(res.urls.artifact_urls, [])
|
||||
|
||||
task = self.new_task()
|
||||
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task)
|
||||
published_model_urls, draft_model_urls = self.create_task_models(task)
|
||||
artifact_urls = self.send_artifacts(task)
|
||||
event_urls = self.send_debug_image_events(task)
|
||||
event_urls.update(self.send_plot_events(task))
|
||||
event_urls.update(self.send_model_events(model))
|
||||
res = self.assert_delete_task(task, force=True, return_file_urls=True)
|
||||
self.assertEqual(set(res.urls.model_urls), draft_model_urls)
|
||||
self.assertEqual(set(res.urls.event_urls), event_urls)
|
||||
@ -120,10 +122,12 @@ class TestTasksResetDelete(TestService):
|
||||
self, **kwargs
|
||||
) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]:
|
||||
task = self.new_task(**kwargs)
|
||||
(_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task, **kwargs)
|
||||
published_model_urls, draft_model_urls = self.create_task_models(task, **kwargs)
|
||||
artifact_urls = self.send_artifacts(task)
|
||||
event_urls = self.send_debug_image_events(task)
|
||||
event_urls.update(self.send_plot_events(task))
|
||||
event_urls.update(self.send_model_events(model))
|
||||
return task, (published_model_urls, draft_model_urls), artifact_urls, event_urls
|
||||
|
||||
def assert_delete_task(self, task_id, force=False, return_file_urls=False):
|
||||
@ -137,15 +141,17 @@ class TestTasksResetDelete(TestService):
|
||||
self.assertEqual(tasks, [])
|
||||
return res
|
||||
|
||||
def create_task_models(self, task, **kwargs) -> Tuple[Set[str], Set[str]]:
|
||||
def create_task_models(self, task, **kwargs) -> Tuple:
|
||||
"""
|
||||
Update models from task and return only non public models
|
||||
"""
|
||||
model_ready = self.new_model(uri="ready", **kwargs)
|
||||
model_not_ready = self.new_model(uri="not_ready", ready=False, **kwargs)
|
||||
ready_uri = "ready"
|
||||
not_ready_uri = "not_ready"
|
||||
model_ready = self.new_model(uri=ready_uri, **kwargs)
|
||||
model_not_ready = self.new_model(uri=not_ready_uri, ready=False, **kwargs)
|
||||
self.api.models.edit(model=model_not_ready, task=task)
|
||||
self.api.models.edit(model=model_ready, task=task)
|
||||
return {"ready"}, {"not_ready"}
|
||||
return (model_ready, {ready_uri}), (model_not_ready, {not_ready_uri})
|
||||
|
||||
def send_artifacts(self, task) -> Set[str]:
|
||||
"""
|
||||
@ -159,6 +165,20 @@ class TestTasksResetDelete(TestService):
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts)
|
||||
return {"test2"}
|
||||
|
||||
def send_model_events(self, model) -> Set[str]:
|
||||
url1 = "http://link1"
|
||||
url2 = "http://link2"
|
||||
events = [
|
||||
self.create_event(
|
||||
model, "training_debug_image", 0, url=url1, model_event=True
|
||||
),
|
||||
self.create_event(
|
||||
model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True
|
||||
)
|
||||
]
|
||||
self.send_batch(events)
|
||||
return {url1, url2}
|
||||
|
||||
def send_debug_image_events(self, task) -> Set[str]:
|
||||
events = [
|
||||
self.create_event(
|
||||
|
Loading…
Reference in New Issue
Block a user