mirror of
https://github.com/clearml/clearml-server
synced 2025-05-09 22:31:14 +00:00
Add support for model statistics
This commit is contained in:
parent
0c15169668
commit
adc1825843
@ -75,3 +75,7 @@ class DeleteMetadataRequest(DeleteMetadata):
|
|||||||
|
|
||||||
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
class AddOrUpdateMetadataRequest(AddOrUpdateMetadata):
|
||||||
model = fields.StringField(required=True)
|
model = fields.StringField(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelsGetRequest(models.Base):
|
||||||
|
include_stats = fields.BoolField(default=False)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple, Sequence, Dict
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.apimodels.models import ModelTaskPublishResponse
|
from apiserver.apimodels.models import ModelTaskPublishResponse
|
||||||
@ -128,3 +128,33 @@ class ModelBLL:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return unarchived
|
return unarchived
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_stats(
|
||||||
|
cls, company: str, model_ids: Sequence[str],
|
||||||
|
) -> Dict[str, dict]:
|
||||||
|
if not model_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
result = Model.aggregate(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
"company": {"$in": [None, "", company]},
|
||||||
|
"_id": {"$in": model_ids},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$addFields": {
|
||||||
|
"labels_count": {"$size": {"$objectToArray": "$labels"}}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$project": {"labels_count": 1},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
r.pop("_id"): r
|
||||||
|
for r in result
|
||||||
|
}
|
||||||
|
@ -104,6 +104,16 @@ _definitions {
|
|||||||
"$ref": "#/definitions/metadata_item"
|
"$ref": "#/definitions/metadata_item"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
stats {
|
||||||
|
description: "Model statistics"
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
labels_count {
|
||||||
|
description: Number of the model labels
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
published_task_item {
|
published_task_item {
|
||||||
@ -224,6 +234,13 @@ get_all_ex {
|
|||||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${get_all_ex."2.15"} {
|
||||||
|
request.properties.include_stats {
|
||||||
|
description: "If true, include models statistic in response"
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_all {
|
get_all {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
|
@ -20,6 +20,7 @@ from apiserver.apimodels.models import (
|
|||||||
AddOrUpdateMetadataRequest,
|
AddOrUpdateMetadataRequest,
|
||||||
ModelsPublishManyRequest,
|
ModelsPublishManyRequest,
|
||||||
ModelsDeleteManyRequest,
|
ModelsDeleteManyRequest,
|
||||||
|
ModelsGetRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.model import ModelBLL, Metadata
|
from apiserver.bll.model import ModelBLL, Metadata
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
@ -117,8 +118,8 @@ def _process_include_subprojects(call_data: dict):
|
|||||||
call_data["project"] = project_ids_with_children(project_ids)
|
call_data["project"] = project_ids_with_children(project_ids)
|
||||||
|
|
||||||
|
|
||||||
@endpoint("models.get_all_ex", required_fields=[])
|
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
|
||||||
def get_all_ex(call: APICall, company_id, _):
|
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
|
||||||
conform_tag_fields(call, call.data)
|
conform_tag_fields(call, call.data)
|
||||||
_process_include_subprojects(call.data)
|
_process_include_subprojects(call.data)
|
||||||
Metadata.escape_query_parameters(call)
|
Metadata.escape_query_parameters(call)
|
||||||
@ -132,6 +133,20 @@ def get_all_ex(call: APICall, company_id, _):
|
|||||||
)
|
)
|
||||||
conform_output_tags(call, models)
|
conform_output_tags(call, models)
|
||||||
unescape_metadata(call, models)
|
unescape_metadata(call, models)
|
||||||
|
|
||||||
|
if not request.include_stats:
|
||||||
|
call.result.data = {"models": models, **ret_params}
|
||||||
|
return
|
||||||
|
|
||||||
|
model_ids = {model["id"] for model in models}
|
||||||
|
stats = ModelBLL.get_model_stats(
|
||||||
|
company=company_id,
|
||||||
|
model_ids=list(model_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
for model in models:
|
||||||
|
model["stats"] = stats.get(model["id"])
|
||||||
|
|
||||||
call.result.data = {"models": models, **ret_params}
|
call.result.data = {"models": models, **ret_params}
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,7 +77,7 @@ class TestModelsService(TestService):
|
|||||||
|
|
||||||
def test_publish_output_model_no_task(self):
|
def test_publish_output_model_no_task(self):
|
||||||
model_id = self.create_temp(
|
model_id = self.create_temp(
|
||||||
service="models", name='test', uri='file:///a', labels={}, ready=False
|
service="models", name="test", uri="file:///a", labels={}, ready=False
|
||||||
)
|
)
|
||||||
self._assert_model_ready(model_id, False)
|
self._assert_model_ready(model_id, False)
|
||||||
|
|
||||||
@ -109,7 +109,7 @@ class TestModelsService(TestService):
|
|||||||
|
|
||||||
def test_publish_task_no_output_model(self):
|
def test_publish_task_no_output_model(self):
|
||||||
task_id = self.create_temp(
|
task_id = self.create_temp(
|
||||||
service="tasks", type='testing', name='server-test', input=dict(view={})
|
service="tasks", type="testing", name="server-test", input=dict(view={})
|
||||||
)
|
)
|
||||||
self.api.tasks.started(task=task_id)
|
self.api.tasks.started(task=task_id)
|
||||||
self.api.tasks.stopped(task=task_id)
|
self.api.tasks.stopped(task=task_id)
|
||||||
@ -118,31 +118,48 @@ class TestModelsService(TestService):
|
|||||||
assert res.updated == 1 # model updated
|
assert res.updated == 1 # model updated
|
||||||
self._assert_task_status(task_id, PUBLISHED)
|
self._assert_task_status(task_id, PUBLISHED)
|
||||||
|
|
||||||
|
def test_get_models_stats(self):
|
||||||
|
model1 = self._create_model(labels={"hello": 1, "world": 2})
|
||||||
|
model2 = self._create_model(labels={"foo": 1})
|
||||||
|
model3 = self._create_model()
|
||||||
|
|
||||||
|
# no stats
|
||||||
|
res = self.api.models.get_all_ex(id=[model1, model2, model3]).models
|
||||||
|
self.assertEqual(len(res), 3)
|
||||||
|
self.assertTrue(all("stats" not in m for m in res))
|
||||||
|
|
||||||
|
# stats
|
||||||
|
res = self.api.models.get_all_ex(
|
||||||
|
id=[model1, model2, model3], include_stats=True
|
||||||
|
).models
|
||||||
|
self.assertEqual(len(res), 3)
|
||||||
|
stats = {m.id: m.stats.labels_count for m in res}
|
||||||
|
self.assertEqual(stats[model1], 2)
|
||||||
|
self.assertEqual(stats[model2], 1)
|
||||||
|
self.assertEqual(stats[model3], 0)
|
||||||
|
|
||||||
def test_update_model_iteration_with_task(self):
|
def test_update_model_iteration_with_task(self):
|
||||||
task_id = self._create_task()
|
task_id = self._create_task()
|
||||||
model_id = self._create_model()
|
model_id = self._create_model()
|
||||||
self.api.models.update(model=model_id, task=task_id, iteration=1000, labels={"foo": 1})
|
self.api.models.update(
|
||||||
|
model=model_id, task=task_id, iteration=1000, labels={"foo": 1}
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||||
1000
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.api.models.update(model=model_id, task=task_id, iteration=500)
|
self.api.models.update(model=model_id, task=task_id, iteration=500)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||||
1000
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_update_model_for_task_iteration(self):
|
def test_update_model_for_task_iteration(self):
|
||||||
task_id = self._create_task()
|
task_id = self._create_task()
|
||||||
|
|
||||||
res = self.api.models.update_for_task(
|
res = self.api.models.update_for_task(
|
||||||
task=task_id,
|
task=task_id, name="test model", uri="file:///b", iteration=999,
|
||||||
name="test model",
|
|
||||||
uri="file:///b",
|
|
||||||
iteration=999,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = res.id
|
model_id = res.id
|
||||||
@ -150,22 +167,19 @@ class TestModelsService(TestService):
|
|||||||
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
|
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 999
|
||||||
999
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000)
|
self.api.models.update_for_task(task=task_id, uri="file:///c", iteration=1000)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||||
1000
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888)
|
self.api.models.update_for_task(task=task_id, uri="file:///d", iteration=888)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.api.tasks.get_by_id(task=task_id).task.last_iteration,
|
self.api.tasks.get_by_id(task=task_id).task.last_iteration, 1000
|
||||||
1000
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_frameworks(self):
|
def test_get_frameworks(self):
|
||||||
@ -238,8 +252,8 @@ class TestModelsService(TestService):
|
|||||||
return self.create_temp(
|
return self.create_temp(
|
||||||
service="models",
|
service="models",
|
||||||
delete_params=dict(can_fail=True, force=True),
|
delete_params=dict(can_fail=True, force=True),
|
||||||
name=kwargs.pop("name", 'test'),
|
name=kwargs.pop("name", "test"),
|
||||||
uri=kwargs.pop("name", 'file:///a'),
|
uri=kwargs.pop("name", "file:///a"),
|
||||||
labels=kwargs.pop("labels", {}),
|
labels=kwargs.pop("labels", {}),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -247,8 +261,8 @@ class TestModelsService(TestService):
|
|||||||
def _create_task(self, **kwargs):
|
def _create_task(self, **kwargs):
|
||||||
task_id = self.create_temp(
|
task_id = self.create_temp(
|
||||||
service="tasks",
|
service="tasks",
|
||||||
type=kwargs.pop("type", 'testing'),
|
type=kwargs.pop("type", "testing"),
|
||||||
name=kwargs.pop("name", 'server-test'),
|
name=kwargs.pop("name", "server-test"),
|
||||||
input=kwargs.pop("input", dict(view={})),
|
input=kwargs.pop("input", dict(view={})),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -257,21 +271,18 @@ class TestModelsService(TestService):
|
|||||||
|
|
||||||
def _create_task_and_model(self):
|
def _create_task_and_model(self):
|
||||||
execution_model_id = self.create_temp(
|
execution_model_id = self.create_temp(
|
||||||
service="models",
|
service="models", name="test", uri="file:///a", labels={}
|
||||||
name='test',
|
|
||||||
uri='file:///a',
|
|
||||||
labels={}
|
|
||||||
)
|
)
|
||||||
task_id = self.create_temp(
|
task_id = self.create_temp(
|
||||||
service="tasks",
|
service="tasks",
|
||||||
type='testing',
|
type="testing",
|
||||||
name='server-test',
|
name="server-test",
|
||||||
input=dict(view={}),
|
input=dict(view={}),
|
||||||
execution=dict(model=execution_model_id)
|
execution=dict(model=execution_model_id),
|
||||||
)
|
)
|
||||||
self.api.tasks.started(task=task_id)
|
self.api.tasks.started(task=task_id)
|
||||||
output_model_id = self.api.models.update_for_task(
|
output_model_id = self.api.models.update_for_task(
|
||||||
task=task_id, uri='file:///b'
|
task=task_id, uri="file:///b"
|
||||||
)["id"]
|
)["id"]
|
||||||
|
|
||||||
return task_id, output_model_id
|
return task_id, output_model_id
|
||||||
|
Loading…
Reference in New Issue
Block a user