mirror of
https://github.com/clearml/clearml-server
synced 2025-04-26 00:49:45 +00:00
Support returning multiple plots during history navigation
This commit is contained in:
parent
bc23f1b0cf
commit
97992b0d9e
@ -61,13 +61,20 @@ class MetricEventsRequest(Base):
|
||||
model_events: bool = BoolField()
|
||||
|
||||
|
||||
class TaskMetricVariant(Base):
|
||||
class GetVariantSampleRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
variant: str = StringField(required=True)
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
navigate_current_metric: bool = BoolField(default=True)
|
||||
model_events: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class GetHistorySampleRequest(TaskMetricVariant):
|
||||
class GetMetricSamplesRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(required=True)
|
||||
iteration: Optional[int] = IntField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
|
@ -26,7 +26,7 @@ from apiserver.bll.event.event_common import (
|
||||
)
|
||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
|
||||
from apiserver.bll.event.history_plot_iterator import HistoryPlotIterator
|
||||
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
@ -93,7 +93,7 @@ class EventBLL(object):
|
||||
es=self.es, redis=self.redis
|
||||
)
|
||||
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
|
||||
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis)
|
||||
self.plot_sample_history = HistoryPlotsIterator(es=self.es, redis=self.redis)
|
||||
self.events_iterator = EventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
|
@ -1,21 +1,126 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, ListField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .event_common import EventType
|
||||
from .history_sample_iterator import HistorySampleIterator, VariantState
|
||||
from .event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class HistoryDebugImageIterator(HistorySampleIterator):
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class DebugImageSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class VariantSampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryDebugImageIterator:
|
||||
event_type = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_image)
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return [{"exists": {"field": "url"}}]
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = VariantSampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if next_iteration:
|
||||
event = self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
else:
|
||||
# noinspection PyArgumentList
|
||||
event = first(
|
||||
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
||||
for f in (
|
||||
self._get_next_for_current_iteration,
|
||||
self._get_next_for_another_iteration,
|
||||
)
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
event: dict, res: VariantSampleResult, state: DebugImageSampleState
|
||||
):
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@staticmethod
|
||||
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
|
||||
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
|
||||
variants_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
@ -25,32 +130,326 @@ class HistoryDebugImageIterator(HistorySampleIterator):
|
||||
]
|
||||
}
|
||||
}
|
||||
for v in variants
|
||||
for v in metric_variants
|
||||
]
|
||||
return {"bool": {"should": variants_conditions}}
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
return event
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
_get_variants_conditions(metric_variants),
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, metric_variants in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
# The min iteration is the lowest iteration that contains non-recycled image url
|
||||
aggs = {
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_conditions(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> VariantSampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = VariantSampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugImageSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugImageSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugImageSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type,
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"urls": {
|
||||
# group by urls and choose the minimal iteration
|
||||
# from all the maximal iterations per url
|
||||
"terms": {"field": "url", "order": {"max_iter": "asc"}, "size": 1},
|
||||
"terms": {
|
||||
"field": "url",
|
||||
"order": {"max_iter": "asc"},
|
||||
"size": 1,
|
||||
},
|
||||
"aggs": {
|
||||
# find max iteration for each url
|
||||
"max_iter": {"max": {"field": "iter"}}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
|
||||
es_res = search_company_events(body=es_req, **search_args)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||
min_iter = int(urls[0]["max_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return min_iter, max_iter
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return aggs, get_min_max_data
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
|
@ -1,36 +0,0 @@
|
||||
from typing import Sequence, Tuple, Callable
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import EventType, uncompress_plot
|
||||
from .history_sample_iterator import HistorySampleIterator, VariantState
|
||||
|
||||
|
||||
class HistoryPlotIterator(HistorySampleIterator):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
super().__init__(redis, es, EventType.metrics_plot)
|
||||
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
return []
|
||||
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
return {"terms": {"variant": [v.name for v in variants]}}
|
||||
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
uncompress_plot(event)
|
||||
return event
|
||||
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
# The min iteration is the lowest iteration that contains non-recycled image url
|
||||
aggs = {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
}
|
||||
|
||||
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
|
||||
min_iter = int(variant_bucket["first_iter"]["value"])
|
||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||
return min_iter, max_iter
|
||||
|
||||
return aggs, get_min_max_data
|
316
apiserver/bll/event/history_plots_iterator.py
Normal file
316
apiserver/bll/event/history_plots_iterator.py
Normal file
@ -0,0 +1,316 @@
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField, ListField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis.client import StrictRedis
|
||||
|
||||
from .event_common import (
|
||||
EventType,
|
||||
uncompress_plot,
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.apierrors import errors
|
||||
|
||||
|
||||
class MetricState(Base):
|
||||
name: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class PlotsSampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
metric_states: Sequence[MetricState] = ListField([MetricState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MetricSamplesResult(object):
|
||||
scroll_id: str = None
|
||||
events: list = []
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistoryPlotsIterator:
|
||||
event_type = EventType.metrics_plot
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=PlotsSampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the samples for next/prev metric on the current iteration
|
||||
If does not exist then try getting sample for the first/last metric from next/prev iteration
|
||||
"""
|
||||
res = MetricSamplesResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
]
|
||||
if state.navigate_current_metric:
|
||||
must_conditions.append({"term": {"metric": state.metric}})
|
||||
|
||||
next_iteration_condition = {
|
||||
"range": {"iter": {range_operator: state.iteration}}
|
||||
}
|
||||
if next_iteration or state.navigate_current_metric:
|
||||
must_conditions.append(next_iteration_condition)
|
||||
else:
|
||||
next_metric_condition = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"iter": state.iteration}},
|
||||
{"range": {"metric": {range_operator: state.metric}}},
|
||||
]
|
||||
}
|
||||
}
|
||||
must_conditions.append(
|
||||
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
|
||||
)
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order=order,
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def get_samples_for_metric(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> MetricSamplesResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = MetricSamplesResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: PlotsSampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: PlotsSampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_metric_states(company_id=company_id, state=state_)
|
||||
|
||||
state: PlotsSampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
|
||||
if not metric_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append({"range": {"iter": {"lte": iteration}}})
|
||||
|
||||
events = self._get_metric_events_for_condition(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
order="desc",
|
||||
must_conditions=must_conditions,
|
||||
)
|
||||
if not events:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||
return res
|
||||
|
||||
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
|
||||
metrics = self._get_metric_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.metric_states = [
|
||||
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
|
||||
for metric, (min_iter, max_iter) in metrics.items()
|
||||
]
|
||||
|
||||
def _get_metric_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 5000,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_iter": {"max": {"field": "iter"}},
|
||||
"first_iter": {"min": {"field": "iter"}},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
es_res = search_company_events(
|
||||
body=es_req,
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
)
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: (
|
||||
int(metric_bucket["first_iter"]["value"]),
|
||||
int(metric_bucket["last_iter"]["value"]),
|
||||
)
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _fill_res_and_update_state(
|
||||
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
|
||||
):
|
||||
for event in events:
|
||||
uncompress_plot(event)
|
||||
state.metric = events[0]["metric"]
|
||||
state.iteration = events[0]["iter"]
|
||||
res.events = events
|
||||
metric_state = first(
|
||||
ms for ms in state.metric_states if ms.name == state.metric
|
||||
)
|
||||
if metric_state:
|
||||
res.min_iteration = metric_state.min_iteration
|
||||
res.max_iteration = metric_state.max_iteration
|
||||
|
||||
def _get_metric_events_for_condition(
|
||||
self, company_id: str, task: str, order: str, must_conditions: Sequence
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1,
|
||||
"order": {"_key": order},
|
||||
},
|
||||
"aggs": {
|
||||
"events": {
|
||||
"top_hits": {
|
||||
"sort": {"variant": {"order": "asc"}},
|
||||
"size": 100,
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return []
|
||||
|
||||
for level in ("iters", "metrics"):
|
||||
level_data = aggs_result[level]["buckets"]
|
||||
if not level_data:
|
||||
return []
|
||||
aggs_result = level_data[0]
|
||||
|
||||
return [
|
||||
hit["_source"]
|
||||
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
|
||||
]
|
@ -1,440 +0,0 @@
|
||||
import abc
|
||||
import operator
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Tuple, Optional, Callable, Mapping
|
||||
|
||||
import attr
|
||||
from boltons.iterutils import first, bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
EventType,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
get_max_metric_and_variant_counts,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
|
||||
|
||||
class VariantState(Base):
|
||||
name: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
min_iteration: int = IntField()
|
||||
max_iteration: int = IntField()
|
||||
|
||||
|
||||
class HistorySampleState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
iteration: int = IntField()
|
||||
variant: str = StringField()
|
||||
task: str = StringField()
|
||||
metric: str = StringField()
|
||||
reached_first: bool = BoolField()
|
||||
reached_last: bool = BoolField()
|
||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||
warning: str = StringField()
|
||||
navigate_current_metric = BoolField(default=True)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class HistorySampleResult(object):
|
||||
scroll_id: str = None
|
||||
event: dict = None
|
||||
min_iteration: int = None
|
||||
max_iteration: int = None
|
||||
|
||||
|
||||
class HistorySampleIterator(abc.ABC):
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
||||
self.es = es
|
||||
self.event_type = event_type
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=HistorySampleState,
|
||||
redis=redis,
|
||||
expiration_interval=EventSettings.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_next_sample(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
state_id: str,
|
||||
navigate_earlier: bool,
|
||||
next_iteration: bool,
|
||||
) -> HistorySampleResult:
|
||||
"""
|
||||
Get the sample for next/prev variant on the current iteration
|
||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||
"""
|
||||
res = HistorySampleResult(scroll_id=state_id)
|
||||
state = self.cache_manager.get_state(state_id)
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
if next_iteration:
|
||||
event = self._get_next_for_another_iteration(
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
else:
|
||||
# noinspection PyArgumentList
|
||||
event = first(
|
||||
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
||||
for f in (
|
||||
self._get_next_for_current_iteration,
|
||||
self._get_next_for_another_iteration,
|
||||
)
|
||||
)
|
||||
if not event:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||
self.cache_manager.set_state(state=state)
|
||||
return res
|
||||
|
||||
def _fill_res_and_update_state(
|
||||
self, event: dict, res: HistorySampleResult, state: HistorySampleState
|
||||
):
|
||||
self._process_event(event)
|
||||
state.variant = event["variant"]
|
||||
state.metric = event["metric"]
|
||||
state.iteration = event["iter"]
|
||||
res.event = event
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == state.variant and vs.metric == state.metric
|
||||
)
|
||||
if var_state:
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _process_event(self, event: dict) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
||||
pass
|
||||
|
||||
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
|
||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||
metrics_conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
self._get_variants_conditions(vs),
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, vs in metrics.items()
|
||||
]
|
||||
return {"bool": {"should": metrics_conditions}}
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||
Only variants for which the iteration falls into their valid range are considered
|
||||
Return None if no such variant or sample is found
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
cmp = operator.lt if navigate_earlier else operator.gt
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||
and var_state.min_iteration <= state.iteration
|
||||
]
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"iter": state.iteration}},
|
||||
self._get_metric_variants_condition(variants),
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
order = "desc" if navigate_earlier else "asc"
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
or from the last variant for the previous iteration (otherwise)
|
||||
The variants for which the sample falls in invalid range are discarded
|
||||
If no suitable sample is found then None is returned
|
||||
"""
|
||||
if state.navigate_current_metric:
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in state.variant_states
|
||||
if var_state.metric == state.metric
|
||||
]
|
||||
else:
|
||||
variants = state.variant_states
|
||||
|
||||
if navigate_earlier:
|
||||
range_operator = "lt"
|
||||
order = "desc"
|
||||
variants = [
|
||||
var_state
|
||||
for var_state in variants
|
||||
if var_state.min_iteration < state.iteration
|
||||
]
|
||||
else:
|
||||
range_operator = "gt"
|
||||
order = "asc"
|
||||
variants = variants
|
||||
|
||||
if not variants:
|
||||
return
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": state.task}},
|
||||
self._get_metric_variants_condition(variants),
|
||||
{"range": {"iter": {range_operator: state.iteration}}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return
|
||||
|
||||
return hits[0]["_source"]
|
||||
|
||||
def get_sample_for_variant(
|
||||
self,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
navigate_current_metric: bool = True,
|
||||
) -> HistorySampleResult:
|
||||
"""
|
||||
Get the sample for the requested iteration or the latest before it
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = HistorySampleResult()
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||
return res
|
||||
|
||||
def init_state(state_: HistorySampleState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
state_.navigate_current_metric = navigate_current_metric
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: HistorySampleState):
|
||||
if (
|
||||
state_.task != task
|
||||
or state_.navigate_current_metric != navigate_current_metric
|
||||
or (state_.navigate_current_metric and state_.metric != metric)
|
||||
):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
# fix old variant states:
|
||||
for vs in state_.variant_states:
|
||||
if vs.metric is None:
|
||||
vs.metric = metric
|
||||
if refresh:
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: HistorySampleState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res.scroll_id = state.id
|
||||
|
||||
var_state = first(
|
||||
vs
|
||||
for vs in state.variant_states
|
||||
if vs.name == variant and vs.metric == metric
|
||||
)
|
||||
if not var_state:
|
||||
return res
|
||||
|
||||
res.min_iteration = var_state.min_iteration
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
must_conditions = [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
if iteration is not None:
|
||||
must_conditions.append(
|
||||
{
|
||||
"range": {
|
||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||
}
|
||||
}
|
||||
)
|
||||
else:
|
||||
must_conditions.append(
|
||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||
)
|
||||
|
||||
es_req = {
|
||||
"size": 1,
|
||||
"sort": {"iter": "desc"},
|
||||
"query": {"bool": {"must": must_conditions}},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.event_type,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
return res
|
||||
|
||||
self._fill_res_and_update_state(
|
||||
event=hits[0]["_source"], res=res, state=state
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
|
||||
metrics = self._get_metric_variant_iterations(
|
||||
company_id=company_id,
|
||||
task=state.task,
|
||||
metric=state.metric if state.navigate_current_metric else None,
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
metric=metric,
|
||||
name=var_name,
|
||||
min_iteration=min_iter,
|
||||
max_iteration=max_iter,
|
||||
)
|
||||
for metric, variants in metrics.items()
|
||||
for var_name, min_iter, max_iter in variants
|
||||
]
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
||||
pass
|
||||
|
||||
def _get_metric_variant_iterations(
|
||||
self, company_id: str, task: str, metric: str,
|
||||
) -> Mapping[str, Tuple[str, str, int, int]]:
|
||||
"""
|
||||
Return valid min and max iterations that the task reported events of the required type
|
||||
"""
|
||||
must = [
|
||||
{"term": {"task": task}},
|
||||
*self._get_extra_conditions(),
|
||||
]
|
||||
if metric is not None:
|
||||
must.append({"term": {"metric": metric}})
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
search_args = dict(
|
||||
es=self.es, company_id=company_id, event_type=self.event_type
|
||||
)
|
||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||
query=query, **search_args
|
||||
)
|
||||
max_variants = int(max_variants // 2)
|
||||
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": max_metrics,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": max_variants,
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": min_max_aggs,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
es_res = search_company_events(body=es_req, **search_args,)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
min_iter, max_iter = get_min_max_data(variant_bucket)
|
||||
return variant, min_iter, max_iter
|
||||
|
||||
return {
|
||||
metric_bucket["key"]: [
|
||||
get_variant_data(variant_bucket)
|
||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||
]
|
||||
for metric_bucket in nested_get(
|
||||
es_res, ("aggregations", "metrics", "buckets")
|
||||
)
|
||||
}
|
@ -372,17 +372,18 @@ _definitions {
|
||||
type: string
|
||||
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
|
||||
}
|
||||
event {
|
||||
type: object
|
||||
description: "Plot event"
|
||||
events {
|
||||
description: "Plot events"
|
||||
type: array
|
||||
items { type: object}
|
||||
}
|
||||
min_iteration {
|
||||
type: integer
|
||||
description: "minimal valid iteration for the variant"
|
||||
description: "minimal valid iteration for the metric"
|
||||
}
|
||||
max_iteration {
|
||||
type: integer
|
||||
description: "maximal valid iteration for the variant"
|
||||
description: "maximal valid iteration for the metric"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -738,10 +739,10 @@ next_debug_image_sample {
|
||||
}
|
||||
get_plot_sample {
|
||||
"2.20": {
|
||||
description: "Return the plot per metric and variant for the provided iteration"
|
||||
description: "Return plots for the provided iteration"
|
||||
request {
|
||||
type: object
|
||||
required: [task, metric, variant]
|
||||
required: [task, metric]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
@ -751,10 +752,6 @@ get_plot_sample {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Metric variant"
|
||||
type: string
|
||||
}
|
||||
iteration {
|
||||
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
|
||||
type: integer
|
||||
@ -786,7 +783,7 @@ get_plot_sample {
|
||||
}
|
||||
next_plot_sample {
|
||||
"2.20": {
|
||||
description: "Get the plot for the next variant for the same iteration or for the next iteration"
|
||||
description: "Get the plot for the next metric for the same iteration or for the next iteration"
|
||||
request {
|
||||
type: object
|
||||
required: [task, scroll_id]
|
||||
@ -801,8 +798,8 @@ next_plot_sample {
|
||||
}
|
||||
navigate_earlier {
|
||||
type: boolean
|
||||
description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration.
|
||||
Otherwise next variant event from the current iteration or first variant event from the next iteration"""
|
||||
description: """If set then get the either previous metric events from the current iteration or (if does not exist) the last metric events from the previous iteration.
|
||||
Otherwise next metric events from the current iteration or first metric events from the next iteration"""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,6 @@ from apiserver.apimodels.events import (
|
||||
TaskMetricsRequest,
|
||||
LogEventsRequest,
|
||||
LogOrderEnum,
|
||||
GetHistorySampleRequest,
|
||||
NextHistorySampleRequest,
|
||||
MetricVariants as ApiMetrics,
|
||||
TaskPlotsRequest,
|
||||
@ -28,6 +27,7 @@ from apiserver.apimodels.events import (
|
||||
ClearScrollRequest,
|
||||
ClearTaskLogRequest,
|
||||
SingleValueMetricsRequest,
|
||||
GetVariantSampleRequest, GetMetricSamplesRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
@ -827,9 +827,9 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
|
||||
@endpoint(
|
||||
"events.get_debug_image_sample",
|
||||
min_version="2.12",
|
||||
request_data_model=GetHistorySampleRequest,
|
||||
request_data_model=GetVariantSampleRequest,
|
||||
)
|
||||
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
@ -866,17 +866,16 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.get_plot_sample", request_data_model=GetHistorySampleRequest,
|
||||
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
|
||||
)
|
||||
def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
|
||||
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
|
||||
task_or_model = _assert_task_or_model_exists(
|
||||
company_id, request.task, model_events=request.model_events,
|
||||
)[0]
|
||||
res = event_bll.plot_sample_history.get_sample_for_variant(
|
||||
res = event_bll.plot_sample_history.get_samples_for_metric(
|
||||
company_id=task_or_model.get_index_company(),
|
||||
task=request.task,
|
||||
metric=request.metric,
|
||||
variant=request.variant,
|
||||
iteration=request.iteration,
|
||||
refresh=request.refresh,
|
||||
state_id=request.scroll_id,
|
||||
|
@ -26,57 +26,55 @@ class TestTaskPlots(TestService):
|
||||
def test_get_plot_sample(self):
|
||||
task = self._temp_task()
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
variants = ["Variant1", "Variant2"]
|
||||
|
||||
# test empty
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric, variant=variant
|
||||
)
|
||||
res = self.api.events.get_plot_sample(task=task, metric=metric)
|
||||
self.assertEqual(res.min_iteration, None)
|
||||
self.assertEqual(res.max_iteration, None)
|
||||
self.assertEqual(res.event, None)
|
||||
self.assertEqual(res.events, [])
|
||||
|
||||
# test existing events
|
||||
iterations = 10
|
||||
iterations = 5
|
||||
events = [
|
||||
self._create_task_event(
|
||||
task=task,
|
||||
iteration=n,
|
||||
iteration=n // len(variants),
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
variant=variants[n % len(variants)],
|
||||
plot_str=f"Test plot str {n}",
|
||||
)
|
||||
for n in range(iterations)
|
||||
for n in range(iterations * len(variants))
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
# if iteration is not specified then return the event from the last one
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric, variant=variant
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-1])
|
||||
res = self.api.events.get_plot_sample(task=task, metric=metric)
|
||||
self._assertEqualEvents(res.events, events[-len(variants) :])
|
||||
self.assertEqual(res.max_iteration, iterations - 1)
|
||||
self.assertEqual(res.min_iteration, 0)
|
||||
self.assertTrue(res.scroll_id)
|
||||
|
||||
# else from the specific iteration
|
||||
iteration = 8
|
||||
iteration = 3
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task,
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
iteration=iteration,
|
||||
scroll_id=res.scroll_id,
|
||||
task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id,
|
||||
)
|
||||
self._assertEqualEvents(
|
||||
res.events,
|
||||
events[iteration * len(variants) : (iteration + 1) * len(variants)],
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[iteration])
|
||||
|
||||
def test_next_plot_sample(self):
|
||||
task = self._temp_task()
|
||||
metric1 = "Metric1"
|
||||
variant1 = "Variant1"
|
||||
metric2 = "Metric2"
|
||||
variant2 = "Variant2"
|
||||
metrics = [(metric1, variant1), (metric2, variant2)]
|
||||
metrics = [
|
||||
(metric1, "variant1"),
|
||||
(metric1, "variant2"),
|
||||
(metric2, "variant3"),
|
||||
(metric2, "variant4"),
|
||||
]
|
||||
# test existing events
|
||||
events = [
|
||||
self._create_task_event(
|
||||
@ -93,73 +91,73 @@ class TestTaskPlots(TestService):
|
||||
|
||||
# single metric navigation
|
||||
# init scroll
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric1, variant=variant1
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-2])
|
||||
res = self.api.events.get_plot_sample(task=task, metric=metric1)
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
|
||||
# navigate forwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
||||
)
|
||||
self.assertEqual(res.event, None)
|
||||
self.assertEqual(res.events, [])
|
||||
|
||||
# navigate backwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-4])
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, None)
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, events[-8:-6])
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, [])
|
||||
|
||||
# all metrics navigation
|
||||
# init scroll
|
||||
res = self.api.events.get_plot_sample(
|
||||
task=task, metric=metric1, variant=variant1, navigate_current_metric=False
|
||||
task=task, metric=metric1, navigate_current_metric=False
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-2])
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
|
||||
# navigate forwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-1])
|
||||
self._assertEqualEvents(res.events, events[-2:])
|
||||
|
||||
# navigate backwards
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-2])
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-3])
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||
self._assertEqualEvents(res.events, events[-6:-4])
|
||||
|
||||
# next_iteration
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, next_iteration=True
|
||||
)
|
||||
self._assertEqualEvent(res.event, None)
|
||||
self._assertEqualEvents(res.events, [])
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
|
||||
task=task,
|
||||
scroll_id=res.scroll_id,
|
||||
next_iteration=True,
|
||||
navigate_earlier=False,
|
||||
)
|
||||
self._assertEqualEvent(res.event, events[-2])
|
||||
self.assertEqual(res.event.iter, 1)
|
||||
self._assertEqualEvents(res.events, events[-4:-2])
|
||||
self.assertTrue(all(ev.iter == 1 for ev in res.events))
|
||||
res = self.api.events.next_plot_sample(
|
||||
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
|
||||
task=task,
|
||||
scroll_id=res.scroll_id,
|
||||
next_iteration=True,
|
||||
navigate_earlier=False,
|
||||
)
|
||||
self._assertEqualEvent(res.event, None)
|
||||
self._assertEqualEvents(res.events, [])
|
||||
|
||||
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]):
|
||||
if ev2 is None:
|
||||
self.assertIsNone(ev1)
|
||||
return
|
||||
self.assertIsNotNone(ev1)
|
||||
def _assertEqualEvents(
|
||||
self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]]
|
||||
):
|
||||
self.assertEqual(len(ev_source), len(ev_target))
|
||||
|
||||
def compare_event(ev1, ev2):
|
||||
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
|
||||
self.assertEqual(ev1[field], ev2[field])
|
||||
|
||||
for e1, e2 in zip(ev_source, ev_target):
|
||||
compare_event(e1, e2)
|
||||
|
||||
def test_task_plots(self):
|
||||
task = self._temp_task()
|
||||
|
||||
@ -238,12 +236,15 @@ class TestTaskPlots(TestService):
|
||||
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
||||
return res.scroll_id
|
||||
|
||||
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
|
||||
expected_variants = set(
|
||||
(m, var) for m, vars_ in expected_metrics.items() for var in vars_
|
||||
)
|
||||
for metric_data in res.metrics:
|
||||
self.assertEqual(len(metric_data.iterations), iterations)
|
||||
for it_data in metric_data.iterations:
|
||||
self.assertEqual(
|
||||
set((e.metric, e.variant) for e in it_data.events), expected_variants
|
||||
set((e.metric, e.variant) for e in it_data.events),
|
||||
expected_variants,
|
||||
)
|
||||
|
||||
return res.scroll_id
|
||||
@ -281,7 +282,7 @@ class TestTaskPlots(TestService):
|
||||
task=task,
|
||||
metric=metric,
|
||||
iterations=iterations,
|
||||
variants=len(variants)
|
||||
variants=len(variants),
|
||||
)
|
||||
|
||||
# test forward navigation
|
||||
|
@ -40,14 +40,20 @@ class TestWorkersService(TestService):
|
||||
def test_system_tags(self):
|
||||
test_worker = f"test_{uuid4().hex}"
|
||||
tag = uuid4().hex
|
||||
system_tag = uuid4().hex
|
||||
self.api.workers.register(
|
||||
worker=test_worker, tags=[tag], system_tags=[system_tag], timeout=5
|
||||
)
|
||||
|
||||
# system_tags support
|
||||
worker = self.api.workers.get_all(tags=[tag], system_tags=["Application"]).workers[0]
|
||||
worker = self.api.workers.get_all(tags=[tag], system_tags=[system_tag]).workers[
|
||||
0
|
||||
]
|
||||
self.assertEqual(worker.id, test_worker)
|
||||
self.assertEqual(worker.tags, [tag])
|
||||
self.assertEqual(worker.system_tags, ["Application"])
|
||||
self.assertEqual(worker.system_tags, [system_tag])
|
||||
|
||||
workers = self.api.workers.get_all(tags=[tag], system_tags=["-Application"]).workers
|
||||
workers = self.api.workers.get_all(tags=[tag], system_tags=[f"-{system_tag}"]).workers
|
||||
self.assertFalse(workers)
|
||||
|
||||
def test_filters(self):
|
||||
|
Loading…
Reference in New Issue
Block a user