From adc18258432beff42571a4034fdb0c32e374c7f3 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 8 Jul 2022 17:39:41 +0300 Subject: [PATCH] Add support for model statistics --- apiserver/apimodels/models.py | 4 ++ apiserver/bll/model/__init__.py | 32 ++++++++++- apiserver/schema/services/models.conf | 17 ++++++ apiserver/services/models.py | 19 ++++++- apiserver/tests/automated/test_models.py | 69 ++++++++++++++---------- 5 files changed, 109 insertions(+), 32 deletions(-) diff --git a/apiserver/apimodels/models.py b/apiserver/apimodels/models.py index 187377f..6c5a113 100644 --- a/apiserver/apimodels/models.py +++ b/apiserver/apimodels/models.py @@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata): class AddOrUpdateMetadataRequest(AddOrUpdateMetadata): model = fields.StringField(required=True) + + +class ModelsGetRequest(models.Base): + include_stats = fields.BoolField(default=False) diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py index 5b16574..3d47b77 100644 --- a/apiserver/bll/model/__init__.py +++ b/apiserver/bll/model/__init__.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Callable, Tuple +from typing import Callable, Tuple, Sequence, Dict from apiserver.apierrors import errors from apiserver.apimodels.models import ModelTaskPublishResponse @@ -128,3 +128,33 @@ class ModelBLL: ) return unarchived + + @classmethod + def get_model_stats( + cls, company: str, model_ids: Sequence[str], + ) -> Dict[str, dict]: + if not model_ids: + return {} + + result = Model.aggregate( + [ + { + "$match": { + "company": {"$in": [None, "", company]}, + "_id": {"$in": model_ids}, + } + }, + { + "$addFields": { + "labels_count": {"$size": {"$objectToArray": "$labels"}} + } + }, + { + "$project": {"labels_count": 1}, + }, + ] + ) + return { + r.pop("_id"): r + for r in result + } diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 8bb7268..0e7628c 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -104,6 +104,16 @@ _definitions { "$ref": "#/definitions/metadata_item" } } + stats { + description: "Model statistics" + type: object + properties { + labels_count { + description: Number of the model labels + type: integer + } + } + } } } published_task_item { @@ -224,6 +234,13 @@ get_all_ex { description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data" } } + "999.0": ${get_all_ex."2.15"} { + request.properties.include_stats { + description: "If true, include models statistic in response" + type: boolean + default: false + } + } } get_all { "2.1" { diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 10a0ef0..0e8cef1 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -20,6 +20,7 @@ from apiserver.apimodels.models import ( AddOrUpdateMetadataRequest, ModelsPublishManyRequest, ModelsDeleteManyRequest, + ModelsGetRequest, ) from apiserver.bll.model import ModelBLL, Metadata from apiserver.bll.organization import OrgBLL, Tags @@ -117,8 +118,8 @@ def _process_include_subprojects(call_data: dict): call_data["project"] = project_ids_with_children(project_ids) -@endpoint("models.get_all_ex", required_fields=[]) -def get_all_ex(call: APICall, company_id, _): +@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest) +def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): conform_tag_fields(call, call.data) _process_include_subprojects(call.data) Metadata.escape_query_parameters(call) @@ -132,6 +133,20 @@ def get_all_ex(call: APICall, company_id, _): ) conform_output_tags(call, models) unescape_metadata(call, models) + + if not request.include_stats: + call.result.data = {"models": models, **ret_params} + return + + model_ids = {model["id"] for model in models} + stats = ModelBLL.get_model_stats( + company=company_id, + model_ids=list(model_ids), + ) + + for model in models: + model["stats"] = stats.get(model["id"]) + call.result.data = {"models": models, **ret_params} diff --git a/apiserver/tests/automated/test_models.py b/apiserver/tests/automated/test_models.py index 2140aa6..a6dd2c2 100644 --- a/apiserver/tests/automated/test_models.py +++ b/apiserver/tests/automated/test_models.py @@ -77,7 +77,7 @@ class TestModelsService(TestService): def test_publish_output_model_no_task(self): model_id = self.create_temp( - service="models", name='test', uri='file:///a', labels={}, ready=False + service="models", name="test", uri="file:///a", labels={}, ready=False ) self._assert_model_ready(model_id, False) @@ -109,7 +109,7 @@ class TestModelsService(TestService): def test_publish_task_no_output_model(self): task_id = self.create_temp( - service="tasks", type='testing', name='server-test', input=dict(view={}) + service="tasks", type="testing", name="server-test", input=dict(view={}) ) self.api.tasks.started(task=task_id) self.api.tasks.stopped(task=task_id) @@ -118,31 +118,48 @@ class TestModelsService(TestService): assert res.updated == 1 # model updated self._assert_task_status(task_id, PUBLISHED) + def test_get_models_stats(self): + model1 = self._create_model(labels={"hello": 1, "world": 2}) + model2 = self._create_model(labels={"foo": 1}) + model3 = self._create_model() + + # no stats + res = self.api.models.get_all_ex(id=[model1, model2, model3]).models + self.assertEqual(len(res), 3) + self.assertTrue(all("stats" not in m for m in res)) + + # stats + res = self.api.models.get_all_ex( + id=[model1, model2, model3], include_stats=True + ).models + self.assertEqual(len(res), 3) + stats = {m.id: m.stats.labels_count for m in res} + self.assertEqual(stats[model1], 2) + self.assertEqual(stats[model2], 1) + self.assertEqual(stats[model3], 0) + def test_update_model_iteration_with_task(self): task_id = self._create_task() model_id = self._create_model() - self.api.models.update(model=model_id, task=task_id, iteration=1000, labels={"foo": 1}) + self.api.models.update( + model=model_id, task=task_id, iteration=1000, labels={"foo": 1} + ) self.assertEqual( - self.api.tasks.get_by_id(task=task_id).task.last_iteration, - 1000 + self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000 ) self.api.models.update(model=model_id, task=task_id, iteration=500) self.assertEqual( - self.api.tasks.get_by_id(task=task_id).task.last_iteration, - 1000 + self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000 ) def test_update_model_for_task_iteration(self): task_id = self._create_task() res = self.api.models.update_for_task( - task=task_id, - name="test model", - uri="file:///b", - iteration=999, + task=task_id, name="test model", uri="file:///b", iteration=999, ) model_id = res.id @@ -150,22 +167,19 @@ class TestModelsService(TestService): self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True) self.assertEqual( - self.api.tasks.get_by_id(task=task_id).task.last_iteration, - 999 + self.api.tasks.get_by_id(task=task_id).task.last_iteration, 999 ) self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000) self.assertEqual( - self.api.tasks.get_by_id(task=task_id).task.last_iteration, - 1000 + self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000 ) self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888) self.assertEqual( - self.api.tasks.get_by_id(task=task_id).task.last_iteration, - 1000 + self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000 ) def test_get_frameworks(self): @@ -238,8 +252,8 @@ class TestModelsService(TestService): return self.create_temp( service="models", delete_params=dict(can_fail=True, force=True), - name=kwargs.pop("name", 'test'), - uri=kwargs.pop("name", 'file:///a'), + name=kwargs.pop("name", "test"), + uri=kwargs.pop("name", "file:///a"), labels=kwargs.pop("labels", {}), **kwargs, ) @@ -247,8 +261,8 @@ class TestModelsService(TestService): def _create_task(self, **kwargs): task_id = self.create_temp( service="tasks", - type=kwargs.pop("type", 'testing'), - name=kwargs.pop("name", 'server-test'), + type=kwargs.pop("type", "testing"), + name=kwargs.pop("name", "server-test"), input=kwargs.pop("input", dict(view={})), **kwargs, ) @@ -257,21 +271,18 @@ class TestModelsService(TestService): def _create_task_and_model(self): execution_model_id = self.create_temp( - service="models", - name='test', - uri='file:///a', - labels={} + service="models", name="test", uri="file:///a", labels={} ) task_id = self.create_temp( service="tasks", - type='testing', - name='server-test', + type="testing", + name="server-test", input=dict(view={}), - execution=dict(model=execution_model_id) + execution=dict(model=execution_model_id), ) self.api.tasks.started(task=task_id) output_model_id = self.api.models.update_for_task( - task=task_id, uri='file:///b' + task=task_id, uri="file:///b" )["id"] return task_id, output_model_id