diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index f37cf62..339201b 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -29,6 +29,10 @@ class ProjectOrNoneRequest(models.Base): include_subprojects = fields.BoolField(default=True) +class GetUniqueMetricsRequest(ProjectOrNoneRequest): + model_metrics = fields.BoolField(default=False) + + class GetParamsRequest(ProjectOrNoneRequest): page = fields.IntField(default=0) page_size = fields.IntField(default=500) diff --git a/apiserver/apimodels/reports.py b/apiserver/apimodels/reports.py index b5f1c4d..7c06a0a 100644 --- a/apiserver/apimodels/reports.py +++ b/apiserver/apimodels/reports.py @@ -66,6 +66,7 @@ class GetTasksDataRequest(Base): plots: EventsRequest = EmbeddedField(EventsRequest) scalar_metrics_iter_histogram: ScalarMetricsIterHistogram = EmbeddedField(ScalarMetricsIterHistogram) allow_public = BoolField(default=True) + model_events: bool = BoolField(default=False) class GetAllRequest(Base): diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 2f1ebaa..a6b775b 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -30,6 +30,7 @@ from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIt from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator +from apiserver.bll.model import ModelBLL from apiserver.bll.util import parallel_chunked_decorator from apiserver.database import utils as dbutils from apiserver.database.model.model import Model @@ -250,7 +251,6 @@ class EventBLL(object): task_or_model_ids.add(task_or_model_id) if ( iter is not None - and not model_events and event.get("metric") not in self._skip_iteration_for_metric ): task_iteration[task_or_model_id] = max( @@ -261,11 +261,10 @@ class EventBLL(object): self._update_last_metric_events_for_task( last_events=task_last_events[task_or_model_id], event=event, ) - if event_type == EventType.metrics_scalar.value: - self._update_last_scalar_events_for_task( - last_events=task_last_scalar_events[task_or_model_id], - event=event, - ) + if event_type == EventType.metrics_scalar.value: + self._update_last_scalar_events_for_task( + last_events=task_last_scalar_events[task_or_model_id], event=event, + ) actions.append(es_action) @@ -303,12 +302,21 @@ class EventBLL(object): else: errors_per_type["Error when indexing events batch"] += 1 - if not model_events: - remaining_tasks = set() - now = datetime.utcnow() - for task_or_model_id in task_or_model_ids: - # Update related tasks. For reasons of performance, we prefer to update - # all of them and not only those who's events were successful + remaining_tasks = set() + now = datetime.utcnow() + for task_or_model_id in task_or_model_ids: + # Update related tasks. For reasons of performance, we prefer to update + # all of them and not only those who's events were successful + if model_events: + ModelBLL.update_statistics( + company_id=company_id, + model_id=task_or_model_id, + last_iteration_max=task_iteration.get(task_or_model_id), + last_scalar_events=task_last_scalar_events.get( + task_or_model_id + ), + ) + else: updated = self._update_task( company_id=company_id, task_id=task_or_model_id, @@ -319,15 +327,14 @@ class EventBLL(object): ), last_events=task_last_events.get(task_or_model_id), ) - if not updated: remaining_tasks.add(task_or_model_id) continue - if remaining_tasks: - TaskBLL.set_last_update( - remaining_tasks, company_id, last_update=now - ) + if remaining_tasks: + TaskBLL.set_last_update( + remaining_tasks, company_id, last_update=now + ) # this is for backwards compatibility with streaming bulk throwing exception on those invalid_iterations_count = errors_per_type.get(invalid_iteration_error) @@ -484,7 +491,9 @@ class EventBLL(object): ) def _get_event_id(self, event): - id_values = (str(event[field]) for field in self.event_id_fields if field in event) + id_values = ( + str(event[field]) for field in self.event_id_fields if field in event + ) return hashlib.md5("-".join(id_values).encode()).hexdigest() def scroll_task_events( @@ -556,9 +565,7 @@ class EventBLL(object): must.append(get_metric_variants_condition(metric_variants)) query = {"bool": {"must": must}} - search_args = dict( - es=self.es, company_id=company_id, event_type=event_type, - ) + search_args = dict(es=self.es, company_id=company_id, event_type=event_type) max_metrics, max_variants = get_max_metric_and_variant_counts( query=query, **search_args, ) @@ -586,7 +593,7 @@ class EventBLL(object): "events": { "top_hits": { "sort": {"iter": {"order": "desc"}}, - "size": last_iterations_per_plot + "size": last_iterations_per_plot, } } }, @@ -597,11 +604,7 @@ class EventBLL(object): } with translate_errors_context(): - es_response = search_company_events( - body=es_req, - ignore=404, - **search_args, - ) + es_response = search_company_events(body=es_req, ignore=404, **search_args) aggs_result = es_response.get("aggregations") if not aggs_result: @@ -614,9 +617,7 @@ class EventBLL(object): for hit in variants_bucket["events"]["hits"]["hits"] ] self.uncompress_plots(events) - return TaskEventsResult( - events=events, total_events=len(events) - ) + return TaskEventsResult(events=events, total_events=len(events)) def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]: """ @@ -731,12 +732,7 @@ class EventBLL(object): if not company_ids: return TaskEventsResult() - task_ids = ( - [task_id] - if isinstance(task_id, str) - else task_id - ) - + task_ids = [task_id] if isinstance(task_id, str) else task_id must = [] if metrics: @@ -967,7 +963,7 @@ class EventBLL(object): event_type: EventType, task_id: Union[str, Sequence[str]], iters: int, - metrics: MetricVariants = None + metrics: MetricVariants = None, ) -> Mapping[str, Sequence]: company_ids = [company_id] if isinstance(company_id, str) else company_id company_ids = [ diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py index 85f129a..f661445 100644 --- a/apiserver/bll/model/__init__.py +++ b/apiserver/bll/model/__init__.py @@ -5,7 +5,7 @@ from mongoengine import Q from apiserver.apierrors import errors from apiserver.apimodels.models import ModelTaskPublishResponse -from apiserver.bll.task.utils import deleted_prefix +from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates from apiserver.database.model import EntityVisibility from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task, TaskStatus @@ -28,11 +28,7 @@ class ModelBLL: @staticmethod def assert_exists( - company_id, - model_ids, - only=None, - allow_public=False, - return_models=True, + company_id, model_ids, only=None, allow_public=False, return_models=True, ) -> Optional[Sequence[Model]]: model_ids = [model_ids] if isinstance(model_ids, str) else model_ids ids = set(model_ids) @@ -179,12 +175,36 @@ class ModelBLL: "labels_count": {"$size": {"$objectToArray": "$labels"}} } }, - { - "$project": {"labels_count": 1}, - }, + {"$project": {"labels_count": 1}}, ] ) - return { - r.pop("_id"): r - for r in result - } + return {r.pop("_id"): r for r in result} + + @staticmethod + def update_statistics( + company_id: str, + model_id: str, + last_iteration_max: int = None, + last_scalar_events: Dict[str, Dict[str, dict]] = None, + ): + updates = {"last_update": datetime.utcnow()} + if last_iteration_max is not None: + updates.update(max__last_iteration=last_iteration_max) + + raw_updates = {} + if last_scalar_events is not None: + raw_updates = {} + if last_scalar_events is not None: + get_last_metric_updates( + task_id=model_id, + last_scalar_events=last_scalar_events, + raw_updates=raw_updates, + extra_updates=updates, + model_events=True, + ) + + ret = Model.objects(id=model_id).update_one(**updates) + if ret and raw_updates: + Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}]) + + return ret diff --git a/apiserver/bll/project/project_queries.py b/apiserver/bll/project/project_queries.py index e7a3a95..76a2af4 100644 --- a/apiserver/bll/project/project_queries.py +++ b/apiserver/bll/project/project_queries.py @@ -209,7 +209,11 @@ class ProjectQueries: @classmethod def get_unique_metric_variants( - cls, company_id, project_ids: Sequence[str], include_subprojects: bool + cls, + company_id, + project_ids: Sequence[str], + include_subprojects: bool, + model_metrics: bool = False, ): pipeline = [ { @@ -246,7 +250,8 @@ class ProjectQueries: {"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})}, ] - result = Task.aggregate(pipeline) + entity_cls = Model if model_metrics else Task + result = entity_cls.aggregate(pipeline) return [r["metrics"][0] for r in result] @classmethod diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 100988e..48ff930 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -40,6 +40,7 @@ from .utils import ( ChangeStatusRequest, update_project_time, deleted_prefix, + get_last_metric_updates, ) log = config.logger(__file__) @@ -412,77 +413,12 @@ class TaskBLL: raw_updates = {} if last_scalar_events is not None: - max_values = config.get("services.tasks.max_last_metrics", 2000) - total_metrics = set() - if max_values: - query = dict(id=task_id) - to_add = sum(len(v) for m, v in last_scalar_events.items()) - if to_add <= max_values: - query[f"unique_metrics__{max_values-to_add}__exists"] = True - task = Task.objects(**query).only("unique_metrics").first() - if task and task.unique_metrics: - total_metrics = set(task.unique_metrics) - - new_metrics = [] - - def add_last_metric_conditional_update( - metric_path: str, metric_value, iter_value: int, is_min: bool - ): - """ - Build an aggregation for an atomic update of the min or max value and the corresponding iteration - """ - if is_min: - field_prefix = "min" - op = "$gt" - else: - field_prefix = "max" - op = "$lt" - - value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".") - condition = { - "$or": [ - {"$lte": [f"${value_field}", None]}, - {op: [f"${value_field}", metric_value]}, - ] - } - raw_updates[value_field] = { - "$cond": [condition, metric_value, f"${value_field}"] - } - - value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace( - "__", "." - ) - raw_updates[value_iteration_field] = { - "$cond": [condition, iter_value, f"${value_iteration_field}",] - } - - for metric_key, metric_data in last_scalar_events.items(): - for variant_key, variant_data in metric_data.items(): - metric = ( - f"{variant_data.get('metric')}/{variant_data.get('variant')}" - ) - if max_values: - if ( - len(total_metrics) >= max_values - and metric not in total_metrics - ): - continue - total_metrics.add(metric) - - new_metrics.append(metric) - path = f"last_metrics__{metric_key}__{variant_key}" - for key, value in variant_data.items(): - if key in ("min_value", "max_value"): - add_last_metric_conditional_update( - metric_path=path, - metric_value=value, - iter_value=variant_data.get(f"{key}_iter", 0), - is_min=(key == "min_value"), - ) - elif key in ("metric", "variant", "value"): - extra_updates[f"set__{path}__{key}"] = value - if new_metrics: - extra_updates["add_to_set__unique_metrics"] = new_metrics + get_last_metric_updates( + task_id=task_id, + last_scalar_events=last_scalar_events, + raw_updates=raw_updates, + extra_updates=extra_updates, + ) if last_events is not None: diff --git a/apiserver/bll/task/utils.py b/apiserver/bll/task/utils.py index 418960c..f579cab 100644 --- a/apiserver/bll/task/utils.py +++ b/apiserver/bll/task/utils.py @@ -5,7 +5,9 @@ import attr import six from apiserver.apierrors import errors +from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context +from apiserver.database.model.model import Model from apiserver.database.model.project import Project from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags from apiserver.database.utils import get_options @@ -167,7 +169,7 @@ def update_project_time(project_ids: Union[str, Sequence[str]]): def get_task_for_update( - company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False + company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False ) -> Task: """ Loads only task id and return the task only if it is updatable (status == 'created') @@ -189,9 +191,88 @@ def get_task_for_update( return task -def update_task(task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True): +def update_task( + task: Task, user_id: str, update_cmds: dict, set_last_update: bool = True +): now = datetime.utcnow() last_updates = dict(last_change=now, last_changed_by=user_id) if set_last_update: last_updates.update(last_update=now) return task.update(**update_cmds, **last_updates) + + +def get_last_metric_updates( + task_id: str, + last_scalar_events: dict, + raw_updates: dict, + extra_updates: dict, + model_events: bool = False, +): + max_values = config.get("services.tasks.max_last_metrics", 2000) + total_metrics = set() + if max_values: + query = dict(id=task_id) + to_add = sum(len(v) for m, v in last_scalar_events.items()) + if to_add <= max_values: + query[f"unique_metrics__{max_values - to_add}__exists"] = True + db_cls = Model if model_events else Task + task = db_cls.objects(**query).only("unique_metrics").first() + if task and task.unique_metrics: + total_metrics = set(task.unique_metrics) + + new_metrics = [] + + def add_last_metric_conditional_update( + metric_path: str, metric_value, iter_value: int, is_min: bool + ): + """ + Build an aggregation for an atomic update of the min or max value and the corresponding iteration + """ + if is_min: + field_prefix = "min" + op = "$gt" + else: + field_prefix = "max" + op = "$lt" + + value_field = f"{metric_path}__{field_prefix}_value".replace("__", ".") + condition = { + "$or": [ + {"$lte": [f"${value_field}", None]}, + {op: [f"${value_field}", metric_value]}, + ] + } + raw_updates[value_field] = { + "$cond": [condition, metric_value, f"${value_field}"] + } + + value_iteration_field = f"{metric_path}__{field_prefix}_value_iteration".replace( + "__", "." + ) + raw_updates[value_iteration_field] = { + "$cond": [condition, iter_value, f"${value_iteration_field}"] + } + + for metric_key, metric_data in last_scalar_events.items(): + for variant_key, variant_data in metric_data.items(): + metric = f"{variant_data.get('metric')}/{variant_data.get('variant')}" + if max_values: + if len(total_metrics) >= max_values and metric not in total_metrics: + continue + total_metrics.add(metric) + + new_metrics.append(metric) + path = f"last_metrics__{metric_key}__{variant_key}" + for key, value in variant_data.items(): + if key in ("min_value", "max_value"): + add_last_metric_conditional_update( + metric_path=path, + metric_value=value, + iter_value=variant_data.get(f"{key}_iter", 0), + is_min=(key == "min_value"), + ) + elif key in ("metric", "variant", "value"): + extra_updates[f"set__{path}__{key}"] = value + + if new_metrics: + extra_updates["add_to_set__unique_metrics"] = new_metrics diff --git a/apiserver/database/model/model.py b/apiserver/database/model/model.py index 8ce4d20..adf0836 100644 --- a/apiserver/database/model/model.py +++ b/apiserver/database/model/model.py @@ -3,6 +3,8 @@ from mongoengine import ( DateTimeField, BooleanField, EmbeddedDocumentField, + IntField, + ListField, ) from apiserver.database import Database, strict @@ -17,12 +19,14 @@ from apiserver.database.model.base import GetMixin from apiserver.database.model.metadata import MetadataItem from apiserver.database.model.model_labels import ModelLabels from apiserver.database.model.project import Project +from apiserver.database.model.task.metrics import MetricEvent from apiserver.database.model.task.task import Task class Model(AttributedDocument): _field_collation_overrides = { "metadata.": AttributedDocument._numeric_locale, + "last_metrics.": AttributedDocument._numeric_locale, } meta = { @@ -67,6 +71,7 @@ class Model(AttributedDocument): "parent", "metadata.*", ), + range_fields=("last_metrics.*", "last_iteration"), datetime_fields=("last_update",), ) @@ -92,6 +97,9 @@ class Model(AttributedDocument): metadata = SafeMapField( field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True ) + last_iteration = IntField(default=0) + last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent))) + unique_metrics = ListField(StringField(required=True), exclude_by_default=True) def get_index_company(self) -> str: return self.company or self.company_origin or "" diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 2cc1c0f..3d5a713 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -1,6 +1,6 @@ _description: """This service provides a management interface for models (results of training tasks) stored in the system.""" _definitions { - include "_common.conf" + include "_tasks_common.conf" multi_field_pattern_data { type: object properties { @@ -104,6 +104,17 @@ _definitions { "$ref": "#/definitions/metadata_item" } } + last_iteration { + description: "Last iteration reported for this model" + type: integer + } + last_metrics { + description: "Last metric variants (hash to events), one for each metric hash" + type: object + additionalProperties { + "$ref": "#/definitions/last_metrics_variants" + } + } stats { description: "Model statistics" type: object diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 990a547..a4c4e46 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -898,6 +898,13 @@ get_unique_metric_variants { } } } + "999.0": ${get_unique_metric_variants."2.13"} { + request.properties.model_metrics { + description: If set to true then bring unique metric and variant names from the project models otherwise from the project tasks + type: boolean + default: false + } + } } get_hyperparam_values { "2.13" { diff --git a/apiserver/schema/services/reports.conf b/apiserver/schema/services/reports.conf index f549bf6..0a91939 100644 --- a/apiserver/schema/services/reports.conf +++ b/apiserver/schema/services/reports.conf @@ -568,6 +568,13 @@ get_task_data { } } } + "999.0": ${get_task_data."2.23"} { + request.properties.model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } get_all_ex { "2.23" { diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 7ce64bf..9cb8d89 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -561,7 +561,6 @@ def _get_multitask_plots( metrics: MetricVariants = None, scroll_id=None, no_scroll=True, - model_events=False, ) -> Tuple[dict, int, str]: task_names = { t.id: t.name for t in itertools.chain.from_iterable(companies.values()) @@ -598,7 +597,6 @@ def get_multi_task_plots(call, company_id, _): last_iters=iters, scroll_id=scroll_id, no_scroll=no_scroll, - model_events=model_events, ) call.result.data = dict( plots=return_events, diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 1a3983a..f570419 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -1,6 +1,6 @@ from datetime import datetime from functools import partial -from typing import Sequence +from typing import Sequence, Union from mongoengine import Q, EmbeddedDocument @@ -59,6 +59,11 @@ org_bll = OrgBLL() project_bll = ProjectBLL() +def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]): + conform_output_tags(call, model_data) + unescape_metadata(call, model_data) + + @endpoint("models.get_by_id", required_fields=["model"]) def get_by_id(call: APICall, company_id, _): model_id = call.data["model"] @@ -74,8 +79,7 @@ def get_by_id(call: APICall, company_id, _): raise errors.bad_request.InvalidModelId( "no such public or company model", id=model_id, company=company_id, ) - conform_output_tags(call, models[0]) - unescape_metadata(call, models[0]) + conform_model_data(call, models[0]) call.result.data = {"model": models[0]} @@ -102,8 +106,7 @@ def get_by_task_id(call: APICall, company_id, _): "no such public or company model", id=model_id, company=company_id, ) model_dict = model.to_proper_dict() - conform_output_tags(call, model_dict) - unescape_metadata(call, model_dict) + conform_model_data(call, model_dict) call.result.data = {"model": model_dict} @@ -119,8 +122,7 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): allow_public=request.allow_public, ret_params=ret_params, ) - conform_output_tags(call, models) - unescape_metadata(call, models) + conform_model_data(call, models) if not request.include_stats: call.result.data = {"models": models, **ret_params} @@ -142,8 +144,7 @@ def get_by_id_ex(call: APICall, company_id, _): models = Model.get_many_with_join( company=company_id, query_dict=call.data, allow_public=True ) - conform_output_tags(call, models) - unescape_metadata(call, models) + conform_model_data(call, models) call.result.data = {"models": models} @@ -159,8 +160,7 @@ def get_all(call: APICall, company_id, _): allow_public=True, ret_params=ret_params, ) - conform_output_tags(call, models) - unescape_metadata(call, models) + conform_model_data(call, models) call.result.data = {"models": models, **ret_params} @@ -428,8 +428,7 @@ def edit(call: APICall, company_id, _): _reset_cached_tags(company_id, projects=[new_project, model.project]) else: _update_cached_tags(company_id, project=model.project, fields=fields) - conform_output_tags(call, fields) - unescape_metadata(call, fields) + conform_model_data(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0) @@ -461,8 +460,7 @@ def _update_model(call: APICall, company_id, model_id=None): _update_cached_tags( company_id, project=model.project, fields=updated_fields ) - conform_output_tags(call, updated_fields) - unescape_metadata(call, updated_fields) + conform_model_data(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields) diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index d09359b..1a3161f 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -15,10 +15,10 @@ from apiserver.apimodels.projects import ( DeleteRequest, MoveRequest, MergeRequest, - ProjectOrNoneRequest, ProjectRequest, ProjectModelMetadataValuesRequest, ProjectChildrenType, + GetUniqueMetricsRequest, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL, ProjectQueries @@ -345,16 +345,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest): @endpoint( - "projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest + "projects.get_unique_metric_variants", request_data_model=GetUniqueMetricsRequest ) def get_unique_metric_variants( - call: APICall, company_id: str, request: ProjectOrNoneRequest + call: APICall, company_id: str, request: GetUniqueMetricsRequest, ): metrics = project_queries.get_unique_metric_variants( company_id, [request.project] if request.project else None, include_subprojects=request.include_subprojects, + model_metrics=request.model_metrics, ) call.result.data = {"metrics": metrics} diff --git a/apiserver/services/queues.py b/apiserver/services/queues.py index e6c95b5..77c9ba6 100644 --- a/apiserver/services/queues.py +++ b/apiserver/services/queues.py @@ -1,3 +1,5 @@ +from typing import Union, Sequence + from mongoengine import Q from apiserver.apimodels.base import UpdateResponse @@ -39,14 +41,18 @@ worker_bll = WorkerBLL() queue_bll = QueueBLL(worker_bll) +def conform_queue_data(call: APICall, queue_data: Union[Sequence[dict], dict]): + conform_output_tags(call, queue_data) + unescape_metadata(call, queue_data) + + @endpoint("queues.get_by_id", min_version="2.4", request_data_model=GetByIdRequest) def get_by_id(call: APICall, company_id, request: GetByIdRequest): queue = queue_bll.get_by_id( company_id, request.queue, max_task_entries=request.max_task_entries ) queue_dict = queue.to_proper_dict() - conform_output_tags(call, queue_dict) - unescape_metadata(call, queue_dict) + conform_queue_data(call, queue_dict) call.result.data = {"queue": queue_dict} @@ -85,8 +91,7 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest): max_task_entries=request.max_task_entries, ret_params=ret_params, ) - conform_output_tags(call, queues) - unescape_metadata(call, queues) + conform_queue_data(call, queues) call.result.data = {"queues": queues, **ret_params} @@ -102,8 +107,7 @@ def get_all(call: APICall, company: str, request: GetAllRequest): max_task_entries=request.max_task_entries, ret_params=ret_params, ) - conform_output_tags(call, queues) - unescape_metadata(call, queues) + conform_queue_data(call, queues) call.result.data = {"queues": queues, **ret_params} @@ -135,8 +139,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest): updated, fields = queue_bll.update( company_id=company_id, queue_id=req_model.queue, **data ) - conform_output_tags(call, fields) - unescape_metadata(call, fields) + conform_queue_data(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) diff --git a/apiserver/services/reports.py b/apiserver/services/reports.py index 8615a75..f4ae4da 100644 --- a/apiserver/services/reports.py +++ b/apiserver/services/reports.py @@ -17,6 +17,8 @@ from apiserver.apimodels.reports import ( from apiserver.apierrors import errors from apiserver.apimodels.base import UpdateResponse from apiserver.bll.project.project_bll import reports_project_name, reports_tag +from apiserver.database.model.model import Model +from apiserver.services.models import conform_model_data from apiserver.services.utils import process_include_subprojects, sort_tags_response from apiserver.bll.organization import OrgBLL from apiserver.bll.project import ProjectBLL @@ -35,7 +37,7 @@ from apiserver.services.events import ( from apiserver.services.tasks import ( escape_execution_parameters, _hidden_query, - unprepare_from_saved, + conform_task_data, ) org_bll = OrgBLL() @@ -178,7 +180,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllRequest): allow_public=request.allow_public, ret_params=ret_params, ) - unprepare_from_saved(call, tasks) + conform_task_data(call, tasks) call.result.data = {"tasks": tasks, **ret_params} @@ -198,18 +200,25 @@ def _get_task_metrics_from_request( @endpoint("reports.get_task_data") def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): + if request.model_events: + entity_cls = Model + conform_data = conform_model_data + else: + entity_cls = Task + conform_data = conform_task_data + call_data = escape_execution_parameters(call) process_include_subprojects(call_data) ret_params = {} - tasks = Task.get_many_with_join( + tasks = entity_cls.get_many_with_join( company=company_id, query_dict=call_data, query=_hidden_query(call_data), allow_public=request.allow_public, ret_params=ret_params, ) - unprepare_from_saved(call, tasks) + conform_data(call, tasks) res = {"tasks": tasks, **ret_params} if not ( request.debug_images or request.plots or request.scalar_metrics_iter_histogram @@ -217,7 +226,9 @@ def get_task_data(call: APICall, company_id, request: GetTasksDataRequest): return res task_ids = [task["id"] for task in tasks] - companies = _get_task_or_model_index_companies(company_id, task_ids=task_ids) + companies = _get_task_or_model_index_companies( + company_id, task_ids=task_ids, model_events=request.model_events + ) if request.debug_images: result = event_bll.debug_images_iterator.get_task_events( companies={ diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 45fa6f2..d814966 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -182,7 +182,7 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest): req_model.task, company_id=company_id, allow_public=True ) task_dict = task.to_proper_dict() - unprepare_from_saved(call, task_dict) + conform_task_data(call, task_dict) call.result.data = {"task": task_dict} @@ -231,7 +231,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq): allow_public=request.allow_public, ret_params=ret_params, ) - unprepare_from_saved(call, tasks) + conform_task_data(call, tasks) call.result.data = {"tasks": tasks, **ret_params} @@ -245,7 +245,7 @@ def get_by_id_ex(call: APICall, company_id, _): company=company_id, query_dict=call_data, allow_public=True, ) - unprepare_from_saved(call, tasks) + conform_task_data(call, tasks) call.result.data = {"tasks": tasks} @@ -264,7 +264,7 @@ def get_all(call: APICall, company_id, _): allow_public=True, ret_params=ret_params, ) - unprepare_from_saved(call, tasks) + conform_task_data(call, tasks) call.result.data = {"tasks": tasks, **ret_params} @@ -430,7 +430,7 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None): return fields -def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]): +def conform_task_data(call: APICall, tasks_data: Union[Sequence[dict], dict]): if isinstance(tasks_data, dict): tasks_data = [tasks_data] @@ -608,7 +608,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest): company_id, project=task.project, fields=updated_fields ) update_project_time(updated_fields.get("project")) - unprepare_from_saved(call, updated_fields) + conform_task_data(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields) @@ -763,7 +763,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): company_id, project=task.project, fields=fixed_fields ) update_project_time(fields.get("project")) - unprepare_from_saved(call, fields) + conform_task_data(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: call.result.data_model = UpdateResponse(updated=0) diff --git a/apiserver/tests/automated/test_reports.py b/apiserver/tests/automated/test_reports.py index 00b562a..b56c5f0 100644 --- a/apiserver/tests/automated/test_reports.py +++ b/apiserver/tests/automated/test_reports.py @@ -113,66 +113,75 @@ class TestReports(TestService): def test_reports_task_data(self): report_task = self._temp_report(name="Rep1") - non_report_task = self._temp_task(name="hello") - debug_image_events = [ - self._create_task_event( - task=non_report_task, - type_="training_debug_image", - iteration=1, - metric=f"Metric_{m}", - variant=f"Variant_{v}", - url=f"{m}_{v}", + for model_events in (False, True): + if model_events: + non_report_task = self._temp_model(name="hello") + event_args = {"model_event": True} + else: + non_report_task = self._temp_task(name="hello") + event_args = {} + debug_image_events = [ + self._create_task_event( + task=non_report_task, + type_="training_debug_image", + iteration=1, + metric=f"Metric_{m}", + variant=f"Variant_{v}", + url=f"{m}_{v}", + **event_args, + ) + for m in range(2) + for v in range(2) + ] + plot_events = [ + self._create_task_event( + task=non_report_task, + type_="plot", + iteration=1, + metric=f"Metric_{m}", + variant=f"Variant_{v}", + plot_str=f"Hello plot", + **event_args, + ) + for m in range(2) + for v in range(2) + ] + self.send_batch([*debug_image_events, *plot_events]) + + res = self.api.reports.get_task_data( + id=[non_report_task], only_fields=["name"], model_events=model_events ) - for m in range(2) - for v in range(2) - ] - plot_events = [ - self._create_task_event( - task=non_report_task, - type_="plot", - iteration=1, - metric=f"Metric_{m}", - variant=f"Variant_{v}", - plot_str=f"Hello plot", + self.assertEqual(len(res.tasks), 1) + self.assertEqual(res.tasks[0].id, non_report_task) + self.assertFalse(any(field in res for field in ("plots", "debug_images"))) + + res = self.api.reports.get_task_data( + id=[non_report_task], + only_fields=["name"], + debug_images={"metrics": []}, + plots={"metrics": [{"metric": "Metric_1"}]}, + model_events=model_events, ) - for m in range(2) - for v in range(2) - ] - self.send_batch([*debug_image_events, *plot_events]) + self.assertEqual(len(res.debug_images), 1) + task_events = res.debug_images[0] + self.assertEqual(task_events.task, non_report_task) + self.assertEqual(len(task_events.iterations), 1) + self.assertEqual(len(task_events.iterations[0].events), 4) - res = self.api.reports.get_task_data( - id=[non_report_task], only_fields=["name"], - ) - self.assertEqual(len(res.tasks), 1) - self.assertEqual(res.tasks[0].id, non_report_task) - self.assertFalse(any(field in res for field in ("plots", "debug_images"))) - - res = self.api.reports.get_task_data( - id=[non_report_task], - only_fields=["name"], - debug_images={"metrics": []}, - plots={"metrics": [{"metric": "Metric_1"}]}, - ) - self.assertEqual(len(res.debug_images), 1) - task_events = res.debug_images[0] - self.assertEqual(task_events.task, non_report_task) - self.assertEqual(len(task_events.iterations), 1) - self.assertEqual(len(task_events.iterations[0].events), 4) - - self.assertEqual(len(res.plots), 1) - for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")): - tasks = nested_get(res.plots, (m, v)) - self.assertEqual(len(tasks), 1) - task_plots = tasks[non_report_task] - self.assertEqual(len(task_plots), 1) - iter_plots = task_plots["1"] - self.assertEqual(iter_plots.name, "hello") - self.assertEqual(len(iter_plots.plots), 1) - ev = iter_plots.plots[0] - self.assertEqual(ev["metric"], m) - self.assertEqual(ev["variant"], v) - self.assertEqual(ev["task"], non_report_task) - self.assertEqual(ev["iter"], 1) + self.assertEqual(len(res.plots), 1) + for m, v in (("Metric_1", "Variant_0"), ("Metric_1", "Variant_1")): + tasks = nested_get(res.plots, (m, v)) + self.assertEqual(len(tasks), 1) + task_plots = tasks[non_report_task] + self.assertEqual(len(task_plots), 1) + iter_plots = task_plots["1"] + self.assertEqual(iter_plots.name, "hello") + self.assertEqual(len(iter_plots.plots), 1) + ev = iter_plots.plots[0] + self.assertEqual(ev["metric"], m) + self.assertEqual(ev["variant"], v) + self.assertEqual(ev["task"], non_report_task) + self.assertEqual(ev["iter"], 1) @staticmethod def _create_task_event(type_, task, iteration, **kwargs): @@ -185,12 +194,14 @@ class TestReports(TestService): **kwargs, } + delete_params = {"force": True} + def _temp_report(self, name, **kwargs): return self.create_temp( "reports", name=name, object_name="task", - delete_params={"force": True}, + delete_params=self.delete_params, **kwargs, ) @@ -199,10 +210,16 @@ class TestReports(TestService): "tasks", name=name, type="training", - delete_params={"force": True}, + delete_params=self.delete_params, **kwargs, ) + def _temp_model(self, name="test model events", **kwargs): + self.update_missing( + kwargs, name=name, uri="file:///a/b", labels={}, ready=False + ) + return self.create_temp("models", delete_params=self.delete_params, **kwargs) + def send_batch(self, events): _, data = self.api.send_batch("events.add_batch", events) return data diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index 126ca49..8dbdc7f 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -16,9 +16,7 @@ class TestTaskEvents(TestService): delete_params = dict(can_fail=True, force=True) def _temp_task(self, name="test task events"): - task_input = dict( - name=name, type="training", - ) + task_input = dict(name=name, type="training",) return self.create_temp( "tasks", delete_paramse=self.delete_params, **task_input ) @@ -220,6 +218,17 @@ class TestTaskEvents(TestService): self.assertEqual(variant_data.x, [0, 1]) self.assertEqual(variant_data.y, [0.0, 1.0]) + model_data = self.api.models.get_all_ex( + id=[model], only_fields=["last_metrics", "last_iteration"] + ).models[0] + metric_data = first(first(model_data.last_metrics.values()).values()) + self.assertEqual(1, model_data.last_iteration) + self.assertEqual(1, metric_data.value) + self.assertEqual(1, metric_data.max_value) + self.assertEqual(1, metric_data.max_value_iteration) + self.assertEqual(0, metric_data.min_value) + self.assertEqual(0, metric_data.min_value_iteration) + def test_error_events(self): task = self._temp_task() events = [