Model events are fully supported

This commit is contained in:
allegroai 2023-05-25 19:17:40 +03:00
parent 2e4e060a82
commit 58465fbc17
19 changed files with 341 additions and 228 deletions

View File

@ -29,6 +29,10 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True) include_subprojects = fields.BoolField(default=True)
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
model_metrics = fields.BoolField(default=False)
class GetParamsRequest(ProjectOrNoneRequest): class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0) page = fields.IntField(default=0)
page_size = fields.IntField(default=500) page_size = fields.IntField(default=500)

View File

@ -66,6 +66,7 @@ class GetTasksDataRequest(Base):
plots: EventsRequest = EmbeddedField(EventsRequest) plots: EventsRequest = EmbeddedField(EventsRequest)
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram) scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram)
allow_public = BoolField(default=True) allow_public = BoolField(default=True)
model_events: bool = BoolField(default=False)
class GetAllRequest(Base): class GetAllRequest(Base):

View File

@ -30,6 +30,7 @@ from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIt
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.model import ModelBLL
from apiserver.bll.util import parallel_chunked_decorator from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils from apiserver.database import utils as dbutils
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
@ -250,7 +251,6 @@ class EventBLL(object):
task_or_model_ids.add(task_or_model_id) task_or_model_ids.add(task_or_model_id)
if ( if (
iter is not None iter is not None
and not model_events
and event.get("metric") not in self._skip_iteration_for_metric and event.get("metric") not in self._skip_iteration_for_metric
): ):
task_iteration[task_or_model_id] = max( task_iteration[task_or_model_id] = max(
@ -261,11 +261,10 @@ class EventBLL(object):
self._update_last_metric_events_for_task( self._update_last_metric_events_for_task(
last_events=task_last_events[task_or_model_id], event=event, last_events=task_last_events[task_or_model_id], event=event,
) )
if event_type == EventType.metrics_scalar.value: if event_type == EventType.metrics_scalar.value:
self._update_last_scalar_events_for_task( self._update_last_scalar_events_for_task(
last_events=task_last_scalar_events[task_or_model_id], last_events=task_last_scalar_events[task_or_model_id], event=event,
event=event, )
)
actions.append(es_action) actions.append(es_action)
@ -303,12 +302,21 @@ class EventBLL(object):
else: else:
errors_per_type["Error when indexing events batch"] += 1 errors_per_type["Error when indexing events batch"] += 1
if not model_events: remaining_tasks = set()
remaining_tasks = set() now = datetime.utcnow()
now = datetime.utcnow() for task_or_model_id in task_or_model_ids:
for task_or_model_id in task_or_model_ids: # Update related tasks. For reasons of performance, we prefer to update
# Update related tasks. For reasons of performance, we prefer to update # all of them and not only those who's events were successful
# all of them and not only those who's events were successful if model_events:
ModelBLL.update_statistics(
company_id=company_id,
model_id=task_or_model_id,
last_iteration_max=task_iteration.get(task_or_model_id),
last_scalar_events=task_last_scalar_events.get(
task_or_model_id
),
)
else:
updated = self._update_task( updated = self._update_task(
company_id=company_id, company_id=company_id,
task_id=task_or_model_id, task_id=task_or_model_id,
@ -319,15 +327,14 @@ class EventBLL(object):
), ),
last_events=task_last_events.get(task_or_model_id), last_events=task_last_events.get(task_or_model_id),
) )
if not updated: if not updated:
remaining_tasks.add(task_or_model_id) remaining_tasks.add(task_or_model_id)
continue continue
if remaining_tasks: if remaining_tasks:
TaskBLL.set_last_update( TaskBLL.set_last_update(
remaining_tasks, company_id, last_update=now remaining_tasks, company_id, last_update=now
) )
# this is for backwards compatibility with streaming bulk throwing exception on those # this is for backwards compatibility with streaming bulk throwing exception on those
invalid_iterations_count = errors_per_type.get(invalid_iteration_error) invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
@ -484,7 +491,9 @@ class EventBLL(object):
) )
def _get_event_id(self, event): def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.event_id_fields if field in event) id_values = (
str(event[field]) for field in self.event_id_fields if field in event
)
return hashlib.md5("-".join(id_values).encode()).hexdigest() return hashlib.md5("-".join(id_values).encode()).hexdigest()
def scroll_task_events( def scroll_task_events(
@ -556,9 +565,7 @@ class EventBLL(object):
must.append(get_metric_variants_condition(metric_variants)) must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}} query = {"bool": {"must": must}}
search_args = dict( search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
es=self.es, company_id=company_id, event_type=event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts( max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args, query=query, **search_args,
) )
@ -586,7 +593,7 @@ class EventBLL(object):
"events": { "events": {
"top_hits": { "top_hits": {
"sort": {"iter": {"order": "desc"}}, "sort": {"iter": {"order": "desc"}},
"size": last_iterations_per_plot "size": last_iterations_per_plot,
} }
} }
}, },
@ -597,11 +604,7 @@ class EventBLL(object):
} }
with translate_errors_context(): with translate_errors_context():
es_response = search_company_events( es_response = search_company_events(body=es_req, ignore=404, **search_args)
body=es_req,
ignore=404,
**search_args,
)
aggs_result = es_response.get("aggregations") aggs_result = es_response.get("aggregations")
if not aggs_result: if not aggs_result:
@ -614,9 +617,7 @@ class EventBLL(object):
for hit in variants_bucket["events"]["hits"]["hits"] for hit in variants_bucket["events"]["hits"]["hits"]
] ]
self.uncompress_plots(events) self.uncompress_plots(events)
return TaskEventsResult( return TaskEventsResult(events=events, total_events=len(events))
events=events, total_events=len(events)
)
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]: def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
""" """
@ -731,12 +732,7 @@ class EventBLL(object):
if not company_ids: if not company_ids:
return TaskEventsResult() return TaskEventsResult()
task_ids = ( task_ids = [task_id] if isinstance(task_id, str) else task_id
[task_id]
if isinstance(task_id, str)
else task_id
)
must = [] must = []
if metrics: if metrics:
@ -967,7 +963,7 @@ class EventBLL(object):
event_type: EventType, event_type: EventType,
task_id: Union[str, Sequence[str]], task_id: Union[str, Sequence[str]],
iters: int, iters: int,
metrics: MetricVariants = None metrics: MetricVariants = None,
) -> Mapping[str, Sequence]: ) -> Mapping[str, Sequence]:
company_ids = [company_id] if isinstance(company_id, str) else company_id company_ids = [company_id] if isinstance(company_id, str) else company_id
company_ids = [ company_ids = [

View File

@ -5,7 +5,7 @@ from mongoengine import Q
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse from apiserver.apimodels.models import ModelTaskPublishResponse
from apiserver.bll.task.utils import deleted_prefix from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus from apiserver.database.model.task.task import Task, TaskStatus
@ -28,11 +28,7 @@ class ModelBLL:
@staticmethod @staticmethod
def assert_exists( def assert_exists(
company_id, company_id, model_ids, only=None, allow_public=False, return_models=True,
model_ids,
only=None,
allow_public=False,
return_models=True,
) -> Optional[Sequence[Model]]: ) -> Optional[Sequence[Model]]:
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
ids = set(model_ids) ids = set(model_ids)
@ -179,12 +175,36 @@ class ModelBLL:
"labels_count": {"$size": {"$objectToArray": "$labels"}} "labels_count": {"$size": {"$objectToArray": "$labels"}}
} }
}, },
{ {"$project": {"labels_count": 1}},
"$project": {"labels_count": 1},
},
] ]
) )
return { return {r.pop("_id"): r for r in result}
r.pop("_id"): r
for r in result @staticmethod
} def update_statistics(
company_id: str,
model_id: str,
last_iteration_max: int = None,
last_scalar_events: Dict[str, Dict[str, dict]] = None,
):
updates = {"last_update": datetime.utcnow()}
if last_iteration_max is not None:
updates.update(max__last_iteration=last_iteration_max)
raw_updates = {}
if last_scalar_events is not None:
raw_updates = {}
if last_scalar_events is not None:
get_last_metric_updates(
task_id=model_id,
last_scalar_events=last_scalar_events,
raw_updates=raw_updates,
extra_updates=updates,
model_events=True,
)
ret = Model.objects(id=model_id).update_one(**updates)
if ret and raw_updates:
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
return ret

View File

@ -209,7 +209,11 @@ class ProjectQueries:
@classmethod @classmethod
def get_unique_metric_variants( def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
model_metrics: bool = False,
): ):
pipeline = [ pipeline = [
{ {
@ -246,7 +250,8 @@ class ProjectQueries:
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})}, {"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
] ]
result = Task.aggregate(pipeline) entity_cls = Model if model_metrics else Task
result = entity_cls.aggregate(pipeline)
return [r["metrics"][0] for r in result] return [r["metrics"][0] for r in result]
@classmethod @classmethod

View File

@ -40,6 +40,7 @@ from .utils import (
ChangeStatusRequest, ChangeStatusRequest,
update_project_time, update_project_time,
deleted_prefix, deleted_prefix,
get_last_metric_updates,
) )
log = config.logger(__file__) log = config.logger(__file__)
@ -412,77 +413,12 @@ class TaskBLL:
raw_updates = {} raw_updates = {}
if last_scalar_events is not None: if last_scalar_events is not None:
max_values = config.get("services.tasks.max_last_metrics", 2000) get_last_metric_updates(
total_metrics = set() task_id=task_id,
if max_values: last_scalar_events=last_scalar_events,
query = dict(id=task_id) raw_updates=raw_updates,
to_add = sum(len(v) for m, v in last_scalar_events.items()) extra_updates=extra_updates,
if to_add <= max_values: )
query[f"unique_metrics__{max_values-to_add}__exists"] = True
task = Task.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}",]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = (
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
)
if max_values:
if (
len(total_metrics) >= max_values
and metric not in total_metrics
):
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics
if last_events is not None: if last_events is not None:

View File

@ -5,7 +5,9 @@ import attr
import six import six
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options from apiserver.database.utils import get_options
@ -167,7 +169,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
def get_task_for_update( def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task: ) -> Task:
""" """
Loads only task id and return the task only if it is updatable (status == 'created') Loads only task id and return the task only if it is updatable (status == 'created')
@ -189,9 +191,88 @@ def get_task_for_update(
return task return task
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True): def update_task(
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
):
now = datetime.utcnow() now = datetime.utcnow()
last_updates = dict(last_change=now, last_changed_by=user_id) last_updates = dict(last_change=now, last_changed_by=user_id)
if set_last_update: if set_last_update:
last_updates.update(last_update=now) last_updates.update(last_update=now)
return task.update(**update_cmds, **last_updates) return task.update(**update_cmds, **last_updates)
def get_last_metric_updates(
task_id: str,
last_scalar_events: dict,
raw_updates: dict,
extra_updates: dict,
model_events: bool = False,
):
max_values = config.get("services.tasks.max_last_metrics", 2000)
total_metrics = set()
if max_values:
query = dict(id=task_id)
to_add = sum(len(v) for m, v in last_scalar_events.items())
if to_add <= max_values:
query[f"unique_metrics__{max_values - to_add}__exists"] = True
db_cls = Model if model_events else Task
task = db_cls.objects(**query).only("unique_metrics").first()
if task and task.unique_metrics:
total_metrics = set(task.unique_metrics)
new_metrics = []
def add_last_metric_conditional_update(
metric_path: str, metric_value, iter_value: int, is_min: bool
):
"""
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
"""
if is_min:
field_prefix = "min"
op = "$gt"
else:
field_prefix = "max"
op = "$lt"
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
condition = {
"$or": [
{"$lte": [f"${value_field}", None]},
{op: [f"${value_field}", metric_value]},
]
}
raw_updates[value_field] = {
"$cond": [condition, metric_value, f"${value_field}"]
}
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
"__", "."
)
raw_updates[value_iteration_field] = {
"$cond": [condition, iter_value, f"${value_iteration_field}"]
}
for metric_key, metric_data in last_scalar_events.items():
for variant_key, variant_data in metric_data.items():
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
if max_values:
if len(total_metrics) >= max_values and metric not in total_metrics:
continue
total_metrics.add(metric)
new_metrics.append(metric)
path = f"last_metrics__{metric_key}__{variant_key}"
for key, value in variant_data.items():
if key in ("min_value", "max_value"):
add_last_metric_conditional_update(
metric_path=path,
metric_value=value,
iter_value=variant_data.get(f"{key}_iter", 0),
is_min=(key == "min_value"),
)
elif key in ("metric", "variant", "value"):
extra_updates[f"set__{path}__{key}"] = value
if new_metrics:
extra_updates["add_to_set__unique_metrics"] = new_metrics

View File

@ -3,6 +3,8 @@ from mongoengine import (
DateTimeField, DateTimeField,
BooleanField, BooleanField,
EmbeddedDocumentField, EmbeddedDocumentField,
IntField,
ListField,
) )
from apiserver.database import Database, strict from apiserver.database import Database, strict
@ -17,12 +19,14 @@ from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.metrics import MetricEvent
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task
class Model(AttributedDocument): class Model(AttributedDocument):
_field_collation_overrides = { _field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale, "metadata.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
} }
meta = { meta = {
@ -67,6 +71,7 @@ class Model(AttributedDocument):
"parent", "parent",
"metadata.*", "metadata.*",
), ),
range_fields=("last_metrics.*", "last_iteration"),
datetime_fields=("last_update",), datetime_fields=("last_update",),
) )
@ -92,6 +97,9 @@ class Model(AttributedDocument):
metadata = SafeMapField( metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
) )
last_iteration = IntField(default=0)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
def get_index_company(self) -> str: def get_index_company(self) -> str:
return self.company or self.company_origin or "" return self.company or self.company_origin or ""

View File

@ -1,6 +1,6 @@
_description: """This service provides a management interface for models (results of training tasks) stored in the system.""" _description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions { _definitions {
include "_common.conf" include "_tasks_common.conf"
multi_field_pattern_data { multi_field_pattern_data {
type: object type: object
properties { properties {
@ -104,6 +104,17 @@ _definitions {
"$ref": "#/definitions/metadata_item" "$ref": "#/definitions/metadata_item"
} }
} }
last_iteration {
description: "Last iteration reported for this model"
type: integer
}
last_metrics {
description: "Last metric variants (hash to events), one for each metric hash"
type: object
additionalProperties {
"$ref": "#/definitions/last_metrics_variants"
}
}
stats { stats {
description: "Model statistics" description: "Model statistics"
type: object type: object

View File

@ -898,6 +898,13 @@ get_unique_metric_variants {
} }
} }
} }
"999.0": ${get_unique_metric_variants."2.13"} {
request.properties.model_metrics {
description: If set to true then bring unique metric and variant names from the project models otherwise from the project tasks
type: boolean
default: false
}
}
} }
get_hyperparam_values { get_hyperparam_values {
"2.13" { "2.13" {

View File

@ -568,6 +568,13 @@ get_task_data {
} }
} }
} }
"999.0": ${get_task_data."2.23"} {
request.properties.model_events {
type: boolean
description: If set then the retrieving model events. Otherwise task events
default: false
}
}
} }
get_all_ex { get_all_ex {
"2.23" { "2.23" {

View File

@ -561,7 +561,6 @@ def _get_multitask_plots(
metrics: MetricVariants = None, metrics: MetricVariants = None,
scroll_id=None, scroll_id=None,
no_scroll=True, no_scroll=True,
model_events=False,
) -> Tuple[dict, int, str]: ) -> Tuple[dict, int, str]:
task_names = { task_names = {
t.id: t.name for t in itertools.chain.from_iterable(companies.values()) t.id: t.name for t in itertools.chain.from_iterable(companies.values())
@ -598,7 +597,6 @@ def get_multi_task_plots(call, company_id, _):
last_iters=iters, last_iters=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=no_scroll, no_scroll=no_scroll,
model_events=model_events,
) )
call.result.data = dict( call.result.data = dict(
plots=return_events, plots=return_events,

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Sequence from typing import Sequence, Union
from mongoengine import Q, EmbeddedDocument from mongoengine import Q, EmbeddedDocument
@ -59,6 +59,11 @@ org_bll = OrgBLL()
project_bll = ProjectBLL() project_bll = ProjectBLL()
def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
conform_output_tags(call, model_data)
unescape_metadata(call, model_data)
@endpoint("models.get_by_id", required_fields=["model"]) @endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _): def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"] model_id = call.data["model"]
@ -74,8 +79,7 @@ def get_by_id(call: APICall, company_id, _):
raise errors.bad_request.InvalidModelId( raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id, "no such public or company model", id=model_id, company=company_id,
) )
conform_output_tags(call, models[0]) conform_model_data(call, models[0])
unescape_metadata(call, models[0])
call.result.data = {"model": models[0]} call.result.data = {"model": models[0]}
@ -102,8 +106,7 @@ def get_by_task_id(call: APICall, company_id, _):
"no such public or company model", id=model_id, company=company_id, "no such public or company model", id=model_id, company=company_id,
) )
model_dict = model.to_proper_dict() model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict) conform_model_data(call, model_dict)
unescape_metadata(call, model_dict)
call.result.data = {"model": model_dict} call.result.data = {"model": model_dict}
@ -119,8 +122,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
allow_public=request.allow_public, allow_public=request.allow_public,
ret_params=ret_params, ret_params=ret_params,
) )
conform_output_tags(call, models) conform_model_data(call, models)
unescape_metadata(call, models)
if not request.include_stats: if not request.include_stats:
call.result.data = {"models": models, **ret_params} call.result.data = {"models": models, **ret_params}
@ -142,8 +144,7 @@ def get_by_id_ex(call: APICall, company_id, _):
models = Model.get_many_with_join( models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True company=company_id, query_dict=call.data, allow_public=True
) )
conform_output_tags(call, models) conform_model_data(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models} call.result.data = {"models": models}
@ -159,8 +160,7 @@ def get_all(call: APICall, company_id, _):
allow_public=True, allow_public=True,
ret_params=ret_params, ret_params=ret_params,
) )
conform_output_tags(call, models) conform_model_data(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params} call.result.data = {"models": models, **ret_params}
@ -428,8 +428,7 @@ def edit(call: APICall, company_id, _):
_reset_cached_tags(company_id, projects=[new_project, model.project]) _reset_cached_tags(company_id, projects=[new_project, model.project])
else: else:
_update_cached_tags(company_id, project=model.project, fields=fields) _update_cached_tags(company_id, project=model.project, fields=fields)
conform_output_tags(call, fields) conform_model_data(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else: else:
call.result.data_model = UpdateResponse(updated=0) call.result.data_model = UpdateResponse(updated=0)
@ -461,8 +460,7 @@ def _update_model(call: APICall, company_id, model_id=None):
_update_cached_tags( _update_cached_tags(
company_id, project=model.project, fields=updated_fields company_id, project=model.project, fields=updated_fields
) )
conform_output_tags(call, updated_fields) conform_model_data(call, updated_fields)
unescape_metadata(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)

View File

@ -15,10 +15,10 @@ from apiserver.apimodels.projects import (
DeleteRequest, DeleteRequest,
MoveRequest, MoveRequest,
MergeRequest, MergeRequest,
ProjectOrNoneRequest,
ProjectRequest, ProjectRequest,
ProjectModelMetadataValuesRequest, ProjectModelMetadataValuesRequest,
ProjectChildrenType, ProjectChildrenType,
GetUniqueMetricsRequest,
) )
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, ProjectQueries from apiserver.bll.project import ProjectBLL, ProjectQueries
@ -345,16 +345,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
@endpoint( @endpoint(
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest "projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest
) )
def get_unique_metric_variants( def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest call: APICall, company_id: str, request: GetUniqueMetricsRequest,
): ):
metrics = project_queries.get_unique_metric_variants( metrics = project_queries.get_unique_metric_variants(
company_id, company_id,
[request.project] if request.project else None, [request.project] if request.project else None,
include_subprojects=request.include_subprojects, include_subprojects=request.include_subprojects,
model_metrics=request.model_metrics,
) )
call.result.data = {"metrics": metrics} call.result.data = {"metrics": metrics}

View File

@ -1,3 +1,5 @@
from typing import Union, Sequence
from mongoengine import Q from mongoengine import Q
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
@ -39,14 +41,18 @@ worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll) queue_bll = QueueBLL(worker_bll)
def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
conform_output_tags(call, queue_data)
unescape_metadata(call, queue_data)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest) @endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
def get_by_id(call: APICall, company_id, request: GetByIdRequest): def get_by_id(call: APICall, company_id, request: GetByIdRequest):
queue = queue_bll.get_by_id( queue = queue_bll.get_by_id(
company_id, request.queue, max_task_entries=request.max_task_entries company_id, request.queue, max_task_entries=request.max_task_entries
) )
queue_dict = queue.to_proper_dict() queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict) conform_queue_data(call, queue_dict)
unescape_metadata(call, queue_dict)
call.result.data = {"queue": queue_dict} call.result.data = {"queue": queue_dict}
@ -85,8 +91,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
max_task_entries=request.max_task_entries, max_task_entries=request.max_task_entries,
ret_params=ret_params, ret_params=ret_params,
) )
conform_output_tags(call, queues) conform_queue_data(call, queues)
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params} call.result.data = {"queues": queues, **ret_params}
@ -102,8 +107,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest):
max_task_entries=request.max_task_entries, max_task_entries=request.max_task_entries,
ret_params=ret_params, ret_params=ret_params,
) )
conform_output_tags(call, queues) conform_queue_data(call, queues)
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params} call.result.data = {"queues": queues, **ret_params}
@ -135,8 +139,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
updated, fields = queue_bll.update( updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data company_id=company_id, queue_id=req_model.queue, **data
) )
conform_output_tags(call, fields) conform_queue_data(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)

View File

@ -17,6 +17,8 @@ from apiserver.apimodels.reports import (
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.project.project_bll import reports_project_name, reports_tag from apiserver.bll.project.project_bll import reports_project_name, reports_tag
from apiserver.database.model.model import Model
from apiserver.services.models import conform_model_data
from apiserver.services.utils import process_include_subprojects, sort_tags_response from apiserver.services.utils import process_include_subprojects, sort_tags_response
from apiserver.bll.organization import OrgBLL from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL
@ -35,7 +37,7 @@ from apiserver.services.events import (
from apiserver.services.tasks import ( from apiserver.services.tasks import (
escape_execution_parameters, escape_execution_parameters,
_hidden_query, _hidden_query,
unprepare_from_saved, conform_task_data,
) )
org_bll = OrgBLL() org_bll = OrgBLL()
@ -178,7 +180,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllRequest):
allow_public=request.allow_public, allow_public=request.allow_public,
ret_params=ret_params, ret_params=ret_params,
) )
unprepare_from_saved(call, tasks) conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params} call.result.data = {"tasks": tasks, **ret_params}
@ -198,18 +200,25 @@ def _get_task_metrics_from_request(
@endpoint("reports.get_task_data") @endpoint("reports.get_task_data")
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
if request.model_events:
entity_cls = Model
conform_data = conform_model_data
else:
entity_cls = Task
conform_data = conform_task_data
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
process_include_subprojects(call_data) process_include_subprojects(call_data)
ret_params = {} ret_params = {}
tasks = Task.get_many_with_join( tasks = entity_cls.get_many_with_join(
company=company_id, company=company_id,
query_dict=call_data, query_dict=call_data,
query=_hidden_query(call_data), query=_hidden_query(call_data),
allow_public=request.allow_public, allow_public=request.allow_public,
ret_params=ret_params, ret_params=ret_params,
) )
unprepare_from_saved(call, tasks) conform_data(call, tasks)
res = {"tasks": tasks, **ret_params} res = {"tasks": tasks, **ret_params}
if not ( if not (
request.debug_images or request.plots or request.scalar_metrics_iter_histogram request.debug_images or request.plots or request.scalar_metrics_iter_histogram
@ -217,7 +226,9 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
return res return res
task_ids = [task["id"] for task in tasks] task_ids = [task["id"] for task in tasks]
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids) companies = _get_task_or_model_index_companies(
company_id, task_ids=task_ids, model_events=request.model_events
)
if request.debug_images: if request.debug_images:
result = event_bll.debug_images_iterator.get_task_events( result = event_bll.debug_images_iterator.get_task_events(
companies={ companies={

View File

@ -182,7 +182,7 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
req_model.task, company_id=company_id, allow_public=True req_model.task, company_id=company_id, allow_public=True
) )
task_dict = task.to_proper_dict() task_dict = task.to_proper_dict()
unprepare_from_saved(call, task_dict) conform_task_data(call, task_dict)
call.result.data = {"task": task_dict} call.result.data = {"task": task_dict}
@ -231,7 +231,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq):
allow_public=request.allow_public, allow_public=request.allow_public,
ret_params=ret_params, ret_params=ret_params,
) )
unprepare_from_saved(call, tasks) conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params} call.result.data = {"tasks": tasks, **ret_params}
@ -245,7 +245,7 @@ def get_by_id_ex(call: APICall, company_id, _):
company=company_id, query_dict=call_data, allow_public=True, company=company_id, query_dict=call_data, allow_public=True,
) )
unprepare_from_saved(call, tasks) conform_task_data(call, tasks)
call.result.data = {"tasks": tasks} call.result.data = {"tasks": tasks}
@ -264,7 +264,7 @@ def get_all(call: APICall, company_id, _):
allow_public=True, allow_public=True,
ret_params=ret_params, ret_params=ret_params,
) )
unprepare_from_saved(call, tasks) conform_task_data(call, tasks)
call.result.data = {"tasks": tasks, **ret_params} call.result.data = {"tasks": tasks, **ret_params}
@ -430,7 +430,7 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
return fields return fields
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]): def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
if isinstance(tasks_data, dict): if isinstance(tasks_data, dict):
tasks_data = [tasks_data] tasks_data = [tasks_data]
@ -608,7 +608,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
company_id, project=task.project, fields=updated_fields company_id, project=task.project, fields=updated_fields
) )
update_project_time(updated_fields.get("project")) update_project_time(updated_fields.get("project"))
unprepare_from_saved(call, updated_fields) conform_task_data(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields)
@ -763,7 +763,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
company_id, project=task.project, fields=fixed_fields company_id, project=task.project, fields=fixed_fields
) )
update_project_time(fields.get("project")) update_project_time(fields.get("project"))
unprepare_from_saved(call, fields) conform_task_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else: else:
call.result.data_model = UpdateResponse(updated=0) call.result.data_model = UpdateResponse(updated=0)

View File

@ -113,66 +113,75 @@ class TestReports(TestService):
def test_reports_task_data(self): def test_reports_task_data(self):
report_task = self._temp_report(name="Rep1") report_task = self._temp_report(name="Rep1")
non_report_task = self._temp_task(name="hello") for model_events in (False, True):
debug_image_events = [ if model_events:
self._create_task_event( non_report_task = self._temp_model(name="hello")
task=non_report_task, event_args = {"model_event": True}
type_="training_debug_image", else:
iteration=1, non_report_task = self._temp_task(name="hello")
metric=f"Metric_{m}", event_args = {}
variant=f"Variant_{v}", debug_image_events = [
url=f"{m}_{v}", self._create_task_event(
task=non_report_task,
type_="training_debug_image",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
url=f"{m}_{v}",
**event_args,
)
for m in range(2)
for v in range(2)
]
plot_events = [
self._create_task_event(
task=non_report_task,
type_="plot",
iteration=1,
metric=f"Metric_{m}",
variant=f"Variant_{v}",
plot_str=f"Hello plot",
**event_args,
)
for m in range(2)
for v in range(2)
]
self.send_batch([*debug_image_events, *plot_events])
res = self.api.reports.get_task_data(
id=[non_report_task], only_fields=["name"], model_events=model_events
) )
for m in range(2) self.assertEqual(len(res.tasks), 1)
for v in range(2) self.assertEqual(res.tasks[0].id, non_report_task)
] self.assertFalse(any(field in res for field in ("plots", "debug_images")))
plot_events = [
self._create_task_event( res = self.api.reports.get_task_data(
task=non_report_task, id=[non_report_task],
type_="plot", only_fields=["name"],
iteration=1, debug_images={"metrics": []},
metric=f"Metric_{m}", plots={"metrics": [{"metric": "Metric_1"}]},
variant=f"Variant_{v}", model_events=model_events,
plot_str=f"Hello plot",
) )
for m in range(2) self.assertEqual(len(res.debug_images), 1)
for v in range(2) task_events = res.debug_images[0]
] self.assertEqual(task_events.task, non_report_task)
self.send_batch([*debug_image_events, *plot_events]) self.assertEqual(len(task_events.iterations), 1)
self.assertEqual(len(task_events.iterations[0].events), 4)
res = self.api.reports.get_task_data( self.assertEqual(len(res.plots), 1)
id=[non_report_task], only_fields=["name"], for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
) tasks = nested_get(res.plots, (m, v))
self.assertEqual(len(res.tasks), 1) self.assertEqual(len(tasks), 1)
self.assertEqual(res.tasks[0].id, non_report_task) task_plots = tasks[non_report_task]
self.assertFalse(any(field in res for field in ("plots", "debug_images"))) self.assertEqual(len(task_plots), 1)
iter_plots = task_plots["1"]
res = self.api.reports.get_task_data( self.assertEqual(iter_plots.name, "hello")
id=[non_report_task], self.assertEqual(len(iter_plots.plots), 1)
only_fields=["name"], ev = iter_plots.plots[0]
debug_images={"metrics": []}, self.assertEqual(ev["metric"], m)
plots={"metrics": [{"metric": "Metric_1"}]}, self.assertEqual(ev["variant"], v)
) self.assertEqual(ev["task"], non_report_task)
self.assertEqual(len(res.debug_images), 1) self.assertEqual(ev["iter"], 1)
task_events = res.debug_images[0]
self.assertEqual(task_events.task, non_report_task)
self.assertEqual(len(task_events.iterations), 1)
self.assertEqual(len(task_events.iterations[0].events), 4)
self.assertEqual(len(res.plots), 1)
for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
tasks = nested_get(res.plots, (m, v))
self.assertEqual(len(tasks), 1)
task_plots = tasks[non_report_task]
self.assertEqual(len(task_plots), 1)
iter_plots = task_plots["1"]
self.assertEqual(iter_plots.name, "hello")
self.assertEqual(len(iter_plots.plots), 1)
ev = iter_plots.plots[0]
self.assertEqual(ev["metric"], m)
self.assertEqual(ev["variant"], v)
self.assertEqual(ev["task"], non_report_task)
self.assertEqual(ev["iter"], 1)
@staticmethod @staticmethod
def _create_task_event(type_, task, iteration, **kwargs): def _create_task_event(type_, task, iteration, **kwargs):
@ -185,12 +194,14 @@ class TestReports(TestService):
**kwargs, **kwargs,
} }
delete_params = {"force": True}
def _temp_report(self, name, **kwargs): def _temp_report(self, name, **kwargs):
return self.create_temp( return self.create_temp(
"reports", "reports",
name=name, name=name,
object_name="task", object_name="task",
delete_params={"force": True}, delete_params=self.delete_params,
**kwargs, **kwargs,
) )
@ -199,10 +210,16 @@ class TestReports(TestService):
"tasks", "tasks",
name=name, name=name,
type="training", type="training",
delete_params={"force": True}, delete_params=self.delete_params,
**kwargs, **kwargs,
) )
def _temp_model(self, name="test model events", **kwargs):
self.update_missing(
kwargs, name=name, uri="file:///a/b", labels={}, ready=False
)
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
def send_batch(self, events): def send_batch(self, events):
_, data = self.api.send_batch("events.add_batch", events) _, data = self.api.send_batch("events.add_batch", events)
return data return data

View File

@ -16,9 +16,7 @@ class TestTaskEvents(TestService):
delete_params = dict(can_fail=True, force=True) delete_params = dict(can_fail=True, force=True)
def _temp_task(self, name="test task events"): def _temp_task(self, name="test task events"):
task_input = dict( task_input = dict(name=name, type="training",)
name=name, type="training",
)
return self.create_temp( return self.create_temp(
"tasks", delete_paramse=self.delete_params, **task_input "tasks", delete_paramse=self.delete_params, **task_input
) )
@ -220,6 +218,17 @@ class TestTaskEvents(TestService):
self.assertEqual(variant_data.x, [0, 1]) self.assertEqual(variant_data.x, [0, 1])
self.assertEqual(variant_data.y, [0.0, 1.0]) self.assertEqual(variant_data.y, [0.0, 1.0])
model_data = self.api.models.get_all_ex(
id=[model], only_fields=["last_metrics", "last_iteration"]
).models[0]
metric_data = first(first(model_data.last_metrics.values()).values())
self.assertEqual(1, model_data.last_iteration)
self.assertEqual(1, metric_data.value)
self.assertEqual(1, metric_data.max_value)
self.assertEqual(1, metric_data.max_value_iteration)
self.assertEqual(0, metric_data.min_value)
self.assertEqual(0, metric_data.min_value_iteration)
def test_error_events(self): def test_error_events(self):
task = self._temp_task() task = self._temp_task()
events = [ events = [