mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 19:06:55 +00:00
443 lines
15 KiB
Python
443 lines
15 KiB
Python
import abc
|
|
import operator
|
|
from operator import attrgetter
|
|
from typing import Sequence, Tuple, Optional, Callable, Mapping
|
|
|
|
import attr
|
|
from boltons.iterutils import first, bucketize
|
|
from elasticsearch import Elasticsearch
|
|
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
|
from jsonmodels.models import Base
|
|
from redis import StrictRedis
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apimodels import JsonSerializableMixin
|
|
from apiserver.bll.event.event_common import (
|
|
EventSettings,
|
|
EventType,
|
|
check_empty_data,
|
|
search_company_events,
|
|
get_max_metric_and_variant_counts,
|
|
)
|
|
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
|
from apiserver.database.errors import translate_errors_context
|
|
from apiserver.timing_context import TimingContext
|
|
from apiserver.utilities.dicts import nested_get
|
|
|
|
|
|
class VariantState(Base):
|
|
name: str = StringField(required=True)
|
|
metric: str = StringField(default=None)
|
|
min_iteration: int = IntField()
|
|
max_iteration: int = IntField()
|
|
|
|
|
|
class HistorySampleState(Base, JsonSerializableMixin):
|
|
id: str = StringField(required=True)
|
|
iteration: int = IntField()
|
|
variant: str = StringField()
|
|
task: str = StringField()
|
|
metric: str = StringField()
|
|
reached_first: bool = BoolField()
|
|
reached_last: bool = BoolField()
|
|
variant_states: Sequence[VariantState] = ListField([VariantState])
|
|
warning: str = StringField()
|
|
navigate_current_metric = BoolField(default=True)
|
|
|
|
|
|
@attr.s(auto_attribs=True)
|
|
class HistorySampleResult(object):
|
|
scroll_id: str = None
|
|
event: dict = None
|
|
min_iteration: int = None
|
|
max_iteration: int = None
|
|
|
|
|
|
class HistorySampleIterator(abc.ABC):
|
|
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
|
self.es = es
|
|
self.event_type = event_type
|
|
self.cache_manager = RedisCacheManager(
|
|
state_class=HistorySampleState,
|
|
redis=redis,
|
|
expiration_interval=EventSettings.state_expiration_sec,
|
|
)
|
|
|
|
def get_next_sample(
|
|
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
|
|
) -> HistorySampleResult:
|
|
"""
|
|
Get the sample for next/prev variant on the current iteration
|
|
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
|
"""
|
|
res = HistorySampleResult(scroll_id=state_id)
|
|
state = self.cache_manager.get_state(state_id)
|
|
if not state or state.task != task:
|
|
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
|
|
|
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
|
return res
|
|
|
|
event = self._get_next_for_current_iteration(
|
|
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
|
) or self._get_next_for_another_iteration(
|
|
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
|
)
|
|
if not event:
|
|
return res
|
|
|
|
self._fill_res_and_update_state(event=event, res=res, state=state)
|
|
self.cache_manager.set_state(state=state)
|
|
return res
|
|
|
|
def _fill_res_and_update_state(
|
|
self, event: dict, res: HistorySampleResult, state: HistorySampleState
|
|
):
|
|
self._process_event(event)
|
|
state.variant = event["variant"]
|
|
state.metric = event["metric"]
|
|
state.iteration = event["iter"]
|
|
res.event = event
|
|
var_state = first(
|
|
vs
|
|
for vs in state.variant_states
|
|
if vs.name == state.variant and vs.metric == state.metric
|
|
)
|
|
if var_state:
|
|
res.min_iteration = var_state.min_iteration
|
|
res.max_iteration = var_state.max_iteration
|
|
|
|
@abc.abstractmethod
|
|
def _get_extra_conditions(self) -> Sequence[dict]:
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def _process_event(self, event: dict) -> dict:
|
|
pass
|
|
|
|
@abc.abstractmethod
|
|
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
|
pass
|
|
|
|
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
|
|
metrics = bucketize(variants, key=attrgetter("metric"))
|
|
metrics_conditions = [
|
|
{
|
|
"bool": {
|
|
"must": [
|
|
{"term": {"metric": metric}},
|
|
self._get_variants_conditions(vs),
|
|
]
|
|
}
|
|
}
|
|
for metric, vs in metrics.items()
|
|
]
|
|
return {"bool": {"should": metrics_conditions}}
|
|
|
|
def _get_next_for_current_iteration(
|
|
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
|
) -> Optional[dict]:
|
|
"""
|
|
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
|
Only variants for which the iteration falls into their valid range are considered
|
|
Return None if no such variant or sample is found
|
|
"""
|
|
if state.navigate_current_metric:
|
|
variants = [
|
|
var_state
|
|
for var_state in state.variant_states
|
|
if var_state.metric == state.metric
|
|
]
|
|
else:
|
|
variants = state.variant_states
|
|
|
|
cmp = operator.lt if navigate_earlier else operator.gt
|
|
variants = [
|
|
var_state
|
|
for var_state in variants
|
|
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
|
and var_state.min_iteration <= state.iteration
|
|
]
|
|
if not variants:
|
|
return
|
|
|
|
must_conditions = [
|
|
{"term": {"task": state.task}},
|
|
{"term": {"iter": state.iteration}},
|
|
self._get_metric_variants_condition(variants),
|
|
*self._get_extra_conditions(),
|
|
]
|
|
order = "desc" if navigate_earlier else "asc"
|
|
es_req = {
|
|
"size": 1,
|
|
"sort": [{"metric": order}, {"variant": order}],
|
|
"query": {"bool": {"must": must_conditions}},
|
|
}
|
|
|
|
with translate_errors_context(), TimingContext(
|
|
"es", "get_next_for_current_iteration"
|
|
):
|
|
es_res = search_company_events(
|
|
self.es,
|
|
company_id=company_id,
|
|
event_type=self.event_type,
|
|
body=es_req,
|
|
)
|
|
|
|
hits = nested_get(es_res, ("hits", "hits"))
|
|
if not hits:
|
|
return
|
|
|
|
return hits[0]["_source"]
|
|
|
|
def _get_next_for_another_iteration(
|
|
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
|
) -> Optional[dict]:
|
|
"""
|
|
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
|
or from the last variant for the previous iteration (otherwise)
|
|
The variants for which the sample falls in invalid range are discarded
|
|
If no suitable sample is found then None is returned
|
|
"""
|
|
if state.navigate_current_metric:
|
|
variants = [
|
|
var_state
|
|
for var_state in state.variant_states
|
|
if var_state.metric == state.metric
|
|
]
|
|
else:
|
|
variants = state.variant_states
|
|
|
|
if navigate_earlier:
|
|
range_operator = "lt"
|
|
order = "desc"
|
|
variants = [
|
|
var_state
|
|
for var_state in variants
|
|
if var_state.min_iteration < state.iteration
|
|
]
|
|
else:
|
|
range_operator = "gt"
|
|
order = "asc"
|
|
variants = variants
|
|
|
|
if not variants:
|
|
return
|
|
|
|
must_conditions = [
|
|
{"term": {"task": state.task}},
|
|
self._get_metric_variants_condition(variants),
|
|
{"range": {"iter": {range_operator: state.iteration}}},
|
|
*self._get_extra_conditions(),
|
|
]
|
|
es_req = {
|
|
"size": 1,
|
|
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
|
"query": {"bool": {"must": must_conditions}},
|
|
}
|
|
with translate_errors_context(), TimingContext(
|
|
"es", "get_next_for_another_iteration"
|
|
):
|
|
es_res = search_company_events(
|
|
self.es,
|
|
company_id=company_id,
|
|
event_type=self.event_type,
|
|
body=es_req,
|
|
)
|
|
|
|
hits = nested_get(es_res, ("hits", "hits"))
|
|
if not hits:
|
|
return
|
|
|
|
return hits[0]["_source"]
|
|
|
|
def get_sample_for_variant(
|
|
self,
|
|
company_id: str,
|
|
task: str,
|
|
metric: str,
|
|
variant: str,
|
|
iteration: Optional[int] = None,
|
|
refresh: bool = False,
|
|
state_id: str = None,
|
|
navigate_current_metric: bool = True,
|
|
) -> HistorySampleResult:
|
|
"""
|
|
Get the sample for the requested iteration or the latest before it
|
|
If the iteration is not passed then get the latest event
|
|
"""
|
|
res = HistorySampleResult()
|
|
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
|
return res
|
|
|
|
def init_state(state_: HistorySampleState):
|
|
state_.task = task
|
|
state_.metric = metric
|
|
state_.navigate_current_metric = navigate_current_metric
|
|
self._reset_variant_states(company_id=company_id, state=state_)
|
|
|
|
def validate_state(state_: HistorySampleState):
|
|
if (
|
|
state_.task != task
|
|
or state_.navigate_current_metric != navigate_current_metric
|
|
or (state_.navigate_current_metric and state_.metric != metric)
|
|
):
|
|
raise errors.bad_request.InvalidScrollId(
|
|
"Task and metric stored in the state do not match the passed ones",
|
|
scroll_id=state_.id,
|
|
)
|
|
# fix old variant states:
|
|
for vs in state_.variant_states:
|
|
if vs.metric is None:
|
|
vs.metric = metric
|
|
if refresh:
|
|
self._reset_variant_states(company_id=company_id, state=state_)
|
|
|
|
state: HistorySampleState
|
|
with self.cache_manager.get_or_create_state(
|
|
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
|
) as state:
|
|
res.scroll_id = state.id
|
|
|
|
var_state = first(
|
|
vs
|
|
for vs in state.variant_states
|
|
if vs.name == variant and vs.metric == metric
|
|
)
|
|
if not var_state:
|
|
return res
|
|
|
|
res.min_iteration = var_state.min_iteration
|
|
res.max_iteration = var_state.max_iteration
|
|
|
|
must_conditions = [
|
|
{"term": {"task": task}},
|
|
{"term": {"metric": metric}},
|
|
{"term": {"variant": variant}},
|
|
*self._get_extra_conditions(),
|
|
]
|
|
if iteration is not None:
|
|
must_conditions.append(
|
|
{
|
|
"range": {
|
|
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
|
}
|
|
}
|
|
)
|
|
else:
|
|
must_conditions.append(
|
|
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
|
)
|
|
|
|
es_req = {
|
|
"size": 1,
|
|
"sort": {"iter": "desc"},
|
|
"query": {"bool": {"must": must_conditions}},
|
|
}
|
|
|
|
with translate_errors_context(), TimingContext(
|
|
"es", "get_history_sample_for_variant"
|
|
):
|
|
es_res = search_company_events(
|
|
self.es,
|
|
company_id=company_id,
|
|
event_type=self.event_type,
|
|
body=es_req,
|
|
)
|
|
|
|
hits = nested_get(es_res, ("hits", "hits"))
|
|
if not hits:
|
|
return res
|
|
|
|
self._fill_res_and_update_state(
|
|
event=hits[0]["_source"], res=res, state=state
|
|
)
|
|
return res
|
|
|
|
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
|
|
metrics = self._get_metric_variant_iterations(
|
|
company_id=company_id,
|
|
task=state.task,
|
|
metric=state.metric if state.navigate_current_metric else None,
|
|
)
|
|
state.variant_states = [
|
|
VariantState(
|
|
metric=metric,
|
|
name=var_name,
|
|
min_iteration=min_iter,
|
|
max_iteration=max_iter,
|
|
)
|
|
for metric, variants in metrics.items()
|
|
for var_name, min_iter, max_iter in variants
|
|
]
|
|
|
|
@abc.abstractmethod
|
|
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
|
pass
|
|
|
|
def _get_metric_variant_iterations(
|
|
self, company_id: str, task: str, metric: str,
|
|
) -> Mapping[str, Tuple[str, str, int, int]]:
|
|
"""
|
|
Return valid min and max iterations that the task reported events of the required type
|
|
"""
|
|
must = [
|
|
{"term": {"task": task}},
|
|
*self._get_extra_conditions(),
|
|
]
|
|
if metric is not None:
|
|
must.append({"term": {"metric": metric}})
|
|
query = {"bool": {"must": must}}
|
|
|
|
search_args = dict(
|
|
es=self.es, company_id=company_id, event_type=self.event_type
|
|
)
|
|
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
|
query=query, **search_args
|
|
)
|
|
max_variants = int(max_variants // 2)
|
|
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
|
|
es_req: dict = {
|
|
"size": 0,
|
|
"query": query,
|
|
"aggs": {
|
|
"metrics": {
|
|
"terms": {
|
|
"field": "metric",
|
|
"size": max_metrics,
|
|
"order": {"_key": "asc"},
|
|
},
|
|
"aggs": {
|
|
"variants": {
|
|
"terms": {
|
|
"field": "variant",
|
|
"size": max_variants,
|
|
"order": {"_key": "asc"},
|
|
},
|
|
"aggs": min_max_aggs,
|
|
}
|
|
},
|
|
}
|
|
},
|
|
}
|
|
|
|
with translate_errors_context(), TimingContext(
|
|
"es", "get_history_sample_iterations"
|
|
):
|
|
es_res = search_company_events(body=es_req, **search_args,)
|
|
|
|
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
|
variant = variant_bucket["key"]
|
|
min_iter, max_iter = get_min_max_data(variant_bucket)
|
|
return variant, min_iter, max_iter
|
|
|
|
return {
|
|
metric_bucket["key"]: [
|
|
get_variant_data(variant_bucket)
|
|
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
|
]
|
|
for metric_bucket in nested_get(
|
|
es_res, ("aggregations", "metrics", "buckets")
|
|
)
|
|
}
|