Support returning multiple plots during history navigation

This commit is contained in:
allegroai 2022-11-29 17:37:30 +02:00
parent bc23f1b0cf
commit 97992b0d9e
10 changed files with 846 additions and 597 deletions

View File

@ -61,13 +61,20 @@ class MetricEventsRequest(Base):
model_events: bool = BoolField() model_events: bool = BoolField()
class TaskMetricVariant(Base): class GetVariantSampleRequest(Base):
task: str = StringField(required=True) task: str = StringField(required=True)
metric: str = StringField(required=True) metric: str = StringField(required=True)
variant: str = StringField(required=True) variant: str = StringField(required=True)
iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
navigate_current_metric: bool = BoolField(default=True)
model_events: bool = BoolField(default=False)
class GetHistorySampleRequest(TaskMetricVariant): class GetMetricSamplesRequest(Base):
task: str = StringField(required=True)
metric: str = StringField(required=True)
iteration: Optional[int] = IntField() iteration: Optional[int] = IntField()
refresh: bool = BoolField(default=False) refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField() scroll_id: Optional[str] = StringField()

View File

@ -26,7 +26,7 @@ from apiserver.bll.event.event_common import (
) )
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator from apiserver.bll.event.history_debug_image_iterator import HistoryDebugImageIterator
from apiserver.bll.event.history_plot_iterator import HistoryPlotIterator from apiserver.bll.event.history_plots_iterator import HistoryPlotsIterator
from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIterator
from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator
from apiserver.bll.util import parallel_chunked_decorator from apiserver.bll.util import parallel_chunked_decorator
@ -93,7 +93,7 @@ class EventBLL(object):
es=self.es, redis=self.redis es=self.es, redis=self.redis
) )
self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis) self.plots_iterator = MetricPlotsIterator(es=self.es, redis=self.redis)
self.plot_sample_history = HistoryPlotIterator(es=self.es, redis=self.redis) self.plot_sample_history = HistoryPlotsIterator(es=self.es, redis=self.redis)
self.events_iterator = EventsIterator(es=self.es) self.events_iterator = EventsIterator(es=self.es)
@property @property

View File

@ -1,21 +1,126 @@
from typing import Sequence, Tuple, Callable import operator
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, BoolField, ListField
from jsonmodels.models import Base
from redis.client import StrictRedis from redis.client import StrictRedis
from apiserver.utilities.dicts import nested_get from apiserver.utilities.dicts import nested_get
from .event_common import EventType from .event_common import (
from .history_sample_iterator import HistorySampleIterator, VariantState EventType,
EventSettings,
check_empty_data,
search_company_events,
get_max_metric_and_variant_counts,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class HistoryDebugImageIterator(HistorySampleIterator): class VariantState(Base):
name: str = StringField(required=True)
metric: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugImageSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class VariantSampleResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class HistoryDebugImageIterator:
event_type = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch): def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_image) self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugImageSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def _get_extra_conditions(self) -> Sequence[dict]: def get_next_sample(
return [{"exists": {"field": "url"}}] self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> VariantSampleResult:
"""
Get the sample for next/prev variant on the current iteration
If does not exist then try getting sample for the first/last variant from next/prev iteration
"""
res = VariantSampleResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict: if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if next_iteration:
event = self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
else:
# noinspection PyArgumentList
event = first(
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
for f in (
self._get_next_for_current_iteration,
self._get_next_for_another_iteration,
)
)
if not event:
return res
self._fill_res_and_update_state(event=event, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
@staticmethod
def _fill_res_and_update_state(
event: dict, res: VariantSampleResult, state: DebugImageSampleState
):
state.variant = event["variant"]
state.metric = event["metric"]
state.iteration = event["iter"]
res.event = event
var_state = first(
vs
for vs in state.variant_states
if vs.name == state.variant and vs.metric == state.metric
)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
@staticmethod
def _get_metric_conditions(variants: Sequence[VariantState]) -> dict:
metrics = bucketize(variants, key=attrgetter("metric"))
def _get_variants_conditions(metric_variants: Sequence[VariantState]) -> dict:
variants_conditions = [ variants_conditions = [
{ {
"bool": { "bool": {
@ -25,32 +130,326 @@ class HistoryDebugImageIterator(HistorySampleIterator):
] ]
} }
} }
for v in variants for v in metric_variants
] ]
return {"bool": {"should": variants_conditions}} return {"bool": {"should": variants_conditions}}
def _process_event(self, event: dict) -> dict: metrics_conditions = [
return event {
"bool": {
"must": [
{"term": {"metric": metric}},
_get_variants_conditions(metric_variants),
]
}
}
for metric, metric_variants in metrics.items()
]
return {"bool": {"should": metrics_conditions}}
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]: def _get_next_for_current_iteration(
# The min iteration is the lowest iteration that contains non-recycled image url self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
aggs = { ) -> Optional[dict]:
"""
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or sample is found
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in variants
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"iter": state.iteration}},
self._get_metric_conditions(variants),
{"exists": {"field": "url"}},
]
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: DebugImageSampleState
) -> Optional[dict]:
"""
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the sample falls in invalid range are discarded
If no suitable sample is found then None is returned
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in variants
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = variants
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
self._get_metric_conditions(variants),
{"range": {"iter": {range_operator: state.iteration}}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_sample_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> VariantSampleResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = VariantSampleResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: DebugImageSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugImageSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
# fix old variant states:
for vs in state_.variant_states:
if vs.metric is None:
vs.metric = metric
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugImageSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(
vs
for vs in state.variant_states
if vs.name == variant and vs.metric == metric
)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
event=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: DebugImageSampleState):
metrics = self._get_metric_variant_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.variant_states = [
VariantState(
metric=metric,
name=var_name,
min_iteration=min_iter,
max_iteration=max_iter,
)
for metric, variants in metrics.items()
for var_name, min_iter, max_iter in variants
]
def _get_metric_variant_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Sequence[Tuple[str, int, int]]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
{"exists": {"field": "url"}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type,
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}}, "last_iter": {"max": {"field": "iter"}},
"urls": { "urls": {
# group by urls and choose the minimal iteration # group by urls and choose the minimal iteration
# from all the maximal iterations per url # from all the maximal iterations per url
"terms": {"field": "url", "order": {"max_iter": "asc"}, "size": 1}, "terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": { "aggs": {
# find max iteration for each url # find max iteration for each url
"max_iter": {"max": {"field": "iter"}} "max_iter": {"max": {"field": "iter"}}
}, },
}, },
},
}
},
}
},
} }
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]: es_res = search_company_events(body=es_req, **search_args)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets")) urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"]) min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"]) max_iter = int(variant_bucket["last_iter"]["value"])
return min_iter, max_iter return variant, min_iter, max_iter
return aggs, get_min_max_data return {
metric_bucket["key"]: [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
]
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}

View File

@ -1,36 +0,0 @@
from typing import Sequence, Tuple, Callable
from elasticsearch import Elasticsearch
from redis.client import StrictRedis
from .event_common import EventType, uncompress_plot
from .history_sample_iterator import HistorySampleIterator, VariantState
class HistoryPlotIterator(HistorySampleIterator):
def __init__(self, redis: StrictRedis, es: Elasticsearch):
super().__init__(redis, es, EventType.metrics_plot)
def _get_extra_conditions(self) -> Sequence[dict]:
return []
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
return {"terms": {"variant": [v.name for v in variants]}}
def _process_event(self, event: dict) -> dict:
uncompress_plot(event)
return event
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
# The min iteration is the lowest iteration that contains non-recycled image url
aggs = {
"last_iter": {"max": {"field": "iter"}},
"first_iter": {"min": {"field": "iter"}},
}
def get_min_max_data(variant_bucket: dict) -> Tuple[int, int]:
min_iter = int(variant_bucket["first_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return min_iter, max_iter
return aggs, get_min_max_data

View File

@ -0,0 +1,316 @@
from typing import Sequence, Tuple, Optional, Mapping
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField, ListField, BoolField
from jsonmodels.models import Base
from redis.client import StrictRedis
from .event_common import (
EventType,
uncompress_plot,
EventSettings,
check_empty_data,
search_company_events,
)
from apiserver.apimodels import JsonSerializableMixin
from apiserver.utilities.dicts import nested_get
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.apierrors import errors
class MetricState(Base):
name: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class PlotsSampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
task: str = StringField()
metric: str = StringField()
metric_states: Sequence[MetricState] = ListField([MetricState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class MetricSamplesResult(object):
scroll_id: str = None
events: list = []
min_iteration: int = None
max_iteration: int = None
class HistoryPlotsIterator:
event_type = EventType.metrics_plot
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=PlotsSampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> MetricSamplesResult:
"""
Get the samples for next/prev metric on the current iteration
If does not exist then try getting sample for the first/last metric from next/prev iteration
"""
res = MetricSamplesResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if navigate_earlier:
range_operator = "lt"
order = "desc"
else:
range_operator = "gt"
order = "asc"
must_conditions = [
{"term": {"task": state.task}},
]
if state.navigate_current_metric:
must_conditions.append({"term": {"metric": state.metric}})
next_iteration_condition = {
"range": {"iter": {range_operator: state.iteration}}
}
if next_iteration or state.navigate_current_metric:
must_conditions.append(next_iteration_condition)
else:
next_metric_condition = {
"bool": {
"must": [
{"term": {"iter": state.iteration}},
{"range": {"metric": {range_operator: state.metric}}},
]
}
}
must_conditions.append(
{"bool": {"should": [next_metric_condition, next_iteration_condition]}}
)
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order=order,
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def get_samples_for_metric(
self,
company_id: str,
task: str,
metric: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> MetricSamplesResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = MetricSamplesResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: PlotsSampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_metric_states(company_id=company_id, state=state_)
def validate_state(state_: PlotsSampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
if refresh:
self._reset_metric_states(company_id=company_id, state=state_)
state: PlotsSampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
metric_state = first(ms for ms in state.metric_states if ms.name == metric)
if not metric_state:
return res
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
]
if iteration is not None:
must_conditions.append({"range": {"iter": {"lte": iteration}}})
events = self._get_metric_events_for_condition(
company_id=company_id,
task=state.task,
order="desc",
must_conditions=must_conditions,
)
if not events:
return res
self._fill_res_and_update_state(events=events, res=res, state=state)
return res
def _reset_metric_states(self, company_id: str, state: PlotsSampleState):
metrics = self._get_metric_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.metric_states = [
MetricState(name=metric, min_iteration=min_iter, max_iteration=max_iter)
for metric, (min_iter, max_iter) in metrics.items()
]
def _get_metric_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Tuple[int, int]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 5000,
"order": {"_key": "asc"},
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"first_iter": {"min": {"field": "iter"}},
},
}
},
}
es_res = search_company_events(
body=es_req,
es=self.es,
company_id=company_id,
event_type=self.event_type,
)
return {
metric_bucket["key"]: (
int(metric_bucket["first_iter"]["value"]),
int(metric_bucket["last_iter"]["value"]),
)
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}
@staticmethod
def _fill_res_and_update_state(
events: Sequence[dict], res: MetricSamplesResult, state: PlotsSampleState
):
for event in events:
uncompress_plot(event)
state.metric = events[0]["metric"]
state.iteration = events[0]["iter"]
res.events = events
metric_state = first(
ms for ms in state.metric_states if ms.name == state.metric
)
if metric_state:
res.min_iteration = metric_state.min_iteration
res.max_iteration = metric_state.max_iteration
def _get_metric_events_for_condition(
self, company_id: str, task: str, order: str, must_conditions: Sequence
) -> Sequence:
es_req = {
"size": 0,
"query": {"bool": {"must": must_conditions}},
"aggs": {
"iters": {
"terms": {"field": "iter", "size": 1, "order": {"_key": order}},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 1,
"order": {"_key": order},
},
"aggs": {
"events": {
"top_hits": {
"sort": {"variant": {"order": "asc"}},
"size": 100,
}
}
},
},
},
}
},
}
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return []
for level in ("iters", "metrics"):
level_data = aggs_result[level]["buckets"]
if not level_data:
return []
aggs_result = level_data[0]
return [
hit["_source"]
for hit in nested_get(aggs_result, ("events", "hits", "hits"))
]

View File

@ -1,440 +0,0 @@
import abc
import operator
from operator import attrgetter
from typing import Sequence, Tuple, Optional, Callable, Mapping
import attr
from boltons.iterutils import first, bucketize
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import (
EventSettings,
EventType,
check_empty_data,
search_company_events,
get_max_metric_and_variant_counts,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
metric: str = StringField(default=None)
min_iteration: int = IntField()
max_iteration: int = IntField()
class HistorySampleState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
reached_first: bool = BoolField()
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
warning: str = StringField()
navigate_current_metric = BoolField(default=True)
@attr.s(auto_attribs=True)
class HistorySampleResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class HistorySampleIterator(abc.ABC):
def __init__(self, redis: StrictRedis, es: Elasticsearch, event_type: EventType):
self.es = es
self.event_type = event_type
self.cache_manager = RedisCacheManager(
state_class=HistorySampleState,
redis=redis,
expiration_interval=EventSettings.state_expiration_sec,
)
def get_next_sample(
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> HistorySampleResult:
"""
Get the sample for next/prev variant on the current iteration
If does not exist then try getting sample for the first/last variant from next/prev iteration
"""
res = HistorySampleResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
if next_iteration:
event = self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
else:
# noinspection PyArgumentList
event = first(
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
for f in (
self._get_next_for_current_iteration,
self._get_next_for_another_iteration,
)
)
if not event:
return res
self._fill_res_and_update_state(event=event, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, event: dict, res: HistorySampleResult, state: HistorySampleState
):
self._process_event(event)
state.variant = event["variant"]
state.metric = event["metric"]
state.iteration = event["iter"]
res.event = event
var_state = first(
vs
for vs in state.variant_states
if vs.name == state.variant and vs.metric == state.metric
)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
@abc.abstractmethod
def _get_extra_conditions(self) -> Sequence[dict]:
pass
@abc.abstractmethod
def _process_event(self, event: dict) -> dict:
pass
@abc.abstractmethod
def _get_variants_conditions(self, variants: Sequence[VariantState]) -> dict:
pass
def _get_metric_variants_condition(self, variants: Sequence[VariantState]) -> dict:
metrics = bucketize(variants, key=attrgetter("metric"))
metrics_conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
self._get_variants_conditions(vs),
]
}
}
for metric, vs in metrics.items()
]
return {"bool": {"should": metrics_conditions}}
def _get_next_for_current_iteration(
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
) -> Optional[dict]:
"""
Get the sample for next (if navigate_earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or sample is found
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in variants
if cmp((var_state.metric, var_state.name), (state.metric, state.variant))
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"iter": state.iteration}},
self._get_metric_variants_condition(variants),
*self._get_extra_conditions(),
]
order = "desc" if navigate_earlier else "asc"
es_req = {
"size": 1,
"sort": [{"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, company_id: str, navigate_earlier: bool, state: HistorySampleState
) -> Optional[dict]:
"""
Get the sample for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the sample falls in invalid range are discarded
If no suitable sample is found then None is returned
"""
if state.navigate_current_metric:
variants = [
var_state
for var_state in state.variant_states
if var_state.metric == state.metric
]
else:
variants = state.variant_states
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in variants
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = variants
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
self._get_metric_variants_condition(variants),
{"range": {"iter": {range_operator: state.iteration}}},
*self._get_extra_conditions(),
]
es_req = {
"size": 1,
"sort": [{"iter": order}, {"metric": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context():
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.event_type, body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_sample_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None,
navigate_current_metric: bool = True,
) -> HistorySampleResult:
"""
Get the sample for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = HistorySampleResult()
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
def init_state(state_: HistorySampleState):
state_.task = task
state_.metric = metric
state_.navigate_current_metric = navigate_current_metric
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: HistorySampleState):
if (
state_.task != task
or state_.navigate_current_metric != navigate_current_metric
or (state_.navigate_current_metric and state_.metric != metric)
):
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
# fix old variant states:
for vs in state_.variant_states:
if vs.metric is None:
vs.metric = metric
if refresh:
self._reset_variant_states(company_id=company_id, state=state_)
state: HistorySampleState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(
vs
for vs in state.variant_states
if vs.name == variant and vs.metric == metric
)
if not var_state:
return res
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
*self._get_extra_conditions(),
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context():
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.event_type,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
event=hits[0]["_source"], res=res, state=state
)
return res
def _reset_variant_states(self, company_id: str, state: HistorySampleState):
metrics = self._get_metric_variant_iterations(
company_id=company_id,
task=state.task,
metric=state.metric if state.navigate_current_metric else None,
)
state.variant_states = [
VariantState(
metric=metric,
name=var_name,
min_iteration=min_iter,
max_iteration=max_iter,
)
for metric, variants in metrics.items()
for var_name, min_iter, max_iter in variants
]
@abc.abstractmethod
def _get_min_max_aggs(self) -> Tuple[dict, Callable[[dict], Tuple[int, int]]]:
pass
def _get_metric_variant_iterations(
self, company_id: str, task: str, metric: str,
) -> Mapping[str, Tuple[str, str, int, int]]:
"""
Return valid min and max iterations that the task reported events of the required type
"""
must = [
{"term": {"task": task}},
*self._get_extra_conditions(),
]
if metric is not None:
must.append({"term": {"metric": metric}})
query = {"bool": {"must": must}}
search_args = dict(
es=self.es, company_id=company_id, event_type=self.event_type
)
max_metrics, max_variants = get_max_metric_and_variant_counts(
query=query, **search_args
)
max_variants = int(max_variants // 2)
min_max_aggs, get_min_max_data = self._get_min_max_aggs()
es_req: dict = {
"size": 0,
"query": query,
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": max_metrics,
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": max_variants,
"order": {"_key": "asc"},
},
"aggs": min_max_aggs,
}
},
}
},
}
with translate_errors_context():
es_res = search_company_events(body=es_req, **search_args,)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
min_iter, max_iter = get_min_max_data(variant_bucket)
return variant, min_iter, max_iter
return {
metric_bucket["key"]: [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(metric_bucket, ("variants", "buckets"))
]
for metric_bucket in nested_get(
es_res, ("aggregations", "metrics", "buckets")
)
}

View File

@ -372,17 +372,18 @@ _definitions {
type: string type: string
description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample" description: "Scroll ID to pass to the next calls to get_plot_sample or next_plot_sample"
} }
event { events {
type: object description: "Plot events"
description: "Plot event" type: array
items { type: object}
} }
min_iteration { min_iteration {
type: integer type: integer
description: "minimal valid iteration for the variant" description: "minimal valid iteration for the metric"
} }
max_iteration { max_iteration {
type: integer type: integer
description: "maximal valid iteration for the variant" description: "maximal valid iteration for the metric"
} }
} }
} }
@ -738,10 +739,10 @@ next_debug_image_sample {
} }
get_plot_sample { get_plot_sample {
"2.20": { "2.20": {
description: "Return the plot per metric and variant for the provided iteration" description: "Return plots for the provided iteration"
request { request {
type: object type: object
required: [task, metric, variant] required: [task, metric]
properties { properties {
task { task {
description: "Task ID" description: "Task ID"
@ -751,10 +752,6 @@ get_plot_sample {
description: "Metric name" description: "Metric name"
type: string type: string
} }
variant {
description: "Metric variant"
type: string
}
iteration { iteration {
description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved" description: "The iteration to bring plot from. If not specified then the latest reported plot is retrieved"
type: integer type: integer
@ -786,7 +783,7 @@ get_plot_sample {
} }
next_plot_sample { next_plot_sample {
"2.20": { "2.20": {
description: "Get the plot for the next variant for the same iteration or for the next iteration" description: "Get the plot for the next metric for the same iteration or for the next iteration"
request { request {
type: object type: object
required: [task, scroll_id] required: [task, scroll_id]
@ -801,8 +798,8 @@ next_plot_sample {
} }
navigate_earlier { navigate_earlier {
type: boolean type: boolean
description: """If set then get the either previous variant event from the current iteration or (if does not exist) the last variant event from the previous iteration. description: """If set then get the either previous metric events from the current iteration or (if does not exist) the last metric events from the previous iteration.
Otherwise next variant event from the current iteration or first variant event from the next iteration""" Otherwise next metric events from the current iteration or first metric events from the next iteration"""
} }
} }
} }

View File

@ -19,7 +19,6 @@ from apiserver.apimodels.events import (
TaskMetricsRequest, TaskMetricsRequest,
LogEventsRequest, LogEventsRequest,
LogOrderEnum, LogOrderEnum,
GetHistorySampleRequest,
NextHistorySampleRequest, NextHistorySampleRequest,
MetricVariants as ApiMetrics, MetricVariants as ApiMetrics,
TaskPlotsRequest, TaskPlotsRequest,
@ -28,6 +27,7 @@ from apiserver.apimodels.events import (
ClearScrollRequest, ClearScrollRequest,
ClearTaskLogRequest, ClearTaskLogRequest,
SingleValueMetricsRequest, SingleValueMetricsRequest,
GetVariantSampleRequest, GetMetricSamplesRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants from apiserver.bll.event.event_common import EventType, MetricVariants
@ -827,9 +827,9 @@ def get_debug_images(call, company_id, request: MetricEventsRequest):
@endpoint( @endpoint(
"events.get_debug_image_sample", "events.get_debug_image_sample",
min_version="2.12", min_version="2.12",
request_data_model=GetHistorySampleRequest, request_data_model=GetVariantSampleRequest,
) )
def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest): def get_debug_image_sample(call, company_id, request: GetVariantSampleRequest):
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events, company_id, request.task, model_events=request.model_events,
)[0] )[0]
@ -866,17 +866,16 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
@endpoint( @endpoint(
"events.get_plot_sample", request_data_model=GetHistorySampleRequest, "events.get_plot_sample", request_data_model=GetMetricSamplesRequest,
) )
def get_plot_sample(call, company_id, request: GetHistorySampleRequest): def get_plot_sample(call, company_id, request: GetMetricSamplesRequest):
task_or_model = _assert_task_or_model_exists( task_or_model = _assert_task_or_model_exists(
company_id, request.task, model_events=request.model_events, company_id, request.task, model_events=request.model_events,
)[0] )[0]
res = event_bll.plot_sample_history.get_sample_for_variant( res = event_bll.plot_sample_history.get_samples_for_metric(
company_id=task_or_model.get_index_company(), company_id=task_or_model.get_index_company(),
task=request.task, task=request.task,
metric=request.metric, metric=request.metric,
variant=request.variant,
iteration=request.iteration, iteration=request.iteration,
refresh=request.refresh, refresh=request.refresh,
state_id=request.scroll_id, state_id=request.scroll_id,

View File

@ -26,57 +26,55 @@ class TestTaskPlots(TestService):
def test_get_plot_sample(self): def test_get_plot_sample(self):
task = self._temp_task() task = self._temp_task()
metric = "Metric1" metric = "Metric1"
variant = "Variant1" variants = ["Variant1", "Variant2"]
# test empty # test empty
res = self.api.events.get_plot_sample( res = self.api.events.get_plot_sample(task=task, metric=metric)
task=task, metric=metric, variant=variant
)
self.assertEqual(res.min_iteration, None) self.assertEqual(res.min_iteration, None)
self.assertEqual(res.max_iteration, None) self.assertEqual(res.max_iteration, None)
self.assertEqual(res.event, None) self.assertEqual(res.events, [])
# test existing events # test existing events
iterations = 10 iterations = 5
events = [ events = [
self._create_task_event( self._create_task_event(
task=task, task=task,
iteration=n, iteration=n // len(variants),
metric=metric, metric=metric,
variant=variant, variant=variants[n % len(variants)],
plot_str=f"Test plot str {n}", plot_str=f"Test plot str {n}",
) )
for n in range(iterations) for n in range(iterations * len(variants))
] ]
self.send_batch(events) self.send_batch(events)
# if iteration is not specified then return the event from the last one # if iteration is not specified then return the event from the last one
res = self.api.events.get_plot_sample( res = self.api.events.get_plot_sample(task=task, metric=metric)
task=task, metric=metric, variant=variant self._assertEqualEvents(res.events, events[-len(variants) :])
)
self._assertEqualEvent(res.event, events[-1])
self.assertEqual(res.max_iteration, iterations - 1) self.assertEqual(res.max_iteration, iterations - 1)
self.assertEqual(res.min_iteration, 0) self.assertEqual(res.min_iteration, 0)
self.assertTrue(res.scroll_id) self.assertTrue(res.scroll_id)
# else from the specific iteration # else from the specific iteration
iteration = 8 iteration = 3
res = self.api.events.get_plot_sample( res = self.api.events.get_plot_sample(
task=task, task=task, metric=metric, iteration=iteration, scroll_id=res.scroll_id,
metric=metric, )
variant=variant, self._assertEqualEvents(
iteration=iteration, res.events,
scroll_id=res.scroll_id, events[iteration * len(variants) : (iteration + 1) * len(variants)],
) )
self._assertEqualEvent(res.event, events[iteration])
def test_next_plot_sample(self): def test_next_plot_sample(self):
task = self._temp_task() task = self._temp_task()
metric1 = "Metric1" metric1 = "Metric1"
variant1 = "Variant1"
metric2 = "Metric2" metric2 = "Metric2"
variant2 = "Variant2" metrics = [
metrics = [(metric1, variant1), (metric2, variant2)] (metric1, "variant1"),
(metric1, "variant2"),
(metric2, "variant3"),
(metric2, "variant4"),
]
# test existing events # test existing events
events = [ events = [
self._create_task_event( self._create_task_event(
@ -93,73 +91,73 @@ class TestTaskPlots(TestService):
# single metric navigation # single metric navigation
# init scroll # init scroll
res = self.api.events.get_plot_sample( res = self.api.events.get_plot_sample(task=task, metric=metric1)
task=task, metric=metric1, variant=variant1 self._assertEqualEvents(res.events, events[-4:-2])
)
self._assertEqualEvent(res.event, events[-2])
# navigate forwards # navigate forwards
res = self.api.events.next_plot_sample( res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False task=task, scroll_id=res.scroll_id, navigate_earlier=False
) )
self.assertEqual(res.event, None) self.assertEqual(res.events, [])
# navigate backwards # navigate backwards
res = self.api.events.next_plot_sample( res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
task=task, scroll_id=res.scroll_id self._assertEqualEvents(res.events, events[-8:-6])
) res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvent(res.event, events[-4]) self._assertEqualEvents(res.events, [])
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, None)
# all metrics navigation # all metrics navigation
# init scroll # init scroll
res = self.api.events.get_plot_sample( res = self.api.events.get_plot_sample(
task=task, metric=metric1, variant=variant1, navigate_current_metric=False task=task, metric=metric1, navigate_current_metric=False
) )
self._assertEqualEvent(res.event, events[-2]) self._assertEqualEvents(res.events, events[-4:-2])
# navigate forwards # navigate forwards
res = self.api.events.next_plot_sample( res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False task=task, scroll_id=res.scroll_id, navigate_earlier=False
) )
self._assertEqualEvent(res.event, events[-1]) self._assertEqualEvents(res.events, events[-2:])
# navigate backwards # navigate backwards
res = self.api.events.next_plot_sample( res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
task=task, scroll_id=res.scroll_id self._assertEqualEvents(res.events, events[-4:-2])
) res = self.api.events.next_plot_sample(task=task, scroll_id=res.scroll_id)
self._assertEqualEvent(res.event, events[-2]) self._assertEqualEvents(res.events, events[-6:-4])
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-3])
# next_iteration # next_iteration
res = self.api.events.next_plot_sample( res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True task=task, scroll_id=res.scroll_id, next_iteration=True
) )
self._assertEqualEvent(res.event, None) self._assertEqualEvents(res.events, [])
res = self.api.events.next_plot_sample( res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
) )
self._assertEqualEvent(res.event, events[-2]) self._assertEqualEvents(res.events, events[-4:-2])
self.assertEqual(res.event.iter, 1) self.assertTrue(all(ev.iter == 1 for ev in res.events))
res = self.api.events.next_plot_sample( res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False task=task,
scroll_id=res.scroll_id,
next_iteration=True,
navigate_earlier=False,
) )
self._assertEqualEvent(res.event, None) self._assertEqualEvents(res.events, [])
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]): def _assertEqualEvents(
if ev2 is None: self, ev_source: Sequence[dict], ev_target: Sequence[Optional[dict]]
self.assertIsNone(ev1) ):
return self.assertEqual(len(ev_source), len(ev_target))
self.assertIsNotNone(ev1)
def compare_event(ev1, ev2):
for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"): for field in ("iter", "timestamp", "metric", "variant", "plot_str", "task"):
self.assertEqual(ev1[field], ev2[field]) self.assertEqual(ev1[field], ev2[field])
for e1, e2 in zip(ev_source, ev_target):
compare_event(e1, e2)
def test_task_plots(self): def test_task_plots(self):
task = self._temp_task() task = self._temp_task()
@ -238,12 +236,15 @@ class TestTaskPlots(TestService):
self.assertTrue(all(m.iterations == [] for m in res.metrics)) self.assertTrue(all(m.iterations == [] for m in res.metrics))
return res.scroll_id return res.scroll_id
expected_variants = set((m, var) for m, vars_ in expected_metrics.items() for var in vars_) expected_variants = set(
(m, var) for m, vars_ in expected_metrics.items() for var in vars_
)
for metric_data in res.metrics: for metric_data in res.metrics:
self.assertEqual(len(metric_data.iterations), iterations) self.assertEqual(len(metric_data.iterations), iterations)
for it_data in metric_data.iterations: for it_data in metric_data.iterations:
self.assertEqual( self.assertEqual(
set((e.metric, e.variant) for e in it_data.events), expected_variants set((e.metric, e.variant) for e in it_data.events),
expected_variants,
) )
return res.scroll_id return res.scroll_id
@ -281,7 +282,7 @@ class TestTaskPlots(TestService):
task=task, task=task,
metric=metric, metric=metric,
iterations=iterations, iterations=iterations,
variants=len(variants) variants=len(variants),
) )
# test forward navigation # test forward navigation

View File

@ -40,14 +40,20 @@ class TestWorkersService(TestService):
def test_system_tags(self): def test_system_tags(self):
test_worker = f"test_{uuid4().hex}" test_worker = f"test_{uuid4().hex}"
tag = uuid4().hex tag = uuid4().hex
system_tag = uuid4().hex
self.api.workers.register(
worker=test_worker, tags=[tag], system_tags=[system_tag], timeout=5
)
# system_tags support # system_tags support
worker = self.api.workers.get_all(tags=[tag], system_tags=["Application"]).workers[0] worker = self.api.workers.get_all(tags=[tag], system_tags=[system_tag]).workers[
0
]
self.assertEqual(worker.id, test_worker) self.assertEqual(worker.id, test_worker)
self.assertEqual(worker.tags, [tag]) self.assertEqual(worker.tags, [tag])
self.assertEqual(worker.system_tags, ["Application"]) self.assertEqual(worker.system_tags, [system_tag])
workers = self.api.workers.get_all(tags=[tag], system_tags=["-Application"]).workers workers = self.api.workers.get_all(tags=[tag], system_tags=[f"-{system_tag}"]).workers
self.assertFalse(workers) self.assertFalse(workers)
def test_filters(self): def test_filters(self):