mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 10:43:10 +00:00
Model events are fully supported
This commit is contained in:
parent
2e4e060a82
commit
58465fbc17
@ -29,6 +29,10 @@ class ProjectOrNoneRequest(models.Base):
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetUniqueMetricsRequest(ProjectOrNoneRequest):
|
||||
model_metrics = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
@ -66,6 +66,7 @@ class GetTasksDataRequest(Base):
|
||||
plots: EventsRequest = EmbeddedField(EventsRequest)
|
||||
scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram)
|
||||
allow_public = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetAllRequest(Base):
|
||||
|
@ -30,6 +30,7 @@ from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIt
|
||||
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.database.model.model import Model
|
||||
@ -250,7 +251,6 @@ class EventBLL(object):
|
||||
task_or_model_ids.add(task_or_model_id)
|
||||
if (
|
||||
iter is not None
|
||||
and not model_events
|
||||
and event.get("metric") not in self._skip_iteration_for_metric
|
||||
):
|
||||
task_iteration[task_or_model_id] = max(
|
||||
@ -261,11 +261,10 @@ class EventBLL(object):
|
||||
self._update_last_metric_events_for_task(
|
||||
last_events=task_last_events[task_or_model_id], event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_or_model_id],
|
||||
event=event,
|
||||
)
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_scalar_events_for_task(
|
||||
last_events=task_last_scalar_events[task_or_model_id], event=event,
|
||||
)
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
@ -303,12 +302,21 @@ class EventBLL(object):
|
||||
else:
|
||||
errors_per_type["Error when indexing events batch"] += 1
|
||||
|
||||
if not model_events:
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_or_model_id in task_or_model_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_or_model_id in task_or_model_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update
|
||||
# all of them and not only those who's events were successful
|
||||
if model_events:
|
||||
ModelBLL.update_statistics(
|
||||
company_id=company_id,
|
||||
model_id=task_or_model_id,
|
||||
last_iteration_max=task_iteration.get(task_or_model_id),
|
||||
last_scalar_events=task_last_scalar_events.get(
|
||||
task_or_model_id
|
||||
),
|
||||
)
|
||||
else:
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_or_model_id,
|
||||
@ -319,15 +327,14 @@ class EventBLL(object):
|
||||
),
|
||||
last_events=task_last_events.get(task_or_model_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_or_model_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(
|
||||
remaining_tasks, company_id, last_update=now
|
||||
)
|
||||
|
||||
# this is for backwards compatibility with streaming bulk throwing exception on those
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
@ -484,7 +491,9 @@ class EventBLL(object):
|
||||
)
|
||||
|
||||
def _get_event_id(self, event):
|
||||
id_values = (str(event[field]) for field in self.event_id_fields if field in event)
|
||||
id_values = (
|
||||
str(event[field]) for field in self.event_id_fields if field in event
|
||||
)
|
||||
return hashlib.md5("-".join(id_values).encode()).hexdigest()
|
||||
|
||||
def scroll_task_events(
|
||||
@ -556,9 +565,7 @@ class EventBLL(object):
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
|
||||
query = {"bool": {"must": must}}
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=event_type,
|
||||
)
|
||||
search_args = dict(es=self.es, company_id=company_id, event_type=event_type)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args,
|
||||
)
|
||||
@ -586,7 +593,7 @@ class EventBLL(object):
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"iter": {"order": "desc"}},
|
||||
"size": last_iterations_per_plot
|
||||
"size": last_iterations_per_plot,
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -597,11 +604,7 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_response = search_company_events(
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
**search_args,
|
||||
)
|
||||
es_response = search_company_events(body=es_req, ignore=404, **search_args)
|
||||
|
||||
aggs_result = es_response.get("aggregations")
|
||||
if not aggs_result:
|
||||
@ -614,9 +617,7 @@ class EventBLL(object):
|
||||
for hit in variants_bucket["events"]["hits"]["hits"]
|
||||
]
|
||||
self.uncompress_plots(events)
|
||||
return TaskEventsResult(
|
||||
events=events, total_events=len(events)
|
||||
)
|
||||
return TaskEventsResult(events=events, total_events=len(events))
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
"""
|
||||
@ -731,12 +732,7 @@ class EventBLL(object):
|
||||
if not company_ids:
|
||||
return TaskEventsResult()
|
||||
|
||||
task_ids = (
|
||||
[task_id]
|
||||
if isinstance(task_id, str)
|
||||
else task_id
|
||||
)
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
|
||||
must = []
|
||||
if metrics:
|
||||
@ -967,7 +963,7 @@ class EventBLL(object):
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
metrics: MetricVariants = None
|
||||
metrics: MetricVariants = None,
|
||||
) -> Mapping[str, Sequence]:
|
||||
company_ids = [company_id] if isinstance(company_id, str) else company_id
|
||||
company_ids = [
|
||||
|
@ -5,7 +5,7 @@ from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
@ -28,11 +28,7 @@ class ModelBLL:
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(
|
||||
company_id,
|
||||
model_ids,
|
||||
only=None,
|
||||
allow_public=False,
|
||||
return_models=True,
|
||||
company_id, model_ids, only=None, allow_public=False, return_models=True,
|
||||
) -> Optional[Sequence[Model]]:
|
||||
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
|
||||
ids = set(model_ids)
|
||||
@ -179,12 +175,36 @@ class ModelBLL:
|
||||
"labels_count": {"$size": {"$objectToArray": "$labels"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {"labels_count": 1},
|
||||
},
|
||||
{"$project": {"labels_count": 1}},
|
||||
]
|
||||
)
|
||||
return {
|
||||
r.pop("_id"): r
|
||||
for r in result
|
||||
}
|
||||
return {r.pop("_id"): r for r in result}
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
company_id: str,
|
||||
model_id: str,
|
||||
last_iteration_max: int = None,
|
||||
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
||||
):
|
||||
updates = {"last_update": datetime.utcnow()}
|
||||
if last_iteration_max is not None:
|
||||
updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
raw_updates = {}
|
||||
if last_scalar_events is not None:
|
||||
raw_updates = {}
|
||||
if last_scalar_events is not None:
|
||||
get_last_metric_updates(
|
||||
task_id=model_id,
|
||||
last_scalar_events=last_scalar_events,
|
||||
raw_updates=raw_updates,
|
||||
extra_updates=updates,
|
||||
model_events=True,
|
||||
)
|
||||
|
||||
ret = Model.objects(id=model_id).update_one(**updates)
|
||||
if ret and raw_updates:
|
||||
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
|
||||
|
||||
return ret
|
||||
|
@ -209,7 +209,11 @@ class ProjectQueries:
|
||||
|
||||
@classmethod
|
||||
def get_unique_metric_variants(
|
||||
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
model_metrics: bool = False,
|
||||
):
|
||||
pipeline = [
|
||||
{
|
||||
@ -246,7 +250,8 @@ class ProjectQueries:
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
entity_cls = Model if model_metrics else Task
|
||||
result = entity_cls.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
|
@ -40,6 +40,7 @@ from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
get_last_metric_updates,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -412,77 +413,12 @@ class TaskBLL:
|
||||
|
||||
raw_updates = {}
|
||||
if last_scalar_events is not None:
|
||||
max_values = config.get("services.tasks.max_last_metrics", 2000)
|
||||
total_metrics = set()
|
||||
if max_values:
|
||||
query = dict(id=task_id)
|
||||
to_add = sum(len(v) for m, v in last_scalar_events.items())
|
||||
if to_add <= max_values:
|
||||
query[f"unique_metrics__{max_values-to_add}__exists"] = True
|
||||
task = Task.objects(**query).only("unique_metrics").first()
|
||||
if task and task.unique_metrics:
|
||||
total_metrics = set(task.unique_metrics)
|
||||
|
||||
new_metrics = []
|
||||
|
||||
def add_last_metric_conditional_update(
|
||||
metric_path: str, metric_value, iter_value: int, is_min: bool
|
||||
):
|
||||
"""
|
||||
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
|
||||
"""
|
||||
if is_min:
|
||||
field_prefix = "min"
|
||||
op = "$gt"
|
||||
else:
|
||||
field_prefix = "max"
|
||||
op = "$lt"
|
||||
|
||||
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
|
||||
condition = {
|
||||
"$or": [
|
||||
{"$lte": [f"${value_field}", None]},
|
||||
{op: [f"${value_field}", metric_value]},
|
||||
]
|
||||
}
|
||||
raw_updates[value_field] = {
|
||||
"$cond": [condition, metric_value, f"${value_field}"]
|
||||
}
|
||||
|
||||
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
|
||||
"__", "."
|
||||
)
|
||||
raw_updates[value_iteration_field] = {
|
||||
"$cond": [condition, iter_value, f"${value_iteration_field}",]
|
||||
}
|
||||
|
||||
for metric_key, metric_data in last_scalar_events.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
metric = (
|
||||
f"{variant_data.get('metric')}/{variant_data.get('variant')}"
|
||||
)
|
||||
if max_values:
|
||||
if (
|
||||
len(total_metrics) >= max_values
|
||||
and metric not in total_metrics
|
||||
):
|
||||
continue
|
||||
total_metrics.add(metric)
|
||||
|
||||
new_metrics.append(metric)
|
||||
path = f"last_metrics__{metric_key}__{variant_key}"
|
||||
for key, value in variant_data.items():
|
||||
if key in ("min_value", "max_value"):
|
||||
add_last_metric_conditional_update(
|
||||
metric_path=path,
|
||||
metric_value=value,
|
||||
iter_value=variant_data.get(f"{key}_iter", 0),
|
||||
is_min=(key == "min_value"),
|
||||
)
|
||||
elif key in ("metric", "variant", "value"):
|
||||
extra_updates[f"set__{path}__{key}"] = value
|
||||
if new_metrics:
|
||||
extra_updates["add_to_set__unique_metrics"] = new_metrics
|
||||
get_last_metric_updates(
|
||||
task_id=task_id,
|
||||
last_scalar_events=last_scalar_events,
|
||||
raw_updates=raw_updates,
|
||||
extra_updates=extra_updates,
|
||||
)
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
|
@ -5,7 +5,9 @@ import attr
|
||||
import six
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
@ -167,7 +169,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
) -> Task:
|
||||
"""
|
||||
Loads only task id and return the task only if it is updatable (status == 'created')
|
||||
@ -189,9 +191,88 @@ def get_task_for_update(
|
||||
return task
|
||||
|
||||
|
||||
def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True):
|
||||
def update_task(
|
||||
task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True
|
||||
):
|
||||
now = datetime.utcnow()
|
||||
last_updates = dict(last_change=now, last_changed_by=user_id)
|
||||
if set_last_update:
|
||||
last_updates.update(last_update=now)
|
||||
return task.update(**update_cmds, **last_updates)
|
||||
|
||||
|
||||
def get_last_metric_updates(
|
||||
task_id: str,
|
||||
last_scalar_events: dict,
|
||||
raw_updates: dict,
|
||||
extra_updates: dict,
|
||||
model_events: bool = False,
|
||||
):
|
||||
max_values = config.get("services.tasks.max_last_metrics", 2000)
|
||||
total_metrics = set()
|
||||
if max_values:
|
||||
query = dict(id=task_id)
|
||||
to_add = sum(len(v) for m, v in last_scalar_events.items())
|
||||
if to_add <= max_values:
|
||||
query[f"unique_metrics__{max_values - to_add}__exists"] = True
|
||||
db_cls = Model if model_events else Task
|
||||
task = db_cls.objects(**query).only("unique_metrics").first()
|
||||
if task and task.unique_metrics:
|
||||
total_metrics = set(task.unique_metrics)
|
||||
|
||||
new_metrics = []
|
||||
|
||||
def add_last_metric_conditional_update(
|
||||
metric_path: str, metric_value, iter_value: int, is_min: bool
|
||||
):
|
||||
"""
|
||||
Build an aggregation for an atomic update of the min or max value and the corresponding iteration
|
||||
"""
|
||||
if is_min:
|
||||
field_prefix = "min"
|
||||
op = "$gt"
|
||||
else:
|
||||
field_prefix = "max"
|
||||
op = "$lt"
|
||||
|
||||
value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".")
|
||||
condition = {
|
||||
"$or": [
|
||||
{"$lte": [f"${value_field}", None]},
|
||||
{op: [f"${value_field}", metric_value]},
|
||||
]
|
||||
}
|
||||
raw_updates[value_field] = {
|
||||
"$cond": [condition, metric_value, f"${value_field}"]
|
||||
}
|
||||
|
||||
value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace(
|
||||
"__", "."
|
||||
)
|
||||
raw_updates[value_iteration_field] = {
|
||||
"$cond": [condition, iter_value, f"${value_iteration_field}"]
|
||||
}
|
||||
|
||||
for metric_key, metric_data in last_scalar_events.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}"
|
||||
if max_values:
|
||||
if len(total_metrics) >= max_values and metric not in total_metrics:
|
||||
continue
|
||||
total_metrics.add(metric)
|
||||
|
||||
new_metrics.append(metric)
|
||||
path = f"last_metrics__{metric_key}__{variant_key}"
|
||||
for key, value in variant_data.items():
|
||||
if key in ("min_value", "max_value"):
|
||||
add_last_metric_conditional_update(
|
||||
metric_path=path,
|
||||
metric_value=value,
|
||||
iter_value=variant_data.get(f"{key}_iter", 0),
|
||||
is_min=(key == "min_value"),
|
||||
)
|
||||
elif key in ("metric", "variant", "value"):
|
||||
extra_updates[f"set__{path}__{key}"] = value
|
||||
|
||||
if new_metrics:
|
||||
extra_updates["add_to_set__unique_metrics"] = new_metrics
|
||||
|
@ -3,6 +3,8 @@ from mongoengine import (
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentField,
|
||||
IntField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
@ -17,12 +19,14 @@ from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.metrics import MetricEvent
|
||||
from apiserver.database.model.task.task import Task
|
||||
|
||||
|
||||
class Model(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
"last_metrics.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@ -67,6 +71,7 @@ class Model(AttributedDocument):
|
||||
"parent",
|
||||
"metadata.*",
|
||||
),
|
||||
range_fields=("last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("last_update",),
|
||||
)
|
||||
|
||||
@ -92,6 +97,9 @@ class Model(AttributedDocument):
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
last_iteration = IntField(default=0)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
unique_metrics = ListField(StringField(required=True), exclude_by_default=True)
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
return self.company or self.company_origin or ""
|
||||
|
@ -1,6 +1,6 @@
|
||||
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
include "_tasks_common.conf"
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
@ -104,6 +104,17 @@ _definitions {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
last_iteration {
|
||||
description: "Last iteration reported for this model"
|
||||
type: integer
|
||||
}
|
||||
last_metrics {
|
||||
description: "Last metric variants (hash to events), one for each metric hash"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/last_metrics_variants"
|
||||
}
|
||||
}
|
||||
stats {
|
||||
description: "Model statistics"
|
||||
type: object
|
||||
|
@ -898,6 +898,13 @@ get_unique_metric_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_unique_metric_variants."2.13"} {
|
||||
request.properties.model_metrics {
|
||||
description: If set to true then bring unique metric and variant names from the project models otherwise from the project tasks
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyperparam_values {
|
||||
"2.13" {
|
||||
|
@ -568,6 +568,13 @@ get_task_data {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${get_task_data."2.23"} {
|
||||
request.properties.model_events {
|
||||
type: boolean
|
||||
description: If set then the retrieving model events. Otherwise task events
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
"2.23" {
|
||||
|
@ -561,7 +561,6 @@ def _get_multitask_plots(
|
||||
metrics: MetricVariants = None,
|
||||
scroll_id=None,
|
||||
no_scroll=True,
|
||||
model_events=False,
|
||||
) -> Tuple[dict, int, str]:
|
||||
task_names = {
|
||||
t.id: t.name for t in itertools.chain.from_iterable(companies.values())
|
||||
@ -598,7 +597,6 @@ def get_multi_task_plots(call, company_id, _):
|
||||
last_iters=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
model_events=model_events,
|
||||
)
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
|
@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
from typing import Sequence, Union
|
||||
|
||||
from mongoengine import Q, EmbeddedDocument
|
||||
|
||||
@ -59,6 +59,11 @@ org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
|
||||
|
||||
def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
|
||||
conform_output_tags(call, model_data)
|
||||
unescape_metadata(call, model_data)
|
||||
|
||||
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
@ -74,8 +79,7 @@ def get_by_id(call: APICall, company_id, _):
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
conform_output_tags(call, models[0])
|
||||
unescape_metadata(call, models[0])
|
||||
conform_model_data(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@ -102,8 +106,7 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
unescape_metadata(call, model_dict)
|
||||
conform_model_data(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
@ -119,8 +122,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
conform_model_data(call, models)
|
||||
|
||||
if not request.include_stats:
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
@ -142,8 +144,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
conform_model_data(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@ -159,8 +160,7 @@ def get_all(call: APICall, company_id, _):
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
conform_model_data(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@ -428,8 +428,7 @@ def edit(call: APICall, company_id, _):
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
conform_model_data(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
@ -461,8 +460,7 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
conform_output_tags(call, updated_fields)
|
||||
unescape_metadata(call, updated_fields)
|
||||
conform_model_data(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
|
@ -15,10 +15,10 @@ from apiserver.apimodels.projects import (
|
||||
DeleteRequest,
|
||||
MoveRequest,
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
ProjectRequest,
|
||||
ProjectModelMetadataValuesRequest,
|
||||
ProjectChildrenType,
|
||||
GetUniqueMetricsRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||
@ -345,16 +345,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
|
||||
"projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest
|
||||
)
|
||||
def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
||||
call: APICall, company_id: str, request: GetUniqueMetricsRequest,
|
||||
):
|
||||
|
||||
metrics = project_queries.get_unique_metric_variants(
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
model_metrics=request.model_metrics,
|
||||
)
|
||||
|
||||
call.result.data = {"metrics": metrics}
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Union, Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
@ -39,14 +41,18 @@ worker_bll = WorkerBLL()
|
||||
queue_bll = QueueBLL(worker_bll)
|
||||
|
||||
|
||||
def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]):
|
||||
conform_output_tags(call, queue_data)
|
||||
unescape_metadata(call, queue_data)
|
||||
|
||||
|
||||
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest)
|
||||
def get_by_id(call: APICall, company_id, request: GetByIdRequest):
|
||||
queue = queue_bll.get_by_id(
|
||||
company_id, request.queue, max_task_entries=request.max_task_entries
|
||||
)
|
||||
queue_dict = queue.to_proper_dict()
|
||||
conform_output_tags(call, queue_dict)
|
||||
unescape_metadata(call, queue_dict)
|
||||
conform_queue_data(call, queue_dict)
|
||||
call.result.data = {"queue": queue_dict}
|
||||
|
||||
|
||||
@ -85,8 +91,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
|
||||
max_task_entries=request.max_task_entries,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
unescape_metadata(call, queues)
|
||||
conform_queue_data(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@ -102,8 +107,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest):
|
||||
max_task_entries=request.max_task_entries,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
unescape_metadata(call, queues)
|
||||
conform_queue_data(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@ -135,8 +139,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
updated, fields = queue_bll.update(
|
||||
company_id=company_id, queue_id=req_model.queue, **data
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
conform_queue_data(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
|
@ -17,6 +17,8 @@ from apiserver.apimodels.reports import (
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.bll.project.project_bll import reports_project_name, reports_tag
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.services.models import conform_model_data
|
||||
from apiserver.services.utils import process_include_subprojects, sort_tags_response
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
@ -35,7 +37,7 @@ from apiserver.services.events import (
|
||||
from apiserver.services.tasks import (
|
||||
escape_execution_parameters,
|
||||
_hidden_query,
|
||||
unprepare_from_saved,
|
||||
conform_task_data,
|
||||
)
|
||||
|
||||
org_bll = OrgBLL()
|
||||
@ -178,7 +180,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllRequest):
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
conform_task_data(call, tasks)
|
||||
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
@ -198,18 +200,25 @@ def _get_task_metrics_from_request(
|
||||
|
||||
@endpoint("reports.get_task_data")
|
||||
def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
if request.model_events:
|
||||
entity_cls = Model
|
||||
conform_data = conform_model_data
|
||||
else:
|
||||
entity_cls = Task
|
||||
conform_data = conform_task_data
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
process_include_subprojects(call_data)
|
||||
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
tasks = entity_cls.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
query=_hidden_query(call_data),
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
conform_data(call, tasks)
|
||||
res = {"tasks": tasks, **ret_params}
|
||||
if not (
|
||||
request.debug_images or request.plots or request.scalar_metrics_iter_histogram
|
||||
@ -217,7 +226,9 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest):
|
||||
return res
|
||||
|
||||
task_ids = [task["id"] for task in tasks]
|
||||
companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids)
|
||||
companies = _get_task_or_model_index_companies(
|
||||
company_id, task_ids=task_ids, model_events=request.model_events
|
||||
)
|
||||
if request.debug_images:
|
||||
result = event_bll.debug_images_iterator.get_task_events(
|
||||
companies={
|
||||
|
@ -182,7 +182,7 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
req_model.task, company_id=company_id, allow_public=True
|
||||
)
|
||||
task_dict = task.to_proper_dict()
|
||||
unprepare_from_saved(call, task_dict)
|
||||
conform_task_data(call, task_dict)
|
||||
call.result.data = {"task": task_dict}
|
||||
|
||||
|
||||
@ -231,7 +231,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq):
|
||||
allow_public=request.allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
conform_task_data(call, tasks)
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@ -245,7 +245,7 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
conform_task_data(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@ -264,7 +264,7 @@ def get_all(call: APICall, company_id, _):
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
conform_task_data(call, tasks)
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@ -430,7 +430,7 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
return fields
|
||||
|
||||
|
||||
def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
||||
def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]):
|
||||
if isinstance(tasks_data, dict):
|
||||
tasks_data = [tasks_data]
|
||||
|
||||
@ -608,7 +608,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
company_id, project=task.project, fields=updated_fields
|
||||
)
|
||||
update_project_time(updated_fields.get("project"))
|
||||
unprepare_from_saved(call, updated_fields)
|
||||
conform_task_data(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@ -763,7 +763,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
company_id, project=task.project, fields=fixed_fields
|
||||
)
|
||||
update_project_time(fields.get("project"))
|
||||
unprepare_from_saved(call, fields)
|
||||
conform_task_data(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
@ -113,66 +113,75 @@ class TestReports(TestService):
|
||||
|
||||
def test_reports_task_data(self):
|
||||
report_task = self._temp_report(name="Rep1")
|
||||
non_report_task = self._temp_task(name="hello")
|
||||
debug_image_events = [
|
||||
self._create_task_event(
|
||||
task=non_report_task,
|
||||
type_="training_debug_image",
|
||||
iteration=1,
|
||||
metric=f"Metric_{m}",
|
||||
variant=f"Variant_{v}",
|
||||
url=f"{m}_{v}",
|
||||
for model_events in (False, True):
|
||||
if model_events:
|
||||
non_report_task = self._temp_model(name="hello")
|
||||
event_args = {"model_event": True}
|
||||
else:
|
||||
non_report_task = self._temp_task(name="hello")
|
||||
event_args = {}
|
||||
debug_image_events = [
|
||||
self._create_task_event(
|
||||
task=non_report_task,
|
||||
type_="training_debug_image",
|
||||
iteration=1,
|
||||
metric=f"Metric_{m}",
|
||||
variant=f"Variant_{v}",
|
||||
url=f"{m}_{v}",
|
||||
**event_args,
|
||||
)
|
||||
for m in range(2)
|
||||
for v in range(2)
|
||||
]
|
||||
plot_events = [
|
||||
self._create_task_event(
|
||||
task=non_report_task,
|
||||
type_="plot",
|
||||
iteration=1,
|
||||
metric=f"Metric_{m}",
|
||||
variant=f"Variant_{v}",
|
||||
plot_str=f"Hello plot",
|
||||
**event_args,
|
||||
)
|
||||
for m in range(2)
|
||||
for v in range(2)
|
||||
]
|
||||
self.send_batch([*debug_image_events, *plot_events])
|
||||
|
||||
res = self.api.reports.get_task_data(
|
||||
id=[non_report_task], only_fields=["name"], model_events=model_events
|
||||
)
|
||||
for m in range(2)
|
||||
for v in range(2)
|
||||
]
|
||||
plot_events = [
|
||||
self._create_task_event(
|
||||
task=non_report_task,
|
||||
type_="plot",
|
||||
iteration=1,
|
||||
metric=f"Metric_{m}",
|
||||
variant=f"Variant_{v}",
|
||||
plot_str=f"Hello plot",
|
||||
self.assertEqual(len(res.tasks), 1)
|
||||
self.assertEqual(res.tasks[0].id, non_report_task)
|
||||
self.assertFalse(any(field in res for field in ("plots", "debug_images")))
|
||||
|
||||
res = self.api.reports.get_task_data(
|
||||
id=[non_report_task],
|
||||
only_fields=["name"],
|
||||
debug_images={"metrics": []},
|
||||
plots={"metrics": [{"metric": "Metric_1"}]},
|
||||
model_events=model_events,
|
||||
)
|
||||
for m in range(2)
|
||||
for v in range(2)
|
||||
]
|
||||
self.send_batch([*debug_image_events, *plot_events])
|
||||
self.assertEqual(len(res.debug_images), 1)
|
||||
task_events = res.debug_images[0]
|
||||
self.assertEqual(task_events.task, non_report_task)
|
||||
self.assertEqual(len(task_events.iterations), 1)
|
||||
self.assertEqual(len(task_events.iterations[0].events), 4)
|
||||
|
||||
res = self.api.reports.get_task_data(
|
||||
id=[non_report_task], only_fields=["name"],
|
||||
)
|
||||
self.assertEqual(len(res.tasks), 1)
|
||||
self.assertEqual(res.tasks[0].id, non_report_task)
|
||||
self.assertFalse(any(field in res for field in ("plots", "debug_images")))
|
||||
|
||||
res = self.api.reports.get_task_data(
|
||||
id=[non_report_task],
|
||||
only_fields=["name"],
|
||||
debug_images={"metrics": []},
|
||||
plots={"metrics": [{"metric": "Metric_1"}]},
|
||||
)
|
||||
self.assertEqual(len(res.debug_images), 1)
|
||||
task_events = res.debug_images[0]
|
||||
self.assertEqual(task_events.task, non_report_task)
|
||||
self.assertEqual(len(task_events.iterations), 1)
|
||||
self.assertEqual(len(task_events.iterations[0].events), 4)
|
||||
|
||||
self.assertEqual(len(res.plots), 1)
|
||||
for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
|
||||
tasks = nested_get(res.plots, (m, v))
|
||||
self.assertEqual(len(tasks), 1)
|
||||
task_plots = tasks[non_report_task]
|
||||
self.assertEqual(len(task_plots), 1)
|
||||
iter_plots = task_plots["1"]
|
||||
self.assertEqual(iter_plots.name, "hello")
|
||||
self.assertEqual(len(iter_plots.plots), 1)
|
||||
ev = iter_plots.plots[0]
|
||||
self.assertEqual(ev["metric"], m)
|
||||
self.assertEqual(ev["variant"], v)
|
||||
self.assertEqual(ev["task"], non_report_task)
|
||||
self.assertEqual(ev["iter"], 1)
|
||||
self.assertEqual(len(res.plots), 1)
|
||||
for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")):
|
||||
tasks = nested_get(res.plots, (m, v))
|
||||
self.assertEqual(len(tasks), 1)
|
||||
task_plots = tasks[non_report_task]
|
||||
self.assertEqual(len(task_plots), 1)
|
||||
iter_plots = task_plots["1"]
|
||||
self.assertEqual(iter_plots.name, "hello")
|
||||
self.assertEqual(len(iter_plots.plots), 1)
|
||||
ev = iter_plots.plots[0]
|
||||
self.assertEqual(ev["metric"], m)
|
||||
self.assertEqual(ev["variant"], v)
|
||||
self.assertEqual(ev["task"], non_report_task)
|
||||
self.assertEqual(ev["iter"], 1)
|
||||
|
||||
@staticmethod
|
||||
def _create_task_event(type_, task, iteration, **kwargs):
|
||||
@ -185,12 +194,14 @@ class TestReports(TestService):
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
delete_params = {"force": True}
|
||||
|
||||
def _temp_report(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"reports",
|
||||
name=name,
|
||||
object_name="task",
|
||||
delete_params={"force": True},
|
||||
delete_params=self.delete_params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -199,10 +210,16 @@ class TestReports(TestService):
|
||||
"tasks",
|
||||
name=name,
|
||||
type="training",
|
||||
delete_params={"force": True},
|
||||
delete_params=self.delete_params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _temp_model(self, name="test model events", **kwargs):
|
||||
self.update_missing(
|
||||
kwargs, name=name, uri="file:///a/b", labels={}, ready=False
|
||||
)
|
||||
return self.create_temp("models", delete_params=self.delete_params, **kwargs)
|
||||
|
||||
def send_batch(self, events):
|
||||
_, data = self.api.send_batch("events.add_batch", events)
|
||||
return data
|
||||
|
@ -16,9 +16,7 @@ class TestTaskEvents(TestService):
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training",
|
||||
)
|
||||
task_input = dict(name=name, type="training",)
|
||||
return self.create_temp(
|
||||
"tasks", delete_paramse=self.delete_params, **task_input
|
||||
)
|
||||
@ -220,6 +218,17 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(variant_data.x, [0, 1])
|
||||
self.assertEqual(variant_data.y, [0.0, 1.0])
|
||||
|
||||
model_data = self.api.models.get_all_ex(
|
||||
id=[model], only_fields=["last_metrics", "last_iteration"]
|
||||
).models[0]
|
||||
metric_data = first(first(model_data.last_metrics.values()).values())
|
||||
self.assertEqual(1, model_data.last_iteration)
|
||||
self.assertEqual(1, metric_data.value)
|
||||
self.assertEqual(1, metric_data.max_value)
|
||||
self.assertEqual(1, metric_data.max_value_iteration)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
self.assertEqual(0, metric_data.min_value_iteration)
|
||||
|
||||
def test_error_events(self):
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
|
Loading…
Reference in New Issue
Block a user