Add reference field to serving models

This commit is contained in:
clearml 2024-12-05 22:24:18 +02:00
parent 0b61ec2a56
commit 77e7fb5c13
4 changed files with 65 additions and 13 deletions

View File

@ -1,4 +1,5 @@
from enum import Enum from enum import Enum
from typing import Sequence
from jsonmodels.models import Base from jsonmodels.models import Base
from jsonmodels.fields import ( from jsonmodels.fields import (
@ -6,8 +7,10 @@ from jsonmodels.fields import (
EmbeddedField, EmbeddedField,
DateTimeField, DateTimeField,
IntField, IntField,
FloatField, BoolField, FloatField,
BoolField,
) )
from jsonmodels import validators
from jsonmodels.validators import Min from jsonmodels.validators import Min
from apiserver.apimodels import ListField, JsonSerializableMixin from apiserver.apimodels import ListField, JsonSerializableMixin
@ -16,6 +19,14 @@ from apiserver.config_repo import config
from .workers import MachineStats from .workers import MachineStats
class ReferenceItem(Base):
type = StringField(
required=True,
validators=validators.Enum("app_id", "app_instance", "model", "task", "url"),
)
value = StringField(required=True)
class ServingModel(Base): class ServingModel(Base):
container_id = StringField(required=True) container_id = StringField(required=True)
endpoint_name = StringField(required=True) endpoint_name = StringField(required=True)
@ -28,12 +39,15 @@ class ServingModel(Base):
input_size = IntField() input_size = IntField()
tags = ListField(str) tags = ListField(str)
system_tags = ListField(str) system_tags = ListField(str)
reference: Sequence[ReferenceItem] = ListField(ReferenceItem)
class RegisterRequest(ServingModel): class RegisterRequest(ServingModel):
timeout = IntField( timeout = IntField(
default=int(config.get("services.serving.default_container_timeout_sec", 10 * 60)), default=int(
validators=[Min(1)] config.get("services.serving.default_container_timeout_sec", 10 * 60)
),
validators=[Min(1)],
) )
""" registration timeout in seconds (default is 10min) """ """ registration timeout in seconds (default is 10min) """
@ -84,7 +98,5 @@ class GetEndpointMetricsHistoryRequest(Base):
to_date = FloatField(required=True, validators=Min(0)) to_date = FloatField(required=True, validators=Min(0))
interval = IntField(required=True, validators=Min(1)) interval = IntField(required=True, validators=Min(1))
endpoint_url = StringField(required=True) endpoint_url = StringField(required=True)
metric_type = ActualEnumField( metric_type = ActualEnumField(MetricType, default=MetricType.requests)
MetricType, default=MetricType.requests
)
instance_charts = BoolField(default=True) instance_charts = BoolField(default=True)

View File

@ -207,7 +207,9 @@ class ServingBLL:
if not self._count: if not self._count:
return None return None
avg = self._total / self._count avg = self._total / self._count
return round(avg, self.float_precision) if self.float_precision else round(avg) return (
round(avg, self.float_precision) if self.float_precision else round(avg)
)
def _get_summary(self, entries: Sequence[ServingContainerEntry]) -> dict: def _get_summary(self, entries: Sequence[ServingContainerEntry]) -> dict:
counters = [ counters = [
@ -263,7 +265,9 @@ class ServingBLL:
by_url.pop(None, None) by_url.pop(None, None)
return [self._get_summary(url_entries) for url_entries in by_url.values()] return [self._get_summary(url_entries) for url_entries in by_url.values()]
def _get_endpoint_entries(self, company_id, endpoint_url: Union[str, None]) -> Sequence[ServingContainerEntry]: def _get_endpoint_entries(
self, company_id, endpoint_url: Union[str, None]
) -> Sequence[ServingContainerEntry]:
url_key = self._get_url_key(company_id, endpoint_url) url_key = self._get_url_key(company_id, endpoint_url)
timestamp = int(time()) timestamp = int(time())
self.redis.zremrangebyscore(url_key, min=0, max=timestamp) self.redis.zremrangebyscore(url_key, min=0, max=timestamp)
@ -328,7 +332,6 @@ class ServingBLL:
"endpoint": entry.endpoint_name, "endpoint": entry.endpoint_name,
"model": entry.model_name, "model": entry.model_name,
"url": entry.endpoint_url, "url": entry.endpoint_url,
} }
) )
@ -352,7 +355,10 @@ class ServingBLL:
"requests_min": entry.requests_min, "requests_min": entry.requests_min,
"latency_ms": entry.latency_ms, "latency_ms": entry.latency_ms,
"last_update": self._naive_time(entry.last_activity_time), "last_update": self._naive_time(entry.last_activity_time),
"reference": [ref.to_struct() for ref in entry.reference]
if isinstance(entry.reference, list)
else entry.reference,
} }
for entry in entries for entry in entries
] ],
} }

View File

@ -1,13 +1,33 @@
_description: "Serving apis" _description: "Serving apis"
_definitions { _definitions {
include "_workers_common.conf" include "_workers_common.conf"
reference_item {
type: object
required = [type, value]
properties {
type {
description: The type of the reference item
type: string
enum: [app_id, app_instance, model, task, url]
}
value {
description: The reference item value
type: string
}
}
}
reference {
description: Array of reference items provided by the container instance. Can contain multiple reference items with the same type
type: array
items: ${_definitions.reference_item}
}
serving_model_report { serving_model_report {
type: object type: object
required: [container_id, endpoint_name, model_name] required: [container_id, endpoint_name, model_name]
properties { properties {
container_id { container_id {
type: string type: string
description: Container ID description: Container ID. Should uniquely identify a specific container instance
} }
endpoint_name { endpoint_name {
type: string type: string
@ -41,6 +61,7 @@ _definitions {
type: integer type: integer
description: Input size in bytes description: Input size in bytes
} }
reference: ${_definitions.reference}
} }
} }
endpoint_stats { endpoint_stats {
@ -113,6 +134,8 @@ _definitions {
format: "date-time" format: "date-time"
description: The latest time when the container instance sent update description: The latest time when the container instance sent update
} }
reference: ${_definitions.reference}
} }
} }
serving_model_info { serving_model_info {

View File

@ -9,6 +9,12 @@ class TestServing(TestService):
container_id1 = "container_1" container_id1 = "container_1"
container_id2 = "container_2" container_id2 = "container_2"
url = "http://test_url" url = "http://test_url"
reference = [
{"type": "app_id", "value": "test"},
{"type": "app_instance", "value": "abd478c8"},
{"type": "model", "value": "262829d3"},
{"type": "model", "value": "7ea29c04"},
]
container_infos = [ container_infos = [
{ {
"container_id": container_id, # required "container_id": container_id, # required
@ -22,6 +28,7 @@ class TestServing(TestService):
"input_size": 9_000_000, # optional right now, bytes "input_size": 9_000_000, # optional right now, bytes
"tags": ["tag1", "tag2"], # optional "tags": ["tag1", "tag2"], # optional
"system_tags": None, # optional "system_tags": None, # optional
**({"reference": reference} if container_id == container_id1 else {}),
} }
for container_id in (container_id1, container_id2) for container_id in (container_id1, container_id2)
] ]
@ -61,11 +68,15 @@ class TestServing(TestService):
"requests", "requests",
"requests_min", "requests_min",
"latency_ms", "latency_ms",
"reference",
) )
] ]
for inst in details.instances for inst in details.instances
}, },
{"container_1": [1000, 1000, 5, 100], "container_2": [2000, 2000, 10, 200]}, {
"container_1": [1000, 1000, 5, 100, reference],
"container_2": [2000, 2000, 10, 200, []],
},
) )
# make sure that the first call did not invalidate anything # make sure that the first call did not invalidate anything
new_details = self.api.serving.get_endpoint_details(endpoint_url=url) new_details = self.api.serving.get_endpoint_details(endpoint_url=url)
@ -93,7 +104,7 @@ class TestServing(TestService):
self.assertEqual(res.computed_interval, 40) self.assertEqual(res.computed_interval, 40)
self.assertEqual(res.total.title, title) self.assertEqual(res.total.title, title)
length = len(res.total.dates) length = len(res.total.dates)
self.assertTrue(3>=length>=1) self.assertTrue(3 >= length >= 1)
self.assertEqual(len(res.total["values"]), length) self.assertEqual(len(res.total["values"]), length)
self.assertIn(value, res.total["values"]) self.assertIn(value, res.total["values"])
self.assertEqual(set(res.instances), {container_id1, container_id2}) self.assertEqual(set(res.instances), {container_id1, container_id2})