Add support for model statistics

This commit is contained in:
allegroai 2022-07-08 17:39:41 +03:00
parent 0c15169668
commit adc1825843
5 changed files with 109 additions and 32 deletions

View File

@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata):
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata): class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
model = fields.StringField(required=True) model = fields.StringField(required=True)
class ModelsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)

View File

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Callable, Tuple from typing import Callable, Tuple, Sequence, Dict
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.models import ModelTaskPublishResponse from apiserver.apimodels.models import ModelTaskPublishResponse
@ -128,3 +128,33 @@ class ModelBLL:
) )
return unarchived return unarchived
@classmethod
def get_model_stats(
cls, company: str, model_ids: Sequence[str],
) -> Dict[str, dict]:
if not model_ids:
return {}
result = Model.aggregate(
[
{
"$match": {
"company": {"$in": [None, "", company]},
"_id": {"$in": model_ids},
}
},
{
"$addFields": {
"labels_count": {"$size": {"$objectToArray": "$labels"}}
}
},
{
"$project": {"labels_count": 1},
},
]
)
return {
r.pop("_id"): r
for r in result
}

View File

@ -104,6 +104,16 @@ _definitions {
"$ref": "#/definitions/metadata_item" "$ref": "#/definitions/metadata_item"
} }
} }
stats {
description: "Model statistics"
type: object
properties {
labels_count {
description: Number of the model labels
type: integer
}
}
}
} }
} }
published_task_item { published_task_item {
@ -224,6 +234,13 @@ get_all_ex {
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data" description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
} }
} }
"999.0": ${get_all_ex."2.15"} {
request.properties.include_stats {
description: "If true, include models statistic in response"
type: boolean
default: false
}
}
} }
get_all { get_all {
"2.1" { "2.1" {

View File

@ -20,6 +20,7 @@ from apiserver.apimodels.models import (
AddOrUpdateMetadataRequest, AddOrUpdateMetadataRequest,
ModelsPublishManyRequest, ModelsPublishManyRequest,
ModelsDeleteManyRequest, ModelsDeleteManyRequest,
ModelsGetRequest,
) )
from apiserver.bll.model import ModelBLL, Metadata from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
@ -117,8 +118,8 @@ def _process_include_subprojects(call_data: dict):
call_data["project"] = project_ids_with_children(project_ids) call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", required_fields=[]) @endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, _): def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
_process_include_subprojects(call.data) _process_include_subprojects(call.data)
Metadata.escape_query_parameters(call) Metadata.escape_query_parameters(call)
@ -132,6 +133,20 @@ def get_all_ex(call: APICall, company_id, _):
) )
conform_output_tags(call, models) conform_output_tags(call, models)
unescape_metadata(call, models) unescape_metadata(call, models)
if not request.include_stats:
call.result.data = {"models": models, **ret_params}
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
for model in models:
model["stats"] = stats.get(model["id"])
call.result.data = {"models": models, **ret_params} call.result.data = {"models": models, **ret_params}

View File

@ -77,7 +77,7 @@ class TestModelsService(TestService):
def test_publish_output_model_no_task(self): def test_publish_output_model_no_task(self):
model_id = self.create_temp( model_id = self.create_temp(
service="models", name='test', uri='file:///a', labels={}, ready=False service="models", name="test", uri="file:///a", labels={}, ready=False
) )
self._assert_model_ready(model_id, False) self._assert_model_ready(model_id, False)
@ -109,7 +109,7 @@ class TestModelsService(TestService):
def test_publish_task_no_output_model(self): def test_publish_task_no_output_model(self):
task_id = self.create_temp( task_id = self.create_temp(
service="tasks", type='testing', name='server-test', input=dict(view={}) service="tasks", type="testing", name="server-test", input=dict(view={})
) )
self.api.tasks.started(task=task_id) self.api.tasks.started(task=task_id)
self.api.tasks.stopped(task=task_id) self.api.tasks.stopped(task=task_id)
@ -118,31 +118,48 @@ class TestModelsService(TestService):
assert res.updated == 1 # model updated assert res.updated == 1 # model updated
self._assert_task_status(task_id, PUBLISHED) self._assert_task_status(task_id, PUBLISHED)
def test_get_models_stats(self):
model1 = self._create_model(labels={"hello": 1, "world": 2})
model2 = self._create_model(labels={"foo": 1})
model3 = self._create_model()
# no stats
res = self.api.models.get_all_ex(id=[model1, model2, model3]).models
self.assertEqual(len(res), 3)
self.assertTrue(all("stats" not in m for m in res))
# stats
res = self.api.models.get_all_ex(
id=[model1, model2, model3], include_stats=True
).models
self.assertEqual(len(res), 3)
stats = {m.id: m.stats.labels_count for m in res}
self.assertEqual(stats[model1], 2)
self.assertEqual(stats[model2], 1)
self.assertEqual(stats[model3], 0)
def test_update_model_iteration_with_task(self): def test_update_model_iteration_with_task(self):
task_id = self._create_task() task_id = self._create_task()
model_id = self._create_model() model_id = self._create_model()
self.api.models.update(model=model_id, task=task_id, iteration=1000, labels={"foo": 1}) self.api.models.update(
model=model_id, task=task_id, iteration=1000, labels={"foo": 1}
)
self.assertEqual( self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration, self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
1000
) )
self.api.models.update(model=model_id, task=task_id, iteration=500) self.api.models.update(model=model_id, task=task_id, iteration=500)
self.assertEqual( self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration, self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
1000
) )
def test_update_model_for_task_iteration(self): def test_update_model_for_task_iteration(self):
task_id = self._create_task() task_id = self._create_task()
res = self.api.models.update_for_task( res = self.api.models.update_for_task(
task=task_id, task=task_id, name="test model", uri="file:///b", iteration=999,
name="test model",
uri="file:///b",
iteration=999,
) )
model_id = res.id model_id = res.id
@ -150,22 +167,19 @@ class TestModelsService(TestService):
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True) self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
self.assertEqual( self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration, self.api.tasks.get_by_id(task=task_id).task.last_iteration, 999
999
) )
self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000) self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000)
self.assertEqual( self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration, self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
1000
) )
self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888) self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888)
self.assertEqual( self.assertEqual(
self.api.tasks.get_by_id(task=task_id).task.last_iteration, self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
1000
) )
def test_get_frameworks(self): def test_get_frameworks(self):
@ -238,8 +252,8 @@ class TestModelsService(TestService):
return self.create_temp( return self.create_temp(
service="models", service="models",
delete_params=dict(can_fail=True, force=True), delete_params=dict(can_fail=True, force=True),
name=kwargs.pop("name", 'test'), name=kwargs.pop("name", "test"),
uri=kwargs.pop("name", 'file:///a'), uri=kwargs.pop("name", "file:///a"),
labels=kwargs.pop("labels", {}), labels=kwargs.pop("labels", {}),
**kwargs, **kwargs,
) )
@ -247,8 +261,8 @@ class TestModelsService(TestService):
def _create_task(self, **kwargs): def _create_task(self, **kwargs):
task_id = self.create_temp( task_id = self.create_temp(
service="tasks", service="tasks",
type=kwargs.pop("type", 'testing'), type=kwargs.pop("type", "testing"),
name=kwargs.pop("name", 'server-test'), name=kwargs.pop("name", "server-test"),
input=kwargs.pop("input", dict(view={})), input=kwargs.pop("input", dict(view={})),
**kwargs, **kwargs,
) )
@ -257,21 +271,18 @@ class TestModelsService(TestService):
def _create_task_and_model(self): def _create_task_and_model(self):
execution_model_id = self.create_temp( execution_model_id = self.create_temp(
service="models", service="models", name="test", uri="file:///a", labels={}
name='test',
uri='file:///a',
labels={}
) )
task_id = self.create_temp( task_id = self.create_temp(
service="tasks", service="tasks",
type='testing', type="testing",
name='server-test', name="server-test",
input=dict(view={}), input=dict(view={}),
execution=dict(model=execution_model_id) execution=dict(model=execution_model_id),
) )
self.api.tasks.started(task=task_id) self.api.tasks.started(task=task_id)
output_model_id = self.api.models.update_for_task( output_model_id = self.api.models.update_for_task(
task=task_id, uri='file:///b' task=task_id, uri="file:///b"
)["id"] )["id"]
return task_id, output_model_id return task_id, output_model_id