mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
Add support for model statistics
This commit is contained in:
parent
0c15169668
commit
adc1825843
@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata):
|
||||
|
||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||
model = fields.StringField(required=True)
|
||||
|
||||
|
||||
class ModelsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Tuple, Sequence, Dict
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||
@ -128,3 +128,33 @@ class ModelBLL:
|
||||
)
|
||||
|
||||
return unarchived
|
||||
|
||||
@classmethod
|
||||
def get_model_stats(
|
||||
cls, company: str, model_ids: Sequence[str],
|
||||
) -> Dict[str, dict]:
|
||||
if not model_ids:
|
||||
return {}
|
||||
|
||||
result = Model.aggregate(
|
||||
[
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"_id": {"$in": model_ids},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$addFields": {
|
||||
"labels_count": {"$size": {"$objectToArray": "$labels"}}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {"labels_count": 1},
|
||||
},
|
||||
]
|
||||
)
|
||||
return {
|
||||
r.pop("_id"): r
|
||||
for r in result
|
||||
}
|
||||
|
@ -104,6 +104,16 @@ _definitions {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
stats {
|
||||
description: "Model statistics"
|
||||
type: object
|
||||
properties {
|
||||
labels_count {
|
||||
description: Number of the model labels
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
published_task_item {
|
||||
@ -224,6 +234,13 @@ get_all_ex {
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
"999.0": ${get_all_ex."2.15"} {
|
||||
request.properties.include_stats {
|
||||
description: "If true, include models statistic in response"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
|
@ -20,6 +20,7 @@ from apiserver.apimodels.models import (
|
||||
AddOrUpdateMetadataRequest,
|
||||
ModelsPublishManyRequest,
|
||||
ModelsDeleteManyRequest,
|
||||
ModelsGetRequest,
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
@ -117,8 +118,8 @@ def _process_include_subprojects(call_data: dict):
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
_process_include_subprojects(call.data)
|
||||
Metadata.escape_query_parameters(call)
|
||||
@ -132,6 +133,20 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
|
||||
if not request.include_stats:
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
return
|
||||
|
||||
model_ids = {model["id"] for model in models}
|
||||
stats = ModelBLL.get_model_stats(
|
||||
company=company_id,
|
||||
model_ids=list(model_ids),
|
||||
)
|
||||
|
||||
for model in models:
|
||||
model["stats"] = stats.get(model["id"])
|
||||
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
|
@ -77,7 +77,7 @@ class TestModelsService(TestService):
|
||||
|
||||
def test_publish_output_model_no_task(self):
|
||||
model_id = self.create_temp(
|
||||
service="models", name='test', uri='file:///a', labels={}, ready=False
|
||||
service="models", name="test", uri="file:///a", labels={}, ready=False
|
||||
)
|
||||
self._assert_model_ready(model_id, False)
|
||||
|
||||
@ -109,7 +109,7 @@ class TestModelsService(TestService):
|
||||
|
||||
def test_publish_task_no_output_model(self):
|
||||
task_id = self.create_temp(
|
||||
service="tasks", type='testing', name='server-test', input=dict(view={})
|
||||
service="tasks", type="testing", name="server-test", input=dict(view={})
|
||||
)
|
||||
self.api.tasks.started(task=task_id)
|
||||
self.api.tasks.stopped(task=task_id)
|
||||
@ -118,31 +118,48 @@ class TestModelsService(TestService):
|
||||
assert res.updated == 1 # model updated
|
||||
self._assert_task_status(task_id, PUBLISHED)
|
||||
|
||||
def test_get_models_stats(self):
|
||||
model1 = self._create_model(labels={"hello": 1, "world": 2})
|
||||
model2 = self._create_model(labels={"foo": 1})
|
||||
model3 = self._create_model()
|
||||
|
||||
# no stats
|
||||
res = self.api.models.get_all_ex(id=[model1, model2, model3]).models
|
||||
self.assertEqual(len(res), 3)
|
||||
self.assertTrue(all("stats" not in m for m in res))
|
||||
|
||||
# stats
|
||||
res = self.api.models.get_all_ex(
|
||||
id=[model1, model2, model3], include_stats=True
|
||||
).models
|
||||
self.assertEqual(len(res), 3)
|
||||
stats = {m.id: m.stats.labels_count for m in res}
|
||||
self.assertEqual(stats[model1], 2)
|
||||
self.assertEqual(stats[model2], 1)
|
||||
self.assertEqual(stats[model3], 0)
|
||||
|
||||
def test_update_model_iteration_with_task(self):
|
||||
task_id = self._create_task()
|
||||
model_id = self._create_model()
|
||||
self.api.models.update(model=model_id, task=task_id, iteration=1000, labels={"foo": 1})
|
||||
self.api.models.update(
|
||||
model=model_id, task=task_id, iteration=1000, labels={"foo": 1}
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
||||
1000
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||
)
|
||||
|
||||
self.api.models.update(model=model_id, task=task_id, iteration=500)
|
||||
|
||||
self.assertEqual(
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
||||
1000
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||
)
|
||||
|
||||
def test_update_model_for_task_iteration(self):
|
||||
task_id = self._create_task()
|
||||
|
||||
res = self.api.models.update_for_task(
|
||||
task=task_id,
|
||||
name="test model",
|
||||
uri="file:///b",
|
||||
iteration=999,
|
||||
task=task_id, name="test model", uri="file:///b", iteration=999,
|
||||
)
|
||||
|
||||
model_id = res.id
|
||||
@ -150,22 +167,19 @@ class TestModelsService(TestService):
|
||||
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
|
||||
|
||||
self.assertEqual(
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
||||
999
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 999
|
||||
)
|
||||
|
||||
self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000)
|
||||
|
||||
self.assertEqual(
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
||||
1000
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||
)
|
||||
|
||||
self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888)
|
||||
|
||||
self.assertEqual(
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
||||
1000
|
||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||
)
|
||||
|
||||
def test_get_frameworks(self):
|
||||
@ -238,8 +252,8 @@ class TestModelsService(TestService):
|
||||
return self.create_temp(
|
||||
service="models",
|
||||
delete_params=dict(can_fail=True, force=True),
|
||||
name=kwargs.pop("name", 'test'),
|
||||
uri=kwargs.pop("name", 'file:///a'),
|
||||
name=kwargs.pop("name", "test"),
|
||||
uri=kwargs.pop("name", "file:///a"),
|
||||
labels=kwargs.pop("labels", {}),
|
||||
**kwargs,
|
||||
)
|
||||
@ -247,8 +261,8 @@ class TestModelsService(TestService):
|
||||
def _create_task(self, **kwargs):
|
||||
task_id = self.create_temp(
|
||||
service="tasks",
|
||||
type=kwargs.pop("type", 'testing'),
|
||||
name=kwargs.pop("name", 'server-test'),
|
||||
type=kwargs.pop("type", "testing"),
|
||||
name=kwargs.pop("name", "server-test"),
|
||||
input=kwargs.pop("input", dict(view={})),
|
||||
**kwargs,
|
||||
)
|
||||
@ -257,21 +271,18 @@ class TestModelsService(TestService):
|
||||
|
||||
def _create_task_and_model(self):
|
||||
execution_model_id = self.create_temp(
|
||||
service="models",
|
||||
name='test',
|
||||
uri='file:///a',
|
||||
labels={}
|
||||
service="models", name="test", uri="file:///a", labels={}
|
||||
)
|
||||
task_id = self.create_temp(
|
||||
service="tasks",
|
||||
type='testing',
|
||||
name='server-test',
|
||||
type="testing",
|
||||
name="server-test",
|
||||
input=dict(view={}),
|
||||
execution=dict(model=execution_model_id)
|
||||
execution=dict(model=execution_model_id),
|
||||
)
|
||||
self.api.tasks.started(task=task_id)
|
||||
output_model_id = self.api.models.update_for_task(
|
||||
task=task_id, uri='file:///b'
|
||||
task=task_id, uri="file:///b"
|
||||
)["id"]
|
||||
|
||||
return task_id, output_model_id
|
||||
|
Loading…
Reference in New Issue
Block a user