Add support for model statistics

This commit is contained in:
allegroai 2022-07-08 17:39:41 +03:00
parent 0c15169668
commit adc1825843
5 changed files with 109 additions and 32 deletions

View File

@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata):
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True)
class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)

View File

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Callable, Tuple
from typing import Callable, Tuple, Sequence, Dict
from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse
@ -128,3 +128,33 @@ class ModelBLL:
)
return unarchived
@classmethod
def get_model_stats(
cls, company: str, model_ids: Sequence[str],
) -> Dict[str, dict]:
if not model_ids:
return {}
result = Model.aggregate(
[
{
"$match": {
"company": {"$in": [None, "", company]},
"_id": {"$in": model_ids},
}
},
{
"$addFields": {
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{
"$project": {"labels_count": 1},
},
]
)
return {
r.pop("_id"): r
for r in result
}

View File

@ -104,6 +104,16 @@ _definitions {
"$ref": "#/definitions/metadata_item"
}
}
stats {
description: "Model statistics"
type: object
properties {
labels_count {
description: Number of the model labels
type: integer
}
}
}
}
}
published_task_item {
@ -224,6 +234,13 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
"999.0": ${get_all_ex."2.15"} {
request.properties.include_stats {
description: "If true, include models statistic in response"
type: boolean
default: false
}
}
}
get_all {
"2.1" {

View File

@ -20,6 +20,7 @@ from apiserver.apimodels.models import (
AddOrUpdateMetadataRequest,
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
ModelsGetRequest,
)
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
@ -117,8 +118,8 @@ def _process_include_subprojects(call_data: dict):
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
_process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
@ -132,6 +133,20 @@ def get_all_ex(call: APICall, company_id, _):
)
conform_output_tags(call, models)
unescape_metadata(call, models)
if not request.include_stats:
call.result.data = {"models": models, **ret_params}
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
for model in models:
model["stats"] = stats.get(model["id"])
call.result.data = {"models": models, **ret_params}

View File

@ -77,7 +77,7 @@ class TestModelsService(TestService):
def test_publish_output_model_no_task(self):
model_id = self.create_temp(
service="models", name='test', uri='file:///a', labels={}, ready=False
service="models", name="test", uri="file:///a", labels={}, ready=False
)
self._assert_model_ready(model_id, False)
@ -109,7 +109,7 @@ class TestModelsService(TestService):
def test_publish_task_no_output_model(self):
task_id = self.create_temp(
service="tasks", type='testing', name='server-test', input=dict(view={})
service="tasks", type="testing", name="server-test", input=dict(view={})
)
self.api.tasks.started(task=task_id)
self.api.tasks.stopped(task=task_id)
@ -118,31 +118,48 @@ class TestModelsService(TestService):
assert res.updated == 1 # model updated
self._assert_task_status(task_id, PUBLISHED)
def test_get_models_stats(self):
model1 = self._create_model(labels={"hello": 1, "world": 2})
model2 = self._create_model(labels={"foo": 1})
model3 = self._create_model()
# no stats
res = self.api.models.get_all_ex(id=[model1, model2, model3]).models
self.assertEqual(len(res), 3)
self.assertTrue(all("stats" not in m for m in res))
# stats
res = self.api.models.get_all_ex(
id=[model1, model2, model3], include_stats=True
).models
self.assertEqual(len(res), 3)
stats = {m.id: m.stats.labels_count for m in res}
self.assertEqual(stats[model1], 2)
self.assertEqual(stats[model2], 1)
self.assertEqual(stats[model3], 0)
def test_update_model_iteration_with_task(self):
task_id = self._create_task()
model_id = self._create_model()
self.api.models.update(model=model_id, task=task_id, iteration=1000, labels={"foo": 1})
self.api.models.update(
model=model_id, task=task_id, iteration=1000, labels={"foo": 1}
)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
self.api.models.update(model=model_id, task=task_id, iteration=500)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
def test_update_model_for_task_iteration(self):
task_id = self._create_task()
res = self.api.models.update_for_task(
task=task_id,
name="test model",
uri="file:///b",
iteration=999,
task=task_id, name="test model", uri="file:///b", iteration=999,
)
model_id = res.id
@ -150,22 +167,19 @@ class TestModelsService(TestService):
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
999
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 999
)
self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888)
self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
1000
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
)
def test_get_frameworks(self):
@ -238,8 +252,8 @@ class TestModelsService(TestService):
return self.create_temp(
service="models",
delete_params=dict(can_fail=True, force=True),
name=kwargs.pop("name", 'test'),
uri=kwargs.pop("name", 'file:///a'),
name=kwargs.pop("name", "test"),
uri=kwargs.pop("name", "file:///a"),
labels=kwargs.pop("labels", {}),
**kwargs,
)
@ -247,8 +261,8 @@ class TestModelsService(TestService):
def _create_task(self, **kwargs):
task_id = self.create_temp(
service="tasks",
type=kwargs.pop("type", 'testing'),
name=kwargs.pop("name", 'server-test'),
type=kwargs.pop("type", "testing"),
name=kwargs.pop("name", "server-test"),
input=kwargs.pop("input", dict(view={})),
**kwargs,
)
@ -257,21 +271,18 @@ class TestModelsService(TestService):
def _create_task_and_model(self):
execution_model_id = self.create_temp(
service="models",
name='test',
uri='file:///a',
labels={}
service="models", name="test", uri="file:///a", labels={}
)
task_id = self.create_temp(
service="tasks",
type='testing',
name='server-test',
type="testing",
name="server-test",
input=dict(view={}),
execution=dict(model=execution_model_id)
execution=dict(model=execution_model_id),
)
self.api.tasks.started(task=task_id)
output_model_id = self.api.models.update_for_task(
task=task_id, uri='file:///b'
task=task_id, uri="file:///b"
)["id"]
return task_id, output_model_id