Fix using reserved keywords as atrifact/hyperparams/configuration names

Replace events.get_debug_image_event and event.get_debug_image_iterations with events.get_debug_image_sample and events.next_debug_image_sample
This commit is contained in:
allegroai
2021-01-05 17:47:27 +02:00
parent d198138c5b
commit be788965e0
10 changed files with 516 additions and 190 deletions

View File

@@ -22,7 +22,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantScrollState(Base):
@@ -466,103 +465,3 @@ class DebugImagesIterator:
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]
def get_debug_image_event(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
) -> Optional[dict]:
"""
Get the debug image for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return None
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append({"range": {"iter": {"lte": iteration}}})
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}}
}
with translate_errors_context(), TimingContext("es", "get_debug_image_event"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return None
return hits[0]["_source"]
def get_debug_image_iterations(
self, company_id: str, task: str, metric: str, variant: str
) -> Tuple[int, int]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
"""
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return 0, 0
es_req: dict = {
"size": 1,
"sort": {"iter": "desc"},
"query": {
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
}
},
"aggs": {
"url_min": {
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1, # we need only one url from the least recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 1,
"_source": "iter",
}
},
}
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_iterations"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return 0, 0
max_iter = hits[0]["_source"]["iter"]
url_min_buckets = nested_get(es_res, ("aggregations", "url_min", "buckets"))
if not url_min_buckets:
return 0, max_iter
min_iter = url_min_buckets[0]["max_iter"]["value"]
return int(min_iter), max_iter

View File

@@ -0,0 +1,359 @@
import operator
from typing import Sequence, Tuple, Optional
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugSampleHistoryState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
reached_first: bool = BoolField()
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
@attr.s(auto_attribs=True)
class DebugSampleHistoryResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class DebugSampleHistory:
EVENT_TYPE = "training_debug_image"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_next_debug_image(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
) -> DebugSampleHistoryResult:
"""
Get the debug image for next/prev variant on the current iteration
If does not exist then try getting image for the first/last variant from next/prev iteration
"""
res = DebugSampleHistoryResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return res
image = self._get_next_for_current_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state
)
if not image:
return res
self._fill_res_and_update_state(image=image, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
):
state.variant = image["variant"]
state.iteration = image["iter"]
res.event = image
var_state = first(s for s in state.variant_states if s.name == state.variant)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
def _get_next_for_current_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or image is found
"""
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in state.variant_states
if cmp(var_state.name, state.variant)
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"terms": {"variant": [v.name for v in variants]}},
{"term": {"iter": state.iteration}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": {"variant": "desc" if navigate_earlier else "asc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the image falls in invalid range are discarded
If no suitable image is found then None is returned
"""
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"exists": {"field": "url"}},
]
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in state.variant_states
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = state.variant_states
if not variants:
return
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
]
must_conditions.append({"bool": {"should": variants_conditions}})
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
es_req = {
"size": 1,
"sort": [{"iter": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_debug_image_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
state_id: str = None,
) -> DebugSampleHistoryResult:
"""
Get the debug image for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return res
def init_state(state_: DebugSampleHistoryState):
state_.task = task
state_.metric = metric
variant_iterations = self._get_variant_iterations(
es_index=es_index, task=task, metric=metric
)
state_.variant_states = [
VariantState(
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
for var_name, min_iter, max_iter in variant_iterations
]
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(s for s in state.variant_states if s.name == variant)
if not var_state:
return res
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
image=hits[0]["_source"], res=res, state=state
)
return res
def _get_variant_iterations(
self,
es_index: str,
task: str,
metric: str,
variants: Optional[Sequence[str]] = None,
) -> Sequence[Tuple[str, int, int]]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
"""
must = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"exists": {"field": "url"}},
]
if variants:
must.append({"terms": {"variant": variants}})
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": {
"variants": {
# all variants that sent debug images
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return variant, min_iter, max_iter
return [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(
es_res, ("aggregations", "variants", "buckets")
)
]

View File

@@ -12,6 +12,7 @@ from elasticsearch import helpers
from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
@@ -54,6 +55,7 @@ class EventBLL(object):
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
@property