mirror of
https://github.com/clearml/clearml-server
synced 2025-04-26 17:10:34 +00:00
Support returning multiple plots during history navigation
This commit is contained in:
parent
bc23f1b0cf
commit
97992b0d9e
@ -61,13 +61,20 @@ class MetricEventsRequest(Base):
|
|||||||
model_events: bool = BoolField()
|
model_events: bool = BoolField()
|
||||||
|
|
||||||
|
|
||||||
class TaskMetricVariant(Base):
|
class GetVariantSampleRequest(Base):
|
||||||
task: str = StringField(required=True)
|
task: str = StringField(required=True)
|
||||||
metric: str = StringField(required=True)
|
metric: str = StringField(required=True)
|
||||||
variant: str = StringField(required=True)
|
variant: str = StringField(required=True)
|
||||||
|
iteration: Optional[int] = IntField()
|
||||||
|
refresh: bool = BoolField(default=False)
|
||||||
|
scroll_id: Optional[str] = StringField()
|
||||||
|
navigate_current_metric: bool = BoolField(default=True)
|
||||||
|
model_events: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class GetHistorySampleRequest(TaskMetricVariant):
|
class GetMetricSamplesRequest(Base):
|
||||||
|
task: str = StringField(required=True)
|
||||||
|
metric: str = StringField(required=True)
|
||||||
iteration: Optional[int] = IntField()
|
iteration: Optional[int] = IntField()
|
||||||
refresh: bool = BoolField(default=False)
|
refresh: bool = BoolField(default=False)
|
||||||
scroll_id: Optional[str] = StringField()
|
scroll_id: Optional[str] = StringField()
|
||||||
|
@ -26,7 +26,7 @@ from apiserver.bll.event.event_common import (
|
|||||||
)
|
)
|
||||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||||
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
|
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
|
||||||
from apiserver.bll.event.history_plot_iterator import HistoryPlotIterator
|
from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
|
||||||
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
|
||||||
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
|
||||||
from apiserver.bll.util import parallel_chunked_decorator
|
from apiserver.bll.util import parallel_chunked_decorator
|
||||||
@ -93,7 +93,7 @@ class EventBLL(object):
|
|||||||
es=self.es, redis=self.redis
|
es=self.es, redis=self.redis
|
||||||
)
|
)
|
||||||
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
|
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
|
||||||
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis)
|
self.plot_sample_history = HistoryPlotsIterator(es=self.es, redis=self.redis)
|
||||||
self.events_iterator = EventsIterator(es=self.es)
|
self.events_iterator = EventsIterator(es=self.es)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -1,21 +1,126 @@
|
|||||||
from typing import Sequence, Tuple, Callable
|
import operator
|
||||||
|
from operator import attrgetter
|
||||||
|
from typing import Sequence, Tuple, Optional, Mapping
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from boltons.iterutils import first, bucketize
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
from jsonmodels.fields import StringField, IntField, BoolField, ListField
|
||||||
|
from jsonmodels.models import Base
|
||||||
from redis.client import StrictRedis
|
from redis.client import StrictRedis
|
||||||
|
|
||||||
from apiserver.utilities.dicts import nested_get
|
from apiserver.utilities.dicts import nested_get
|
||||||
from .event_common import EventType
|
from .event_common import (
|
||||||
from .history_sample_iterator import HistorySampleIterator, VariantState
|
EventType,
|
||||||
|
EventSettings,
|
||||||
|
check_empty_data,
|
||||||
|
search_company_events,
|
||||||
|
get_max_metric_and_variant_counts,
|
||||||
|
)
|
||||||
|
from apiserver.apimodels import JsonSerializableMixin
|
||||||
|
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||||
|
from apiserver.apierrors import errors
|
||||||
|
|
||||||
|
|
||||||
class HistoryDebugImageIterator(HistorySampleIterator):
|
class VariantState(Base):
|
||||||
|
name: str = StringField(required=True)
|
||||||
|
metric: str = StringField(default=None)
|
||||||
|
min_iteration: int = IntField()
|
||||||
|
max_iteration: int = IntField()
|
||||||
|
|
||||||
|
|
||||||
|
class DebugImageSampleState(Base, JsonSerializableMixin):
|
||||||
|
id: str = StringField(required=True)
|
||||||
|
iteration: int = IntField()
|
||||||
|
variant: str = StringField()
|
||||||
|
task: str = StringField()
|
||||||
|
metric: str = StringField()
|
||||||
|
variant_states: Sequence[VariantState] = ListField([VariantState])
|
||||||
|
warning: str = StringField()
|
||||||
|
navigate_current_metric = BoolField(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class VariantSampleResult(object):
|
||||||
|
scroll_id: str = None
|
||||||
|
event: dict = None
|
||||||
|
min_iteration: int = None
|
||||||
|
max_iteration: int = None
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryDebugImageIterator:
|
||||||
|
event_type = EventType.metrics_image
|
||||||
|
|
||||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||||
super().__init__(redis, es, EventType.metrics_image)
|
self.es = es
|
||||||
|
self.cache_manager = RedisCacheManager(
|
||||||
|
state_class=DebugImageSampleState,
|
||||||
|
redis=redis,
|
||||||
|
expiration_interval=EventSettings.state_expiration_sec,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
def get_next_sample(
|
||||||
return [{"exists": {"field": "url"}}]
|
self,
|
||||||
|
company_id: str,
|
||||||
|
task: str,
|
||||||
|
state_id: str,
|
||||||
|
navigate_earlier: bool,
|
||||||
|
next_iteration: bool,
|
||||||
|
) -> VariantSampleResult:
|
||||||
|
"""
|
||||||
|
Get the sample for next/prev variant on the current iteration
|
||||||
|
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
||||||
|
"""
|
||||||
|
res = VariantSampleResult(scroll_id=state_id)
|
||||||
|
state = self.cache_manager.get_state(state_id)
|
||||||
|
if not state or state.task != task:
|
||||||
|
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||||
|
|
||||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||||
|
return res
|
||||||
|
|
||||||
|
if next_iteration:
|
||||||
|
event = self._get_next_for_another_iteration(
|
||||||
|
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# noinspection PyArgumentList
|
||||||
|
event = first(
|
||||||
|
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
||||||
|
for f in (
|
||||||
|
self._get_next_for_current_iteration,
|
||||||
|
self._get_next_for_another_iteration,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not event:
|
||||||
|
return res
|
||||||
|
|
||||||
|
self._fill_res_and_update_state(event=event, res=res, state=state)
|
||||||
|
self.cache_manager.set_state(state=state)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fill_res_and_update_state(
|
||||||
|
event: dict, res: VariantSampleResult, state: DebugImageSampleState
|
||||||
|
):
|
||||||
|
state.variant = event["variant"]
|
||||||
|
state.metric = event["metric"]
|
||||||
|
state.iteration = event["iter"]
|
||||||
|
res.event = event
|
||||||
|
var_state = first(
|
||||||
|
vs
|
||||||
|
for vs in state.variant_states
|
||||||
|
if vs.name == state.variant and vs.metric == state.metric
|
||||||
|
)
|
||||||
|
if var_state:
|
||||||
|
res.min_iteration = var_state.min_iteration
|
||||||
|
res.max_iteration = var_state.max_iteration
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
|
||||||
|
metrics = bucketize(variants, key=attrgetter("metric"))
|
||||||
|
|
||||||
|
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
|
||||||
variants_conditions = [
|
variants_conditions = [
|
||||||
{
|
{
|
||||||
"bool": {
|
"bool": {
|
||||||
@ -25,32 +130,326 @@ class HistoryDebugImageIterator(HistorySampleIterator):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for v in variants
|
for v in metric_variants
|
||||||
]
|
]
|
||||||
return {"bool": {"should": variants_conditions}}
|
return {"bool": {"should": variants_conditions}}
|
||||||
|
|
||||||
def _process_event(self, event: dict) -> dict:
|
metrics_conditions = [
|
||||||
return event
|
{
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{"term": {"metric": metric}},
|
||||||
|
_get_variants_conditions(metric_variants),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for metric, metric_variants in metrics.items()
|
||||||
|
]
|
||||||
|
return {"bool": {"should": metrics_conditions}}
|
||||||
|
|
||||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
def _get_next_for_current_iteration(
|
||||||
# The min iteration is the lowest iteration that contains non-recycled image url
|
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||||
aggs = {
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
||||||
|
Only variants for which the iteration falls into their valid range are considered
|
||||||
|
Return None if no such variant or sample is found
|
||||||
|
"""
|
||||||
|
if state.navigate_current_metric:
|
||||||
|
variants = [
|
||||||
|
var_state
|
||||||
|
for var_state in state.variant_states
|
||||||
|
if var_state.metric == state.metric
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
variants = state.variant_states
|
||||||
|
|
||||||
|
cmp = operator.lt if navigate_earlier else operator.gt
|
||||||
|
variants = [
|
||||||
|
var_state
|
||||||
|
for var_state in variants
|
||||||
|
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
||||||
|
and var_state.min_iteration <= state.iteration
|
||||||
|
]
|
||||||
|
if not variants:
|
||||||
|
return
|
||||||
|
|
||||||
|
must_conditions = [
|
||||||
|
{"term": {"task": state.task}},
|
||||||
|
{"term": {"iter": state.iteration}},
|
||||||
|
self._get_metric_conditions(variants),
|
||||||
|
{"exists": {"field": "url"}},
|
||||||
|
]
|
||||||
|
order = "desc" if navigate_earlier else "asc"
|
||||||
|
es_req = {
|
||||||
|
"size": 1,
|
||||||
|
"sort": [{"metric": order}, {"variant": order}],
|
||||||
|
"query": {"bool": {"must": must_conditions}},
|
||||||
|
}
|
||||||
|
|
||||||
|
es_res = search_company_events(
|
||||||
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=self.event_type,
|
||||||
|
body=es_req,
|
||||||
|
)
|
||||||
|
|
||||||
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
|
if not hits:
|
||||||
|
return
|
||||||
|
|
||||||
|
return hits[0]["_source"]
|
||||||
|
|
||||||
|
def _get_next_for_another_iteration(
|
||||||
|
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||||
|
or from the last variant for the previous iteration (otherwise)
|
||||||
|
The variants for which the sample falls in invalid range are discarded
|
||||||
|
If no suitable sample is found then None is returned
|
||||||
|
"""
|
||||||
|
if state.navigate_current_metric:
|
||||||
|
variants = [
|
||||||
|
var_state
|
||||||
|
for var_state in state.variant_states
|
||||||
|
if var_state.metric == state.metric
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
variants = state.variant_states
|
||||||
|
|
||||||
|
if navigate_earlier:
|
||||||
|
range_operator = "lt"
|
||||||
|
order = "desc"
|
||||||
|
variants = [
|
||||||
|
var_state
|
||||||
|
for var_state in variants
|
||||||
|
if var_state.min_iteration < state.iteration
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
range_operator = "gt"
|
||||||
|
order = "asc"
|
||||||
|
variants = variants
|
||||||
|
|
||||||
|
if not variants:
|
||||||
|
return
|
||||||
|
|
||||||
|
must_conditions = [
|
||||||
|
{"term": {"task": state.task}},
|
||||||
|
self._get_metric_conditions(variants),
|
||||||
|
{"range": {"iter": {range_operator: state.iteration}}},
|
||||||
|
{"exists": {"field": "url"}},
|
||||||
|
]
|
||||||
|
es_req = {
|
||||||
|
"size": 1,
|
||||||
|
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
||||||
|
"query": {"bool": {"must": must_conditions}},
|
||||||
|
}
|
||||||
|
es_res = search_company_events(
|
||||||
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=self.event_type,
|
||||||
|
body=es_req,
|
||||||
|
)
|
||||||
|
|
||||||
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
|
if not hits:
|
||||||
|
return
|
||||||
|
|
||||||
|
return hits[0]["_source"]
|
||||||
|
|
||||||
|
def get_sample_for_variant(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
task: str,
|
||||||
|
metric: str,
|
||||||
|
variant: str,
|
||||||
|
iteration: Optional[int] = None,
|
||||||
|
refresh: bool = False,
|
||||||
|
state_id: str = None,
|
||||||
|
navigate_current_metric: bool = True,
|
||||||
|
) -> VariantSampleResult:
|
||||||
|
"""
|
||||||
|
Get the sample for the requested iteration or the latest before it
|
||||||
|
If the iteration is not passed then get the latest event
|
||||||
|
"""
|
||||||
|
res = VariantSampleResult()
|
||||||
|
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||||
|
return res
|
||||||
|
|
||||||
|
def init_state(state_: DebugImageSampleState):
|
||||||
|
state_.task = task
|
||||||
|
state_.metric = metric
|
||||||
|
state_.navigate_current_metric = navigate_current_metric
|
||||||
|
self._reset_variant_states(company_id=company_id, state=state_)
|
||||||
|
|
||||||
|
def validate_state(state_: DebugImageSampleState):
|
||||||
|
if (
|
||||||
|
state_.task != task
|
||||||
|
or state_.navigate_current_metric != navigate_current_metric
|
||||||
|
or (state_.navigate_current_metric and state_.metric != metric)
|
||||||
|
):
|
||||||
|
raise errors.bad_request.InvalidScrollId(
|
||||||
|
"Task and metric stored in the state do not match the passed ones",
|
||||||
|
scroll_id=state_.id,
|
||||||
|
)
|
||||||
|
# fix old variant states:
|
||||||
|
for vs in state_.variant_states:
|
||||||
|
if vs.metric is None:
|
||||||
|
vs.metric = metric
|
||||||
|
if refresh:
|
||||||
|
self._reset_variant_states(company_id=company_id, state=state_)
|
||||||
|
|
||||||
|
state: DebugImageSampleState
|
||||||
|
with self.cache_manager.get_or_create_state(
|
||||||
|
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||||
|
) as state:
|
||||||
|
res.scroll_id = state.id
|
||||||
|
|
||||||
|
var_state = first(
|
||||||
|
vs
|
||||||
|
for vs in state.variant_states
|
||||||
|
if vs.name == variant and vs.metric == metric
|
||||||
|
)
|
||||||
|
if not var_state:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res.min_iteration = var_state.min_iteration
|
||||||
|
res.max_iteration = var_state.max_iteration
|
||||||
|
|
||||||
|
must_conditions = [
|
||||||
|
{"term": {"task": task}},
|
||||||
|
{"term": {"metric": metric}},
|
||||||
|
{"term": {"variant": variant}},
|
||||||
|
{"exists": {"field": "url"}},
|
||||||
|
]
|
||||||
|
if iteration is not None:
|
||||||
|
must_conditions.append(
|
||||||
|
{
|
||||||
|
"range": {
|
||||||
|
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
must_conditions.append(
|
||||||
|
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
||||||
|
)
|
||||||
|
|
||||||
|
es_req = {
|
||||||
|
"size": 1,
|
||||||
|
"sort": {"iter": "desc"},
|
||||||
|
"query": {"bool": {"must": must_conditions}},
|
||||||
|
}
|
||||||
|
|
||||||
|
es_res = search_company_events(
|
||||||
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=self.event_type,
|
||||||
|
body=es_req,
|
||||||
|
)
|
||||||
|
|
||||||
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
|
if not hits:
|
||||||
|
return res
|
||||||
|
|
||||||
|
self._fill_res_and_update_state(
|
||||||
|
event=hits[0]["_source"], res=res, state=state
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
|
||||||
|
metrics = self._get_metric_variant_iterations(
|
||||||
|
company_id=company_id,
|
||||||
|
task=state.task,
|
||||||
|
metric=state.metric if state.navigate_current_metric else None,
|
||||||
|
)
|
||||||
|
state.variant_states = [
|
||||||
|
VariantState(
|
||||||
|
metric=metric,
|
||||||
|
name=var_name,
|
||||||
|
min_iteration=min_iter,
|
||||||
|
max_iteration=max_iter,
|
||||||
|
)
|
||||||
|
for metric, variants in metrics.items()
|
||||||
|
for var_name, min_iter, max_iter in variants
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_metric_variant_iterations(
|
||||||
|
self, company_id: str, task: str, metric: str,
|
||||||
|
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
|
||||||
|
"""
|
||||||
|
Return valid min and max iterations that the task reported events of the required type
|
||||||
|
"""
|
||||||
|
must = [
|
||||||
|
{"term": {"task": task}},
|
||||||
|
{"exists": {"field": "url"}},
|
||||||
|
]
|
||||||
|
if metric is not None:
|
||||||
|
must.append({"term": {"metric": metric}})
|
||||||
|
query = {"bool": {"must": must}}
|
||||||
|
|
||||||
|
search_args = dict(
|
||||||
|
es=self.es, company_id=company_id, event_type=self.event_type,
|
||||||
|
)
|
||||||
|
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
||||||
|
query=query, **search_args
|
||||||
|
)
|
||||||
|
max_variants = int(max_variants // 2)
|
||||||
|
es_req: dict = {
|
||||||
|
"size": 0,
|
||||||
|
"query": query,
|
||||||
|
"aggs": {
|
||||||
|
"metrics": {
|
||||||
|
"terms": {
|
||||||
|
"field": "metric",
|
||||||
|
"size": max_metrics,
|
||||||
|
"order": {"_key": "asc"},
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"variants": {
|
||||||
|
"terms": {
|
||||||
|
"field": "variant",
|
||||||
|
"size": max_variants,
|
||||||
|
"order": {"_key": "asc"},
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
"last_iter": {"max": {"field": "iter"}},
|
"last_iter": {"max": {"field": "iter"}},
|
||||||
"urls": {
|
"urls": {
|
||||||
# group by urls and choose the minimal iteration
|
# group by urls and choose the minimal iteration
|
||||||
# from all the maximal iterations per url
|
# from all the maximal iterations per url
|
||||||
"terms": {"field": "url", "order": {"max_iter": "asc"}, "size": 1},
|
"terms": {
|
||||||
|
"field": "url",
|
||||||
|
"order": {"max_iter": "asc"},
|
||||||
|
"size": 1,
|
||||||
|
},
|
||||||
"aggs": {
|
"aggs": {
|
||||||
# find max iteration for each url
|
# find max iteration for each url
|
||||||
"max_iter": {"max": {"field": "iter"}}
|
"max_iter": {"max": {"field": "iter"}}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
|
es_res = search_company_events(body=es_req, **search_args)
|
||||||
|
|
||||||
|
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||||
|
variant = variant_bucket["key"]
|
||||||
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
urls = nested_get(variant_bucket, ("urls", "buckets"))
|
||||||
min_iter = int(urls[0]["max_iter"]["value"])
|
min_iter = int(urls[0]["max_iter"]["value"])
|
||||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
max_iter = int(variant_bucket["last_iter"]["value"])
|
||||||
return min_iter, max_iter
|
return variant, min_iter, max_iter
|
||||||
|
|
||||||
return aggs, get_min_max_data
|
return {
|
||||||
|
metric_bucket["key"]: [
|
||||||
|
get_variant_data(variant_bucket)
|
||||||
|
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
||||||
|
]
|
||||||
|
for metric_bucket in nested_get(
|
||||||
|
es_res, ("aggregations", "metrics", "buckets")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
@ -1,36 +0,0 @@
|
|||||||
from typing import Sequence, Tuple, Callable
|
|
||||||
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from redis.client import StrictRedis
|
|
||||||
|
|
||||||
from .event_common import EventType, uncompress_plot
|
|
||||||
from .history_sample_iterator import HistorySampleIterator, VariantState
|
|
||||||
|
|
||||||
|
|
||||||
class HistoryPlotIterator(HistorySampleIterator):
|
|
||||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
|
||||||
super().__init__(redis, es, EventType.metrics_plot)
|
|
||||||
|
|
||||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
|
||||||
return {"terms": {"variant": [v.name for v in variants]}}
|
|
||||||
|
|
||||||
def _process_event(self, event: dict) -> dict:
|
|
||||||
uncompress_plot(event)
|
|
||||||
return event
|
|
||||||
|
|
||||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
|
||||||
# The min iteration is the lowest iteration that contains non-recycled image url
|
|
||||||
aggs = {
|
|
||||||
"last_iter": {"max": {"field": "iter"}},
|
|
||||||
"first_iter": {"min": {"field": "iter"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
|
|
||||||
min_iter = int(variant_bucket["first_iter"]["value"])
|
|
||||||
max_iter = int(variant_bucket["last_iter"]["value"])
|
|
||||||
return min_iter, max_iter
|
|
||||||
|
|
||||||
return aggs, get_min_max_data
|
|
316
apiserver/bll/event/history_plots_iterator.py
Normal file
316
apiserver/bll/event/history_plots_iterator.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
from typing import Sequence, Tuple, Optional, Mapping
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from boltons.iterutils import first
|
||||||
|
from elasticsearch import Elasticsearch
|
||||||
|
from jsonmodels.fields import StringField, IntField, ListField, BoolField
|
||||||
|
from jsonmodels.models import Base
|
||||||
|
from redis.client import StrictRedis
|
||||||
|
|
||||||
|
from .event_common import (
|
||||||
|
EventType,
|
||||||
|
uncompress_plot,
|
||||||
|
EventSettings,
|
||||||
|
check_empty_data,
|
||||||
|
search_company_events,
|
||||||
|
)
|
||||||
|
from apiserver.apimodels import JsonSerializableMixin
|
||||||
|
from apiserver.utilities.dicts import nested_get
|
||||||
|
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||||
|
from apiserver.apierrors import errors
|
||||||
|
|
||||||
|
|
||||||
|
class MetricState(Base):
|
||||||
|
name: str = StringField(default=None)
|
||||||
|
min_iteration: int = IntField()
|
||||||
|
max_iteration: int = IntField()
|
||||||
|
|
||||||
|
|
||||||
|
class PlotsSampleState(Base, JsonSerializableMixin):
|
||||||
|
id: str = StringField(required=True)
|
||||||
|
iteration: int = IntField()
|
||||||
|
task: str = StringField()
|
||||||
|
metric: str = StringField()
|
||||||
|
metric_states: Sequence[MetricState] = ListField([MetricState])
|
||||||
|
warning: str = StringField()
|
||||||
|
navigate_current_metric = BoolField(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
class MetricSamplesResult(object):
|
||||||
|
scroll_id: str = None
|
||||||
|
events: list = []
|
||||||
|
min_iteration: int = None
|
||||||
|
max_iteration: int = None
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryPlotsIterator:
|
||||||
|
event_type = EventType.metrics_plot
|
||||||
|
|
||||||
|
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||||
|
self.es = es
|
||||||
|
self.cache_manager = RedisCacheManager(
|
||||||
|
state_class=PlotsSampleState,
|
||||||
|
redis=redis,
|
||||||
|
expiration_interval=EventSettings.state_expiration_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_next_sample(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
task: str,
|
||||||
|
state_id: str,
|
||||||
|
navigate_earlier: bool,
|
||||||
|
next_iteration: bool,
|
||||||
|
) -> MetricSamplesResult:
|
||||||
|
"""
|
||||||
|
Get the samples for next/prev metric on the current iteration
|
||||||
|
If does not exist then try getting sample for the first/last metric from next/prev iteration
|
||||||
|
"""
|
||||||
|
res = MetricSamplesResult(scroll_id=state_id)
|
||||||
|
state = self.cache_manager.get_state(state_id)
|
||||||
|
if not state or state.task != task:
|
||||||
|
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||||
|
|
||||||
|
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||||
|
return res
|
||||||
|
|
||||||
|
if navigate_earlier:
|
||||||
|
range_operator = "lt"
|
||||||
|
order = "desc"
|
||||||
|
else:
|
||||||
|
range_operator = "gt"
|
||||||
|
order = "asc"
|
||||||
|
|
||||||
|
must_conditions = [
|
||||||
|
{"term": {"task": state.task}},
|
||||||
|
]
|
||||||
|
if state.navigate_current_metric:
|
||||||
|
must_conditions.append({"term": {"metric": state.metric}})
|
||||||
|
|
||||||
|
next_iteration_condition = {
|
||||||
|
"range": {"iter": {range_operator: state.iteration}}
|
||||||
|
}
|
||||||
|
if next_iteration or state.navigate_current_metric:
|
||||||
|
must_conditions.append(next_iteration_condition)
|
||||||
|
else:
|
||||||
|
next_metric_condition = {
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{"term": {"iter": state.iteration}},
|
||||||
|
{"range": {"metric": {range_operator: state.metric}}},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
must_conditions.append(
|
||||||
|
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
|
||||||
|
)
|
||||||
|
|
||||||
|
events = self._get_metric_events_for_condition(
|
||||||
|
company_id=company_id,
|
||||||
|
task=state.task,
|
||||||
|
order=order,
|
||||||
|
must_conditions=must_conditions,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
return res
|
||||||
|
|
||||||
|
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||||
|
self.cache_manager.set_state(state=state)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def get_samples_for_metric(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
task: str,
|
||||||
|
metric: str,
|
||||||
|
iteration: Optional[int] = None,
|
||||||
|
refresh: bool = False,
|
||||||
|
state_id: str = None,
|
||||||
|
navigate_current_metric: bool = True,
|
||||||
|
) -> MetricSamplesResult:
|
||||||
|
"""
|
||||||
|
Get the sample for the requested iteration or the latest before it
|
||||||
|
If the iteration is not passed then get the latest event
|
||||||
|
"""
|
||||||
|
res = MetricSamplesResult()
|
||||||
|
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
||||||
|
return res
|
||||||
|
|
||||||
|
def init_state(state_: PlotsSampleState):
|
||||||
|
state_.task = task
|
||||||
|
state_.metric = metric
|
||||||
|
state_.navigate_current_metric = navigate_current_metric
|
||||||
|
self._reset_metric_states(company_id=company_id, state=state_)
|
||||||
|
|
||||||
|
def validate_state(state_: PlotsSampleState):
|
||||||
|
if (
|
||||||
|
state_.task != task
|
||||||
|
or state_.navigate_current_metric != navigate_current_metric
|
||||||
|
or (state_.navigate_current_metric and state_.metric != metric)
|
||||||
|
):
|
||||||
|
raise errors.bad_request.InvalidScrollId(
|
||||||
|
"Task and metric stored in the state do not match the passed ones",
|
||||||
|
scroll_id=state_.id,
|
||||||
|
)
|
||||||
|
if refresh:
|
||||||
|
self._reset_metric_states(company_id=company_id, state=state_)
|
||||||
|
|
||||||
|
state: PlotsSampleState
|
||||||
|
with self.cache_manager.get_or_create_state(
|
||||||
|
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||||
|
) as state:
|
||||||
|
res.scroll_id = state.id
|
||||||
|
|
||||||
|
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
|
||||||
|
if not metric_state:
|
||||||
|
return res
|
||||||
|
|
||||||
|
res.min_iteration = metric_state.min_iteration
|
||||||
|
res.max_iteration = metric_state.max_iteration
|
||||||
|
|
||||||
|
must_conditions = [
|
||||||
|
{"term": {"task": task}},
|
||||||
|
{"term": {"metric": metric}},
|
||||||
|
]
|
||||||
|
if iteration is not None:
|
||||||
|
must_conditions.append({"range": {"iter": {"lte": iteration}}})
|
||||||
|
|
||||||
|
events = self._get_metric_events_for_condition(
|
||||||
|
company_id=company_id,
|
||||||
|
task=state.task,
|
||||||
|
order="desc",
|
||||||
|
must_conditions=must_conditions,
|
||||||
|
)
|
||||||
|
if not events:
|
||||||
|
return res
|
||||||
|
|
||||||
|
self._fill_res_and_update_state(events=events, res=res, state=state)
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
|
||||||
|
metrics = self._get_metric_iterations(
|
||||||
|
company_id=company_id,
|
||||||
|
task=state.task,
|
||||||
|
metric=state.metric if state.navigate_current_metric else None,
|
||||||
|
)
|
||||||
|
state.metric_states = [
|
||||||
|
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
|
||||||
|
for metric, (min_iter, max_iter) in metrics.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_metric_iterations(
|
||||||
|
self, company_id: str, task: str, metric: str,
|
||||||
|
) -> Mapping[str, Tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Return valid min and max iterations that the task reported events of the required type
|
||||||
|
"""
|
||||||
|
must = [
|
||||||
|
{"term": {"task": task}},
|
||||||
|
]
|
||||||
|
if metric is not None:
|
||||||
|
must.append({"term": {"metric": metric}})
|
||||||
|
query = {"bool": {"must": must}}
|
||||||
|
|
||||||
|
es_req: dict = {
|
||||||
|
"size": 0,
|
||||||
|
"query": query,
|
||||||
|
"aggs": {
|
||||||
|
"metrics": {
|
||||||
|
"terms": {
|
||||||
|
"field": "metric",
|
||||||
|
"size": 5000,
|
||||||
|
"order": {"_key": "asc"},
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"last_iter": {"max": {"field": "iter"}},
|
||||||
|
"first_iter": {"min": {"field": "iter"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
es_res = search_company_events(
|
||||||
|
body=es_req,
|
||||||
|
es=self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=self.event_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
metric_bucket["key"]: (
|
||||||
|
int(metric_bucket["first_iter"]["value"]),
|
||||||
|
int(metric_bucket["last_iter"]["value"]),
|
||||||
|
)
|
||||||
|
for metric_bucket in nested_get(
|
||||||
|
es_res, ("aggregations", "metrics", "buckets")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fill_res_and_update_state(
|
||||||
|
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
|
||||||
|
):
|
||||||
|
for event in events:
|
||||||
|
uncompress_plot(event)
|
||||||
|
state.metric = events[0]["metric"]
|
||||||
|
state.iteration = events[0]["iter"]
|
||||||
|
res.events = events
|
||||||
|
metric_state = first(
|
||||||
|
ms for ms in state.metric_states if ms.name == state.metric
|
||||||
|
)
|
||||||
|
if metric_state:
|
||||||
|
res.min_iteration = metric_state.min_iteration
|
||||||
|
res.max_iteration = metric_state.max_iteration
|
||||||
|
|
||||||
|
def _get_metric_events_for_condition(
|
||||||
|
self, company_id: str, task: str, order: str, must_conditions: Sequence
|
||||||
|
) -> Sequence:
|
||||||
|
es_req = {
|
||||||
|
"size": 0,
|
||||||
|
"query": {"bool": {"must": must_conditions}},
|
||||||
|
"aggs": {
|
||||||
|
"iters": {
|
||||||
|
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
|
||||||
|
"aggs": {
|
||||||
|
"metrics": {
|
||||||
|
"terms": {
|
||||||
|
"field": "metric",
|
||||||
|
"size": 1,
|
||||||
|
"order": {"_key": order},
|
||||||
|
},
|
||||||
|
"aggs": {
|
||||||
|
"events": {
|
||||||
|
"top_hits": {
|
||||||
|
"sort": {"variant": {"order": "asc"}},
|
||||||
|
"size": 100,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
es_res = search_company_events(
|
||||||
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=self.event_type,
|
||||||
|
body=es_req,
|
||||||
|
)
|
||||||
|
|
||||||
|
aggs_result = es_res.get("aggregations")
|
||||||
|
if not aggs_result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
for level in ("iters", "metrics"):
|
||||||
|
level_data = aggs_result[level]["buckets"]
|
||||||
|
if not level_data:
|
||||||
|
return []
|
||||||
|
aggs_result = level_data[0]
|
||||||
|
|
||||||
|
return [
|
||||||
|
hit["_source"]
|
||||||
|
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
|
||||||
|
]
|
@ -1,440 +0,0 @@
|
|||||||
import abc
|
|
||||||
import operator
|
|
||||||
from operator import attrgetter
|
|
||||||
from typing import Sequence, Tuple, Optional, Callable, Mapping
|
|
||||||
|
|
||||||
import attr
|
|
||||||
from boltons.iterutils import first, bucketize
|
|
||||||
from elasticsearch import Elasticsearch
|
|
||||||
from jsonmodels.fields import StringField, ListField, IntField, BoolField
|
|
||||||
from jsonmodels.models import Base
|
|
||||||
from redis import StrictRedis
|
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
|
||||||
from apiserver.apimodels import JsonSerializableMixin
|
|
||||||
from apiserver.bll.event.event_common import (
|
|
||||||
EventSettings,
|
|
||||||
EventType,
|
|
||||||
check_empty_data,
|
|
||||||
search_company_events,
|
|
||||||
get_max_metric_and_variant_counts,
|
|
||||||
)
|
|
||||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
|
||||||
from apiserver.database.errors import translate_errors_context
|
|
||||||
from apiserver.utilities.dicts import nested_get
|
|
||||||
|
|
||||||
|
|
||||||
class VariantState(Base):
|
|
||||||
name: str = StringField(required=True)
|
|
||||||
metric: str = StringField(default=None)
|
|
||||||
min_iteration: int = IntField()
|
|
||||||
max_iteration: int = IntField()
|
|
||||||
|
|
||||||
|
|
||||||
class HistorySampleState(Base, JsonSerializableMixin):
|
|
||||||
id: str = StringField(required=True)
|
|
||||||
iteration: int = IntField()
|
|
||||||
variant: str = StringField()
|
|
||||||
task: str = StringField()
|
|
||||||
metric: str = StringField()
|
|
||||||
reached_first: bool = BoolField()
|
|
||||||
reached_last: bool = BoolField()
|
|
||||||
variant_states: Sequence[VariantState] = ListField([VariantState])
|
|
||||||
warning: str = StringField()
|
|
||||||
navigate_current_metric = BoolField(default=True)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
|
||||||
class HistorySampleResult(object):
|
|
||||||
scroll_id: str = None
|
|
||||||
event: dict = None
|
|
||||||
min_iteration: int = None
|
|
||||||
max_iteration: int = None
|
|
||||||
|
|
||||||
|
|
||||||
class HistorySampleIterator(abc.ABC):
|
|
||||||
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
|
|
||||||
self.es = es
|
|
||||||
self.event_type = event_type
|
|
||||||
self.cache_manager = RedisCacheManager(
|
|
||||||
state_class=HistorySampleState,
|
|
||||||
redis=redis,
|
|
||||||
expiration_interval=EventSettings.state_expiration_sec,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_next_sample(
|
|
||||||
self,
|
|
||||||
company_id: str,
|
|
||||||
task: str,
|
|
||||||
state_id: str,
|
|
||||||
navigate_earlier: bool,
|
|
||||||
next_iteration: bool,
|
|
||||||
) -> HistorySampleResult:
|
|
||||||
"""
|
|
||||||
Get the sample for next/prev variant on the current iteration
|
|
||||||
If does not exist then try getting sample for the first/last variant from next/prev iteration
|
|
||||||
"""
|
|
||||||
res = HistorySampleResult(scroll_id=state_id)
|
|
||||||
state = self.cache_manager.get_state(state_id)
|
|
||||||
if not state or state.task != task:
|
|
||||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
|
||||||
|
|
||||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
|
||||||
return res
|
|
||||||
|
|
||||||
if next_iteration:
|
|
||||||
event = self._get_next_for_another_iteration(
|
|
||||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# noinspection PyArgumentList
|
|
||||||
event = first(
|
|
||||||
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
|
|
||||||
for f in (
|
|
||||||
self._get_next_for_current_iteration,
|
|
||||||
self._get_next_for_another_iteration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not event:
|
|
||||||
return res
|
|
||||||
|
|
||||||
self._fill_res_and_update_state(event=event, res=res, state=state)
|
|
||||||
self.cache_manager.set_state(state=state)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _fill_res_and_update_state(
|
|
||||||
self, event: dict, res: HistorySampleResult, state: HistorySampleState
|
|
||||||
):
|
|
||||||
self._process_event(event)
|
|
||||||
state.variant = event["variant"]
|
|
||||||
state.metric = event["metric"]
|
|
||||||
state.iteration = event["iter"]
|
|
||||||
res.event = event
|
|
||||||
var_state = first(
|
|
||||||
vs
|
|
||||||
for vs in state.variant_states
|
|
||||||
if vs.name == state.variant and vs.metric == state.metric
|
|
||||||
)
|
|
||||||
if var_state:
|
|
||||||
res.min_iteration = var_state.min_iteration
|
|
||||||
res.max_iteration = var_state.max_iteration
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _get_extra_conditions(self) -> Sequence[dict]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _process_event(self, event: dict) -> dict:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
|
|
||||||
metrics = bucketize(variants, key=attrgetter("metric"))
|
|
||||||
metrics_conditions = [
|
|
||||||
{
|
|
||||||
"bool": {
|
|
||||||
"must": [
|
|
||||||
{"term": {"metric": metric}},
|
|
||||||
self._get_variants_conditions(vs),
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for metric, vs in metrics.items()
|
|
||||||
]
|
|
||||||
return {"bool": {"should": metrics_conditions}}
|
|
||||||
|
|
||||||
def _get_next_for_current_iteration(
|
|
||||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
|
||||||
) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
|
|
||||||
Only variants for which the iteration falls into their valid range are considered
|
|
||||||
Return None if no such variant or sample is found
|
|
||||||
"""
|
|
||||||
if state.navigate_current_metric:
|
|
||||||
variants = [
|
|
||||||
var_state
|
|
||||||
for var_state in state.variant_states
|
|
||||||
if var_state.metric == state.metric
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
variants = state.variant_states
|
|
||||||
|
|
||||||
cmp = operator.lt if navigate_earlier else operator.gt
|
|
||||||
variants = [
|
|
||||||
var_state
|
|
||||||
for var_state in variants
|
|
||||||
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
|
|
||||||
and var_state.min_iteration <= state.iteration
|
|
||||||
]
|
|
||||||
if not variants:
|
|
||||||
return
|
|
||||||
|
|
||||||
must_conditions = [
|
|
||||||
{"term": {"task": state.task}},
|
|
||||||
{"term": {"iter": state.iteration}},
|
|
||||||
self._get_metric_variants_condition(variants),
|
|
||||||
*self._get_extra_conditions(),
|
|
||||||
]
|
|
||||||
order = "desc" if navigate_earlier else "asc"
|
|
||||||
es_req = {
|
|
||||||
"size": 1,
|
|
||||||
"sort": [{"metric": order}, {"variant": order}],
|
|
||||||
"query": {"bool": {"must": must_conditions}},
|
|
||||||
}
|
|
||||||
|
|
||||||
with translate_errors_context():
|
|
||||||
es_res = search_company_events(
|
|
||||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
|
||||||
)
|
|
||||||
|
|
||||||
hits = nested_get(es_res, ("hits", "hits"))
|
|
||||||
if not hits:
|
|
||||||
return
|
|
||||||
|
|
||||||
return hits[0]["_source"]
|
|
||||||
|
|
||||||
def _get_next_for_another_iteration(
|
|
||||||
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
|
|
||||||
) -> Optional[dict]:
|
|
||||||
"""
|
|
||||||
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
|
|
||||||
or from the last variant for the previous iteration (otherwise)
|
|
||||||
The variants for which the sample falls in invalid range are discarded
|
|
||||||
If no suitable sample is found then None is returned
|
|
||||||
"""
|
|
||||||
if state.navigate_current_metric:
|
|
||||||
variants = [
|
|
||||||
var_state
|
|
||||||
for var_state in state.variant_states
|
|
||||||
if var_state.metric == state.metric
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
variants = state.variant_states
|
|
||||||
|
|
||||||
if navigate_earlier:
|
|
||||||
range_operator = "lt"
|
|
||||||
order = "desc"
|
|
||||||
variants = [
|
|
||||||
var_state
|
|
||||||
for var_state in variants
|
|
||||||
if var_state.min_iteration < state.iteration
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
range_operator = "gt"
|
|
||||||
order = "asc"
|
|
||||||
variants = variants
|
|
||||||
|
|
||||||
if not variants:
|
|
||||||
return
|
|
||||||
|
|
||||||
must_conditions = [
|
|
||||||
{"term": {"task": state.task}},
|
|
||||||
self._get_metric_variants_condition(variants),
|
|
||||||
{"range": {"iter": {range_operator: state.iteration}}},
|
|
||||||
*self._get_extra_conditions(),
|
|
||||||
]
|
|
||||||
es_req = {
|
|
||||||
"size": 1,
|
|
||||||
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
|
|
||||||
"query": {"bool": {"must": must_conditions}},
|
|
||||||
}
|
|
||||||
with translate_errors_context():
|
|
||||||
es_res = search_company_events(
|
|
||||||
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
|
|
||||||
)
|
|
||||||
|
|
||||||
hits = nested_get(es_res, ("hits", "hits"))
|
|
||||||
if not hits:
|
|
||||||
return
|
|
||||||
|
|
||||||
return hits[0]["_source"]
|
|
||||||
|
|
||||||
def get_sample_for_variant(
|
|
||||||
self,
|
|
||||||
company_id: str,
|
|
||||||
task: str,
|
|
||||||
metric: str,
|
|
||||||
variant: str,
|
|
||||||
iteration: Optional[int] = None,
|
|
||||||
refresh: bool = False,
|
|
||||||
state_id: str = None,
|
|
||||||
navigate_current_metric: bool = True,
|
|
||||||
) -> HistorySampleResult:
|
|
||||||
"""
|
|
||||||
Get the sample for the requested iteration or the latest before it
|
|
||||||
If the iteration is not passed then get the latest event
|
|
||||||
"""
|
|
||||||
res = HistorySampleResult()
|
|
||||||
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
|
|
||||||
return res
|
|
||||||
|
|
||||||
def init_state(state_: HistorySampleState):
|
|
||||||
state_.task = task
|
|
||||||
state_.metric = metric
|
|
||||||
state_.navigate_current_metric = navigate_current_metric
|
|
||||||
self._reset_variant_states(company_id=company_id, state=state_)
|
|
||||||
|
|
||||||
def validate_state(state_: HistorySampleState):
|
|
||||||
if (
|
|
||||||
state_.task != task
|
|
||||||
or state_.navigate_current_metric != navigate_current_metric
|
|
||||||
or (state_.navigate_current_metric and state_.metric != metric)
|
|
||||||
):
|
|
||||||
raise errors.bad_request.InvalidScrollId(
|
|
||||||
"Task and metric stored in the state do not match the passed ones",
|
|
||||||
scroll_id=state_.id,
|
|
||||||
)
|
|
||||||
# fix old variant states:
|
|
||||||
for vs in state_.variant_states:
|
|
||||||
if vs.metric is None:
|
|
||||||
vs.metric = metric
|
|
||||||
if refresh:
|
|
||||||
self._reset_variant_states(company_id=company_id, state=state_)
|
|
||||||
|
|
||||||
state: HistorySampleState
|
|
||||||
with self.cache_manager.get_or_create_state(
|
|
||||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
|
||||||
) as state:
|
|
||||||
res.scroll_id = state.id
|
|
||||||
|
|
||||||
var_state = first(
|
|
||||||
vs
|
|
||||||
for vs in state.variant_states
|
|
||||||
if vs.name == variant and vs.metric == metric
|
|
||||||
)
|
|
||||||
if not var_state:
|
|
||||||
return res
|
|
||||||
|
|
||||||
res.min_iteration = var_state.min_iteration
|
|
||||||
res.max_iteration = var_state.max_iteration
|
|
||||||
|
|
||||||
must_conditions = [
|
|
||||||
{"term": {"task": task}},
|
|
||||||
{"term": {"metric": metric}},
|
|
||||||
{"term": {"variant": variant}},
|
|
||||||
*self._get_extra_conditions(),
|
|
||||||
]
|
|
||||||
if iteration is not None:
|
|
||||||
must_conditions.append(
|
|
||||||
{
|
|
||||||
"range": {
|
|
||||||
"iter": {"lte": iteration, "gte": var_state.min_iteration}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
must_conditions.append(
|
|
||||||
{"range": {"iter": {"gte": var_state.min_iteration}}}
|
|
||||||
)
|
|
||||||
|
|
||||||
es_req = {
|
|
||||||
"size": 1,
|
|
||||||
"sort": {"iter": "desc"},
|
|
||||||
"query": {"bool": {"must": must_conditions}},
|
|
||||||
}
|
|
||||||
|
|
||||||
with translate_errors_context():
|
|
||||||
es_res = search_company_events(
|
|
||||||
self.es,
|
|
||||||
company_id=company_id,
|
|
||||||
event_type=self.event_type,
|
|
||||||
body=es_req,
|
|
||||||
)
|
|
||||||
|
|
||||||
hits = nested_get(es_res, ("hits", "hits"))
|
|
||||||
if not hits:
|
|
||||||
return res
|
|
||||||
|
|
||||||
self._fill_res_and_update_state(
|
|
||||||
event=hits[0]["_source"], res=res, state=state
|
|
||||||
)
|
|
||||||
return res
|
|
||||||
|
|
||||||
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
|
|
||||||
metrics = self._get_metric_variant_iterations(
|
|
||||||
company_id=company_id,
|
|
||||||
task=state.task,
|
|
||||||
metric=state.metric if state.navigate_current_metric else None,
|
|
||||||
)
|
|
||||||
state.variant_states = [
|
|
||||||
VariantState(
|
|
||||||
metric=metric,
|
|
||||||
name=var_name,
|
|
||||||
min_iteration=min_iter,
|
|
||||||
max_iteration=max_iter,
|
|
||||||
)
|
|
||||||
for metric, variants in metrics.items()
|
|
||||||
for var_name, min_iter, max_iter in variants
|
|
||||||
]
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_metric_variant_iterations(
|
|
||||||
self, company_id: str, task: str, metric: str,
|
|
||||||
) -> Mapping[str, Tuple[str, str, int, int]]:
|
|
||||||
"""
|
|
||||||
Return valid min and max iterations that the task reported events of the required type
|
|
||||||
"""
|
|
||||||
must = [
|
|
||||||
{"term": {"task": task}},
|
|
||||||
*self._get_extra_conditions(),
|
|
||||||
]
|
|
||||||
if metric is not None:
|
|
||||||
must.append({"term": {"metric": metric}})
|
|
||||||
query = {"bool": {"must": must}}
|
|
||||||
|
|
||||||
search_args = dict(
|
|
||||||
es=self.es, company_id=company_id, event_type=self.event_type
|
|
||||||
)
|
|
||||||
max_metrics, max_variants = get_max_metric_and_variant_counts(
|
|
||||||
query=query, **search_args
|
|
||||||
)
|
|
||||||
max_variants = int(max_variants // 2)
|
|
||||||
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
|
|
||||||
es_req: dict = {
|
|
||||||
"size": 0,
|
|
||||||
"query": query,
|
|
||||||
"aggs": {
|
|
||||||
"metrics": {
|
|
||||||
"terms": {
|
|
||||||
"field": "metric",
|
|
||||||
"size": max_metrics,
|
|
||||||
"order": {"_key": "asc"},
|
|
||||||
},
|
|
||||||
"aggs": {
|
|
||||||
"variants": {
|
|
||||||
"terms": {
|
|
||||||
"field": "variant",
|
|
||||||
"size": max_variants,
|
|
||||||
"order": {"_key": "asc"},
|
|
||||||
},
|
|
||||||
"aggs": min_max_aggs,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
with translate_errors_context():
|
|
||||||
es_res = search_company_events(body=es_req, **search_args,)
|
|
||||||
|
|
||||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
|
||||||
variant = variant_bucket["key"]
|
|
||||||
min_iter, max_iter = get_min_max_data(variant_bucket)
|
|
||||||
return variant, min_iter, max_iter
|
|
||||||
|
|
||||||
return {
|
|
||||||
metric_bucket["key"]: [
|
|
||||||
get_variant_data(variant_bucket)
|
|
||||||
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
|
|
||||||
]
|
|
||||||
for metric_bucket in nested_get(
|
|
||||||
es_res, ("aggregations", "metrics", "buckets")
|
|
||||||
)
|
|
||||||
}
|
|
@ -372,17 +372,18 @@ _definitions {
|
|||||||
type: string
|
type: string
|
||||||
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
|
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
|
||||||
}
|
}
|
||||||
event {
|
events {
|
||||||
type: object
|
description: "Plot events"
|
||||||
description: "Plot event"
|
type: array
|
||||||
|
items { type: object}
|
||||||
}
|
}
|
||||||
min_iteration {
|
min_iteration {
|
||||||
type: integer
|
type: integer
|
||||||
description: "minimal valid iteration for the variant"
|
description: "minimal valid iteration for the metric"
|
||||||
}
|
}
|
||||||
max_iteration {
|
max_iteration {
|
||||||
type: integer
|
type: integer
|
||||||
description: "maximal valid iteration for the variant"
|
description: "maximal valid iteration for the metric"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -738,10 +739,10 @@ next_debug_image_sample {
|
|||||||
}
|
}
|
||||||
get_plot_sample {
|
get_plot_sample {
|
||||||
"2.20": {
|
"2.20": {
|
||||||
description: "Return the plot per metric and variant for the provided iteration"
|
description: "Return plots for the provided iteration"
|
||||||
request {
|
request {
|
||||||
type: object
|
type: object
|
||||||
required: [task, metric, variant]
|
required: [task, metric]
|
||||||
properties {
|
properties {
|
||||||
task {
|
task {
|
||||||
description: "Task ID"
|
description: "Task ID"
|
||||||
@ -751,10 +752,6 @@ get_plot_sample {
|
|||||||
description: "Metric name"
|
description: "Metric name"
|
||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
variant {
|
|
||||||
description: "Metric variant"
|
|
||||||
type: string
|
|
||||||
}
|
|
||||||
iteration {
|
iteration {
|
||||||
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
|
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
|
||||||
type: integer
|
type: integer
|
||||||
@ -786,7 +783,7 @@ get_plot_sample {
|
|||||||
}
|
}
|
||||||
next_plot_sample {
|
next_plot_sample {
|
||||||
"2.20": {
|
"2.20": {
|
||||||
description: "Get the plot for the next variant for the same iteration or for the next iteration"
|
description: "Get the plot for the next metric for the same iteration or for the next iteration"
|
||||||
request {
|
request {
|
||||||
type: object
|
type: object
|
||||||
required: [task, scroll_id]
|
required: [task, scroll_id]
|
||||||
@ -801,8 +798,8 @@ next_plot_sample {
|
|||||||
}
|
}
|
||||||
navigate_earlier {
|
navigate_earlier {
|
||||||
type: boolean
|
type: boolean
|
||||||
description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration.
|
description: """If set then get the either previous metric events from the current iteration or (if does not exist) the last metric events from the previous iteration.
|
||||||
Otherwise next variant event from the current iteration or first variant event from the next iteration"""
|
Otherwise next metric events from the current iteration or first metric events from the next iteration"""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,6 @@ from apiserver.apimodels.events import (
|
|||||||
TaskMetricsRequest,
|
TaskMetricsRequest,
|
||||||
LogEventsRequest,
|
LogEventsRequest,
|
||||||
LogOrderEnum,
|
LogOrderEnum,
|
||||||
GetHistorySampleRequest,
|
|
||||||
NextHistorySampleRequest,
|
NextHistorySampleRequest,
|
||||||
MetricVariants as ApiMetrics,
|
MetricVariants as ApiMetrics,
|
||||||
TaskPlotsRequest,
|
TaskPlotsRequest,
|
||||||
@ -28,6 +27,7 @@ from apiserver.apimodels.events import (
|
|||||||
ClearScrollRequest,
|
ClearScrollRequest,
|
||||||
ClearTaskLogRequest,
|
ClearTaskLogRequest,
|
||||||
SingleValueMetricsRequest,
|
SingleValueMetricsRequest,
|
||||||
|
GetVariantSampleRequest, GetMetricSamplesRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||||
@ -827,9 +827,9 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
|
|||||||
@endpoint(
|
@endpoint(
|
||||||
"events.get_debug_image_sample",
|
"events.get_debug_image_sample",
|
||||||
min_version="2.12",
|
min_version="2.12",
|
||||||
request_data_model=GetHistorySampleRequest,
|
request_data_model=GetVariantSampleRequest,
|
||||||
)
|
)
|
||||||
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest):
|
def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
|
||||||
task_or_model = _assert_task_or_model_exists(
|
task_or_model = _assert_task_or_model_exists(
|
||||||
company_id, request.task, model_events=request.model_events,
|
company_id, request.task, model_events=request.model_events,
|
||||||
)[0]
|
)[0]
|
||||||
@ -866,17 +866,16 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
|
|||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"events.get_plot_sample", request_data_model=GetHistorySampleRequest,
|
"events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
|
||||||
)
|
)
|
||||||
def get_plot_sample(call, company_id, request: GetHistorySampleRequest):
|
def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
|
||||||
task_or_model = _assert_task_or_model_exists(
|
task_or_model = _assert_task_or_model_exists(
|
||||||
company_id, request.task, model_events=request.model_events,
|
company_id, request.task, model_events=request.model_events,
|
||||||
)[0]
|
)[0]
|
||||||
res = event_bll.plot_sample_history.get_sample_for_variant(
|
res = event_bll.plot_sample_history.get_samples_for_metric(
|
||||||
company_id=task_or_model.get_index_company(),
|
company_id=task_or_model.get_index_company(),
|
||||||
task=request.task,
|
task=request.task,
|
||||||
metric=request.metric,
|
metric=request.metric,
|
||||||
variant=request.variant,
|
|
||||||
iteration=request.iteration,
|
iteration=request.iteration,
|
||||||
refresh=request.refresh,
|
refresh=request.refresh,
|
||||||
state_id=request.scroll_id,
|
state_id=request.scroll_id,
|
||||||
|
@ -26,57 +26,55 @@ class TestTaskPlots(TestService):
|
|||||||
def test_get_plot_sample(self):
|
def test_get_plot_sample(self):
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
metric = "Metric1"
|
metric = "Metric1"
|
||||||
variant = "Variant1"
|
variants = ["Variant1", "Variant2"]
|
||||||
|
|
||||||
# test empty
|
# test empty
|
||||||
res = self.api.events.get_plot_sample(
|
res = self.api.events.get_plot_sample(task=task, metric=metric)
|
||||||
task=task, metric=metric, variant=variant
|
|
||||||
)
|
|
||||||
self.assertEqual(res.min_iteration, None)
|
self.assertEqual(res.min_iteration, None)
|
||||||
self.assertEqual(res.max_iteration, None)
|
self.assertEqual(res.max_iteration, None)
|
||||||
self.assertEqual(res.event, None)
|
self.assertEqual(res.events, [])
|
||||||
|
|
||||||
# test existing events
|
# test existing events
|
||||||
iterations = 10
|
iterations = 5
|
||||||
events = [
|
events = [
|
||||||
self._create_task_event(
|
self._create_task_event(
|
||||||
task=task,
|
task=task,
|
||||||
iteration=n,
|
iteration=n // len(variants),
|
||||||
metric=metric,
|
metric=metric,
|
||||||
variant=variant,
|
variant=variants[n % len(variants)],
|
||||||
plot_str=f"Test plot str {n}",
|
plot_str=f"Test plot str {n}",
|
||||||
)
|
)
|
||||||
for n in range(iterations)
|
for n in range(iterations * len(variants))
|
||||||
]
|
]
|
||||||
self.send_batch(events)
|
self.send_batch(events)
|
||||||
|
|
||||||
# if iteration is not specified then return the event from the last one
|
# if iteration is not specified then return the event from the last one
|
||||||
res = self.api.events.get_plot_sample(
|
res = self.api.events.get_plot_sample(task=task, metric=metric)
|
||||||
task=task, metric=metric, variant=variant
|
self._assertEqualEvents(res.events, events[-len(variants) :])
|
||||||
)
|
|
||||||
self._assertEqualEvent(res.event, events[-1])
|
|
||||||
self.assertEqual(res.max_iteration, iterations - 1)
|
self.assertEqual(res.max_iteration, iterations - 1)
|
||||||
self.assertEqual(res.min_iteration, 0)
|
self.assertEqual(res.min_iteration, 0)
|
||||||
self.assertTrue(res.scroll_id)
|
self.assertTrue(res.scroll_id)
|
||||||
|
|
||||||
# else from the specific iteration
|
# else from the specific iteration
|
||||||
iteration = 8
|
iteration = 3
|
||||||
res = self.api.events.get_plot_sample(
|
res = self.api.events.get_plot_sample(
|
||||||
task=task,
|
task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id,
|
||||||
metric=metric,
|
)
|
||||||
variant=variant,
|
self._assertEqualEvents(
|
||||||
iteration=iteration,
|
res.events,
|
||||||
scroll_id=res.scroll_id,
|
events[iteration * len(variants) : (iteration + 1) * len(variants)],
|
||||||
)
|
)
|
||||||
self._assertEqualEvent(res.event, events[iteration])
|
|
||||||
|
|
||||||
def test_next_plot_sample(self):
|
def test_next_plot_sample(self):
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
metric1 = "Metric1"
|
metric1 = "Metric1"
|
||||||
variant1 = "Variant1"
|
|
||||||
metric2 = "Metric2"
|
metric2 = "Metric2"
|
||||||
variant2 = "Variant2"
|
metrics = [
|
||||||
metrics = [(metric1, variant1), (metric2, variant2)]
|
(metric1, "variant1"),
|
||||||
|
(metric1, "variant2"),
|
||||||
|
(metric2, "variant3"),
|
||||||
|
(metric2, "variant4"),
|
||||||
|
]
|
||||||
# test existing events
|
# test existing events
|
||||||
events = [
|
events = [
|
||||||
self._create_task_event(
|
self._create_task_event(
|
||||||
@ -93,73 +91,73 @@ class TestTaskPlots(TestService):
|
|||||||
|
|
||||||
# single metric navigation
|
# single metric navigation
|
||||||
# init scroll
|
# init scroll
|
||||||
res = self.api.events.get_plot_sample(
|
res = self.api.events.get_plot_sample(task=task, metric=metric1)
|
||||||
task=task, metric=metric1, variant=variant1
|
self._assertEqualEvents(res.events, events[-4:-2])
|
||||||
)
|
|
||||||
self._assertEqualEvent(res.event, events[-2])
|
|
||||||
|
|
||||||
# navigate forwards
|
# navigate forwards
|
||||||
res = self.api.events.next_plot_sample(
|
res = self.api.events.next_plot_sample(
|
||||||
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
||||||
)
|
)
|
||||||
self.assertEqual(res.event, None)
|
self.assertEqual(res.events, [])
|
||||||
|
|
||||||
# navigate backwards
|
# navigate backwards
|
||||||
res = self.api.events.next_plot_sample(
|
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||||
task=task, scroll_id=res.scroll_id
|
self._assertEqualEvents(res.events, events[-8:-6])
|
||||||
)
|
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||||
self._assertEqualEvent(res.event, events[-4])
|
self._assertEqualEvents(res.events, [])
|
||||||
res = self.api.events.next_plot_sample(
|
|
||||||
task=task, scroll_id=res.scroll_id
|
|
||||||
)
|
|
||||||
self._assertEqualEvent(res.event, None)
|
|
||||||
|
|
||||||
# all metrics navigation
|
# all metrics navigation
|
||||||
# init scroll
|
# init scroll
|
||||||
res = self.api.events.get_plot_sample(
|
res = self.api.events.get_plot_sample(
|
||||||
task=task, metric=metric1, variant=variant1, navigate_current_metric=False
|
task=task, metric=metric1, navigate_current_metric=False
|
||||||
)
|
)
|
||||||
self._assertEqualEvent(res.event, events[-2])
|
self._assertEqualEvents(res.events, events[-4:-2])
|
||||||
|
|
||||||
# navigate forwards
|
# navigate forwards
|
||||||
res = self.api.events.next_plot_sample(
|
res = self.api.events.next_plot_sample(
|
||||||
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
task=task, scroll_id=res.scroll_id, navigate_earlier=False
|
||||||
)
|
)
|
||||||
self._assertEqualEvent(res.event, events[-1])
|
self._assertEqualEvents(res.events, events[-2:])
|
||||||
|
|
||||||
# navigate backwards
|
# navigate backwards
|
||||||
res = self.api.events.next_plot_sample(
|
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||||
task=task, scroll_id=res.scroll_id
|
self._assertEqualEvents(res.events, events[-4:-2])
|
||||||
)
|
res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
|
||||||
self._assertEqualEvent(res.event, events[-2])
|
self._assertEqualEvents(res.events, events[-6:-4])
|
||||||
res = self.api.events.next_plot_sample(
|
|
||||||
task=task, scroll_id=res.scroll_id
|
|
||||||
)
|
|
||||||
self._assertEqualEvent(res.event, events[-3])
|
|
||||||
|
|
||||||
# next_iteration
|
# next_iteration
|
||||||
res = self.api.events.next_plot_sample(
|
res = self.api.events.next_plot_sample(
|
||||||
task=task, scroll_id=res.scroll_id, next_iteration=True
|
task=task, scroll_id=res.scroll_id, next_iteration=True
|
||||||
)
|
)
|
||||||
self._assertEqualEvent(res.event, None)
|
self._assertEqualEvents(res.events, [])
|
||||||
res = self.api.events.next_plot_sample(
|
res = self.api.events.next_plot_sample(
|
||||||
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
|
task=task,
|
||||||
|
scroll_id=res.scroll_id,
|
||||||
|
next_iteration=True,
|
||||||
|
navigate_earlier=False,
|
||||||
)
|
)
|
||||||
self._assertEqualEvent(res.event, events[-2])
|
self._assertEqualEvents(res.events, events[-4:-2])
|
||||||
self.assertEqual(res.event.iter, 1)
|
self.assertTrue(all(ev.iter == 1 for ev in res.events))
|
||||||
res = self.api.events.next_plot_sample(
|
res = self.api.events.next_plot_sample(
|
||||||
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
|
task=task,
|
||||||
|
scroll_id=res.scroll_id,
|
||||||
|
next_iteration=True,
|
||||||
|
navigate_earlier=False,
|
||||||
)
|
)
|
||||||
self._assertEqualEvent(res.event, None)
|
self._assertEqualEvents(res.events, [])
|
||||||
|
|
||||||
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]):
|
def _assertEqualEvents(
|
||||||
if ev2 is None:
|
self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]]
|
||||||
self.assertIsNone(ev1)
|
):
|
||||||
return
|
self.assertEqual(len(ev_source), len(ev_target))
|
||||||
self.assertIsNotNone(ev1)
|
|
||||||
|
def compare_event(ev1, ev2):
|
||||||
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
|
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
|
||||||
self.assertEqual(ev1[field], ev2[field])
|
self.assertEqual(ev1[field], ev2[field])
|
||||||
|
|
||||||
|
for e1, e2 in zip(ev_source, ev_target):
|
||||||
|
compare_event(e1, e2)
|
||||||
|
|
||||||
def test_task_plots(self):
|
def test_task_plots(self):
|
||||||
task = self._temp_task()
|
task = self._temp_task()
|
||||||
|
|
||||||
@ -238,12 +236,15 @@ class TestTaskPlots(TestService):
|
|||||||
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
self.assertTrue(all(m.iterations == [] for m in res.metrics))
|
||||||
return res.scroll_id
|
return res.scroll_id
|
||||||
|
|
||||||
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_)
|
expected_variants = set(
|
||||||
|
(m, var) for m, vars_ in expected_metrics.items() for var in vars_
|
||||||
|
)
|
||||||
for metric_data in res.metrics:
|
for metric_data in res.metrics:
|
||||||
self.assertEqual(len(metric_data.iterations), iterations)
|
self.assertEqual(len(metric_data.iterations), iterations)
|
||||||
for it_data in metric_data.iterations:
|
for it_data in metric_data.iterations:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
set((e.metric, e.variant) for e in it_data.events), expected_variants
|
set((e.metric, e.variant) for e in it_data.events),
|
||||||
|
expected_variants,
|
||||||
)
|
)
|
||||||
|
|
||||||
return res.scroll_id
|
return res.scroll_id
|
||||||
@ -281,7 +282,7 @@ class TestTaskPlots(TestService):
|
|||||||
task=task,
|
task=task,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
variants=len(variants)
|
variants=len(variants),
|
||||||
)
|
)
|
||||||
|
|
||||||
# test forward navigation
|
# test forward navigation
|
||||||
|
@ -40,14 +40,20 @@ class TestWorkersService(TestService):
|
|||||||
def test_system_tags(self):
|
def test_system_tags(self):
|
||||||
test_worker = f"test_{uuid4().hex}"
|
test_worker = f"test_{uuid4().hex}"
|
||||||
tag = uuid4().hex
|
tag = uuid4().hex
|
||||||
|
system_tag = uuid4().hex
|
||||||
|
self.api.workers.register(
|
||||||
|
worker=test_worker, tags=[tag], system_tags=[system_tag], timeout=5
|
||||||
|
)
|
||||||
|
|
||||||
# system_tags support
|
# system_tags support
|
||||||
worker = self.api.workers.get_all(tags=[tag], system_tags=["Application"]).workers[0]
|
worker = self.api.workers.get_all(tags=[tag], system_tags=[system_tag]).workers[
|
||||||
|
0
|
||||||
|
]
|
||||||
self.assertEqual(worker.id, test_worker)
|
self.assertEqual(worker.id, test_worker)
|
||||||
self.assertEqual(worker.tags, [tag])
|
self.assertEqual(worker.tags, [tag])
|
||||||
self.assertEqual(worker.system_tags, ["Application"])
|
self.assertEqual(worker.system_tags, [system_tag])
|
||||||
|
|
||||||
workers = self.api.workers.get_all(tags=[tag], system_tags=["-Application"]).workers
|
workers = self.api.workers.get_all(tags=[tag], system_tags=[f"-{system_tag}"]).workers
|
||||||
self.assertFalse(workers)
|
self.assertFalse(workers)
|
||||||
|
|
||||||
def test_filters(self):
|
def test_filters(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user