from datetime import datetime, timedelta from enum import Enum, auto from operator import attrgetter from time import time from typing import Optional, Sequence, Union import attr from boltons.iterutils import chunked_iter, bucketize from pyhocon import ConfigTree from apiserver.apimodels.serving import ( ServingContainerEntry, RegisterRequest, StatusReportRequest, ) from apiserver.apierrors import errors from apiserver.config_repo import config from apiserver.redis_manager import redman from .stats import ServingStats log = config.logger(__file__) class ServingBLL: def __init__(self, redis=None): self.conf = config.get("services.serving", ConfigTree()) self.redis = redis or redman.connection("workers") @staticmethod def _get_url_key(company: str, url: str): return f"serving_url_{company}_{url}" @staticmethod def _get_container_key(company: str, container_id: str) -> str: """Build redis key from company and container_id""" return f"serving_container_{company}_{container_id}" def _save_serving_container_entry(self, entry: ServingContainerEntry): self.redis.setex( entry.key, timedelta(seconds=entry.register_timeout), entry.to_json() ) url_key = self._get_url_key(entry.company_id, entry.endpoint_url) expiration = int(time()) + entry.register_timeout container_item = {entry.key: expiration} self.redis.zadd(url_key, container_item) # make sure that url set will not get stuck in redis # indefinitely in case no more containers report to it self.redis.expire(url_key, max(3600, entry.register_timeout)) def _get_serving_container_entry( self, company_id: str, container_id: str ) -> Optional[ServingContainerEntry]: """ Get a container entry for the provided container ID. """ key = self._get_container_key(company_id, container_id) data = self.redis.get(key) if not data: return try: entry = ServingContainerEntry.from_json(data) return entry except Exception as e: msg = "Failed parsing container entry" log.exception(f"{msg}: {str(e)}") def register_serving_container( self, company_id: str, request: RegisterRequest, ip: str = "", ) -> ServingContainerEntry: """ Register a serving container """ now = datetime.utcnow() key = self._get_container_key(company_id, request.container_id) entry = ServingContainerEntry( **request.to_struct(), key=key, company_id=company_id, ip=ip, register_time=now, register_timeout=request.timeout, last_activity_time=now, ) self._save_serving_container_entry(entry) return entry def unregister_serving_container( self, company_id: str, container_id: str, ) -> None: """ Unregister a serving container """ entry = self._get_serving_container_entry(company_id, container_id) if entry: url_key = self._get_url_key(entry.company_id, entry.endpoint_url) self.redis.zrem(url_key, entry.key) key = self._get_container_key(company_id, container_id) res = self.redis.delete(key) if res: return if not self.conf.get("container_auto_unregister", True): raise errors.bad_request.ContainerNotRegistered(container=container_id) def container_status_report( self, company_id: str, report: StatusReportRequest, ip: str = "", ) -> None: """ Serving container status report """ container_id = report.container_id now = datetime.utcnow() entry = self._get_serving_container_entry(company_id, container_id) if entry: ip = ip or entry.ip register_time = entry.register_time register_timeout = entry.register_timeout else: if not self.conf.get("container_auto_register", True): raise errors.bad_request.ContainerNotRegistered(container=container_id) ip = ip register_time = now register_timeout = int( self.conf.get("default_container_timeout_sec", 10 * 60) ) key = self._get_container_key(company_id, container_id) entry = ServingContainerEntry( **report.to_struct(), key=key, company_id=company_id, ip=ip, register_time=register_time, register_timeout=register_timeout, last_activity_time=now, ) self._save_serving_container_entry(entry) ServingStats.log_stats_to_es(entry) def _get_all( self, company_id: str, ) -> Sequence[ServingContainerEntry]: keys = list(self.redis.scan_iter(self._get_container_key(company_id, "*"))) entries = [] for keys in chunked_iter(keys, 1000): data = self.redis.mget(keys) if not data: continue for d in data: try: entries.append(ServingContainerEntry.from_json(d)) except Exception as ex: log.error(f"Failed parsing container entry {str(ex)}") return entries @attr.s(auto_attribs=True) class Counter: class AggType(Enum): avg = auto() max = auto() total = auto() count = auto() name: str field: str agg_type: AggType float_precision: int = None _max: Union[int, float, datetime] = attr.field(init=False, default=None) _total: Union[int, float] = attr.field(init=False, default=0) _count: int = attr.field(init=False, default=0) def add(self, entry: ServingContainerEntry): value = getattr(entry, self.field, None) if value is None: return self._count += 1 if self.agg_type == self.AggType.max: self._max = value if self._max is None else max(self._max, value) else: self._total += value def __call__(self): if self.agg_type == self.AggType.count: return self._count if self.agg_type == self.AggType.max: return self._max if self.agg_type == self.AggType.total: return self._total if not self._count: return None avg = self._total / self._count return ( round(avg, self.float_precision) if self.float_precision else round(avg) ) def _get_summary(self, entries: Sequence[ServingContainerEntry]) -> dict: counters = [ self.Counter( name="uptime_sec", field="uptime_sec", agg_type=self.Counter.AggType.max, ), self.Counter( name="requests", field="requests_num", agg_type=self.Counter.AggType.total, ), self.Counter( name="requests_min", field="requests_min", agg_type=self.Counter.AggType.avg, float_precision=2, ), self.Counter( name="latency_ms", field="latency_ms", agg_type=self.Counter.AggType.avg, ), self.Counter( name="last_update", field="last_activity_time", agg_type=self.Counter.AggType.max, ), ] for entry in entries: for counter in counters: counter.add(entry) first_entry = entries[0] ret = { "endpoint": first_entry.endpoint_name, "model": first_entry.model_name, "url": first_entry.endpoint_url, "instances": len(entries), **{counter.name: counter() for counter in counters}, } ret["last_update"] = self._naive_time(ret.get("last_update")) return ret def get_endpoints(self, company_id: str): """ Group instances by urls and return a summary for each url Do not return data for "loading" instances that have no url """ entries = self._get_all(company_id) by_url = bucketize(entries, key=attrgetter("endpoint_url")) by_url.pop(None, None) return [self._get_summary(url_entries) for url_entries in by_url.values()] def _get_endpoint_entries( self, company_id, endpoint_url: Union[str, None] ) -> Sequence[ServingContainerEntry]: url_key = self._get_url_key(company_id, endpoint_url) timestamp = int(time()) self.redis.zremrangebyscore(url_key, min=0, max=timestamp) container_keys = {key.decode() for key in self.redis.zrange(url_key, 0, -1)} if not container_keys: return [] entries = [] found_keys = set() data = self.redis.mget(container_keys) or [] for d in data: try: entry = ServingContainerEntry.from_json(d) if entry.endpoint_url == endpoint_url: entries.append(entry) found_keys.add(entry.key) except Exception as ex: log.error(f"Failed parsing container entry {str(ex)}") missing_keys = container_keys - found_keys if missing_keys: self.redis.zrem(url_key, *missing_keys) return entries def get_loading_instances(self, company_id: str): entries = self._get_endpoint_entries(company_id, None) return [ { "id": entry.container_id, "endpoint": entry.endpoint_name, "url": entry.endpoint_url, "model": entry.model_name, "model_source": entry.model_source, "model_version": entry.model_version, "preprocess_artifact": entry.preprocess_artifact, "input_type": entry.input_type, "input_size": entry.input_size, "uptime_sec": entry.uptime_sec, "last_update": self._naive_time(entry.last_activity_time), } for entry in entries ] @staticmethod def _naive_time(input_: datetime) -> datetime: if not isinstance(input_, datetime): return input_ return input_.replace(tzinfo=None) def get_endpoint_details(self, company_id, endpoint_url: str) -> dict: entries = self._get_endpoint_entries(company_id, endpoint_url) if not entries: raise errors.bad_request.NoContainersForUrl(url=endpoint_url) instances = [] entry: ServingContainerEntry for entry in entries: instances.append( { "endpoint": entry.endpoint_name, "model": entry.model_name, "url": entry.endpoint_url, } ) first_entry = entries[0] return { "endpoint": first_entry.endpoint_name, "model": first_entry.model_name, "url": first_entry.endpoint_url, "preprocess_artifact": first_entry.preprocess_artifact, "input_type": first_entry.input_type, "input_size": first_entry.input_size, "model_source": first_entry.model_source, "model_version": first_entry.model_version, "uptime_sec": max(e.uptime_sec for e in entries), "last_update": self._naive_time(max(e.last_activity_time for e in entries)), "instances": [ { "id": entry.container_id, "uptime_sec": entry.uptime_sec, "requests": entry.requests_num, "requests_min": entry.requests_min, "latency_ms": entry.latency_ms, "last_update": self._naive_time(entry.last_activity_time), "reference": [ref.to_struct() for ref in entry.reference] if isinstance(entry.reference, list) else entry.reference, } for entry in entries ], }