diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index 0f45d59..a359f15 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -57,12 +57,15 @@ class TaskMetricVariant(Base): variant: str = StringField(required=True) -class DebugImageIterationsRequest(TaskMetricVariant): - pass - - -class DebugImageEventRequest(TaskMetricVariant): +class GetDebugImageSampleRequest(TaskMetricVariant): iteration: Optional[int] = IntField() + scroll_id: Optional[str] = StringField() + + +class NextDebugImageSampleRequest(Base): + task: str = StringField(required=True) + scroll_id: Optional[str] = StringField() + navigate_earlier: bool = BoolField(default=True) class LogOrderEnum(StringEnum): diff --git a/apiserver/bll/event/debug_images_iterator.py b/apiserver/bll/event/debug_images_iterator.py index b157830..032d04f 100644 --- a/apiserver/bll/event/debug_images_iterator.py +++ b/apiserver/bll/event/debug_images_iterator.py @@ -22,7 +22,6 @@ from apiserver.database.errors import translate_errors_context from apiserver.database.model.task.metrics import MetricEventStats from apiserver.database.model.task.task import Task from apiserver.timing_context import TimingContext -from apiserver.utilities.dicts import nested_get class VariantScrollState(Base): @@ -466,103 +465,3 @@ class DebugImagesIterator: if events_to_remove: it["events"] = [ev for ev in it["events"] if ev not in events_to_remove] return [it for it in iterations if it["events"]] - - def get_debug_image_event( - self, - company_id: str, - task: str, - metric: str, - variant: str, - iteration: Optional[int] = None, - ) -> Optional[dict]: - """ - Get the debug image for the requested iteration or the latest before it - If the iteration is not passed then get the latest event - """ - es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) - if not self.es.indices.exists(es_index): - return None - - must_conditions = [ - {"term": {"task": task}}, - {"term": {"metric": metric}}, - {"term": {"variant": variant}}, - {"exists": {"field": "url"}}, - ] - if iteration is not None: - must_conditions.append({"range": {"iter": {"lte": iteration}}}) - - es_req = { - "size": 1, - "sort": {"iter": "desc"}, - "query": {"bool": {"must": must_conditions}} - } - - with translate_errors_context(), TimingContext("es", "get_debug_image_event"): - es_res = self.es.search(index=es_index, body=es_req, routing=task) - - hits = nested_get(es_res, ("hits", "hits")) - if not hits: - return None - - return hits[0]["_source"] - - def get_debug_image_iterations( - self, company_id: str, task: str, metric: str, variant: str - ) -> Tuple[int, int]: - """ - Return valid min and max iterations that the task reported images - The min iteration is the lowest iteration that contains non-recycled image url - """ - es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) - if not self.es.indices.exists(es_index): - return 0, 0 - - es_req: dict = { - "size": 1, - "sort": {"iter": "desc"}, - "query": { - "bool": { - "must": [ - {"term": {"task": task}}, - {"term": {"metric": metric}}, - {"term": {"variant": variant}}, - {"exists": {"field": "url"}}, - ] - } - }, - "aggs": { - "url_min": { - "terms": { - "field": "url", - "order": {"max_iter": "asc"}, - "size": 1, # we need only one url from the least recent iteration - }, - "aggs": { - "max_iter": {"max": {"field": "iter"}}, - "iters": { - "top_hits": { - "sort": {"iter": {"order": "desc"}}, - "size": 1, - "_source": "iter", - } - }, - } - } - }, - } - - with translate_errors_context(), TimingContext("es", "get_debug_image_iterations"): - es_res = self.es.search(index=es_index, body=es_req, routing=task) - - hits = nested_get(es_res, ("hits", "hits")) - if not hits: - return 0, 0 - - max_iter = hits[0]["_source"]["iter"] - url_min_buckets = nested_get(es_res, ("aggregations", "url_min", "buckets")) - if not url_min_buckets: - return 0, max_iter - - min_iter = url_min_buckets[0]["max_iter"]["value"] - return int(min_iter), max_iter diff --git a/apiserver/bll/event/debug_sample_history.py b/apiserver/bll/event/debug_sample_history.py new file mode 100644 index 0000000..d27fa1f --- /dev/null +++ b/apiserver/bll/event/debug_sample_history.py @@ -0,0 +1,359 @@ +import operator +from typing import Sequence, Tuple, Optional + +import attr +from boltons.iterutils import first +from elasticsearch import Elasticsearch +from jsonmodels.fields import StringField, ListField, IntField, BoolField +from jsonmodels.models import Base +from redis import StrictRedis + +from apiserver.apierrors import errors +from apiserver.apimodels import JsonSerializableMixin +from apiserver.bll.event.event_metrics import EventMetrics +from apiserver.bll.redis_cache_manager import RedisCacheManager +from apiserver.config_repo import config +from apiserver.database.errors import translate_errors_context +from apiserver.timing_context import TimingContext +from apiserver.utilities.dicts import nested_get + + +class VariantState(Base): + name: str = StringField(required=True) + min_iteration: int = IntField() + max_iteration: int = IntField() + + +class DebugSampleHistoryState(Base, JsonSerializableMixin): + id: str = StringField(required=True) + iteration: int = IntField() + variant: str = StringField() + task: str = StringField() + metric: str = StringField() + reached_first: bool = BoolField() + reached_last: bool = BoolField() + variant_states: Sequence[VariantState] = ListField([VariantState]) + + +@attr.s(auto_attribs=True) +class DebugSampleHistoryResult(object): + scroll_id: str = None + event: dict = None + min_iteration: int = None + max_iteration: int = None + + +class DebugSampleHistory: + EVENT_TYPE = "training_debug_image" + + @property + def state_expiration_sec(self): + return config.get( + f"services.events.events_retrieval.state_expiration_sec", 3600 + ) + + def __init__(self, redis: StrictRedis, es: Elasticsearch): + self.es = es + self.cache_manager = RedisCacheManager( + state_class=DebugSampleHistoryState, + redis=redis, + expiration_interval=self.state_expiration_sec, + ) + + def get_next_debug_image( + self, company_id: str, task: str, state_id: str, navigate_earlier: bool + ) -> DebugSampleHistoryResult: + """ + Get the debug image for next/prev variant on the current iteration + If does not exist then try getting image for the first/last variant from next/prev iteration + """ + res = DebugSampleHistoryResult(scroll_id=state_id) + state = self.cache_manager.get_state(state_id) + if not state or state.task != task: + raise errors.bad_request.InvalidScrollId(scroll_id=state_id) + + es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) + if not self.es.indices.exists(es_index): + return res + + image = self._get_next_for_current_iteration( + es_index=es_index, navigate_earlier=navigate_earlier, state=state + ) or self._get_next_for_another_iteration( + es_index=es_index, navigate_earlier=navigate_earlier, state=state + ) + if not image: + return res + + self._fill_res_and_update_state(image=image, res=res, state=state) + self.cache_manager.set_state(state=state) + return res + + def _fill_res_and_update_state( + self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState + ): + state.variant = image["variant"] + state.iteration = image["iter"] + res.event = image + var_state = first(s for s in state.variant_states if s.name == state.variant) + if var_state: + res.min_iteration = var_state.min_iteration + res.max_iteration = var_state.max_iteration + + def _get_next_for_current_iteration( + self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState + ) -> Optional[dict]: + """ + Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration + Only variants for which the iteration falls into their valid range are considered + Return None if no such variant or image is found + """ + cmp = operator.lt if navigate_earlier else operator.gt + variants = [ + var_state + for var_state in state.variant_states + if cmp(var_state.name, state.variant) + and var_state.min_iteration <= state.iteration + ] + if not variants: + return + + must_conditions = [ + {"term": {"task": state.task}}, + {"term": {"metric": state.metric}}, + {"terms": {"variant": [v.name for v in variants]}}, + {"term": {"iter": state.iteration}}, + {"exists": {"field": "url"}}, + ] + es_req = { + "size": 1, + "sort": {"variant": "desc" if navigate_earlier else "asc"}, + "query": {"bool": {"must": must_conditions}}, + } + + with translate_errors_context(), TimingContext( + "es", "get_next_for_current_iteration" + ): + es_res = self.es.search(index=es_index, body=es_req, routing=state.task) + + hits = nested_get(es_res, ("hits", "hits")) + if not hits: + return + + return hits[0]["_source"] + + def _get_next_for_another_iteration( + self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState + ) -> Optional[dict]: + """ + Get the image for the first variant for the next iteration (if navigate_earlier is set to False) + or from the last variant for the previous iteration (otherwise) + The variants for which the image falls in invalid range are discarded + If no suitable image is found then None is returned + """ + + must_conditions = [ + {"term": {"task": state.task}}, + {"term": {"metric": state.metric}}, + {"exists": {"field": "url"}}, + ] + + if navigate_earlier: + range_operator = "lt" + order = "desc" + variants = [ + var_state + for var_state in state.variant_states + if var_state.min_iteration < state.iteration + ] + else: + range_operator = "gt" + order = "asc" + variants = state.variant_states + + if not variants: + return + + variants_conditions = [ + { + "bool": { + "must": [ + {"term": {"variant": v.name}}, + {"range": {"iter": {"gte": v.min_iteration}}}, + ] + } + } + for v in variants + ] + must_conditions.append({"bool": {"should": variants_conditions}}) + must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},) + + es_req = { + "size": 1, + "sort": [{"iter": order}, {"variant": order}], + "query": {"bool": {"must": must_conditions}}, + } + with translate_errors_context(), TimingContext( + "es", "get_next_for_another_iteration" + ): + es_res = self.es.search(index=es_index, body=es_req, routing=state.task) + + hits = nested_get(es_res, ("hits", "hits")) + if not hits: + return + + return hits[0]["_source"] + + def get_debug_image_for_variant( + self, + company_id: str, + task: str, + metric: str, + variant: str, + iteration: Optional[int] = None, + state_id: str = None, + ) -> DebugSampleHistoryResult: + """ + Get the debug image for the requested iteration or the latest before it + If the iteration is not passed then get the latest event + """ + res = DebugSampleHistoryResult() + es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE) + if not self.es.indices.exists(es_index): + return res + + def init_state(state_: DebugSampleHistoryState): + state_.task = task + state_.metric = metric + variant_iterations = self._get_variant_iterations( + es_index=es_index, task=task, metric=metric + ) + state_.variant_states = [ + VariantState( + name=var_name, min_iteration=min_iter, max_iteration=max_iter + ) + for var_name, min_iter, max_iter in variant_iterations + ] + + def validate_state(state_: DebugSampleHistoryState): + if state_.task != task or state_.metric != metric: + raise errors.bad_request.InvalidScrollId( + "Task and metric stored in the state do not match the passed ones", + scroll_id=state_.id, + ) + + state: DebugSampleHistoryState + with self.cache_manager.get_or_create_state( + state_id=state_id, init_state=init_state, validate_state=validate_state, + ) as state: + res.scroll_id = state.id + + var_state = first(s for s in state.variant_states if s.name == variant) + if not var_state: + return res + + must_conditions = [ + {"term": {"task": task}}, + {"term": {"metric": metric}}, + {"term": {"variant": variant}}, + {"exists": {"field": "url"}}, + ] + if iteration is not None: + must_conditions.append( + { + "range": { + "iter": {"lte": iteration, "gte": var_state.min_iteration} + } + } + ) + else: + must_conditions.append( + {"range": {"iter": {"gte": var_state.min_iteration}}} + ) + + es_req = { + "size": 1, + "sort": {"iter": "desc"}, + "query": {"bool": {"must": must_conditions}}, + } + + with translate_errors_context(), TimingContext( + "es", "get_debug_image_for_variant" + ): + es_res = self.es.search(index=es_index, body=es_req, routing=task) + + hits = nested_get(es_res, ("hits", "hits")) + if not hits: + return res + + self._fill_res_and_update_state( + image=hits[0]["_source"], res=res, state=state + ) + return res + + def _get_variant_iterations( + self, + es_index: str, + task: str, + metric: str, + variants: Optional[Sequence[str]] = None, + ) -> Sequence[Tuple[str, int, int]]: + """ + Return valid min and max iterations that the task reported images + The min iteration is the lowest iteration that contains non-recycled image url + """ + must = [ + {"term": {"task": task}}, + {"term": {"metric": metric}}, + {"exists": {"field": "url"}}, + ] + if variants: + must.append({"terms": {"variant": variants}}) + + es_req: dict = { + "size": 0, + "query": {"bool": {"must": must}}, + "aggs": { + "variants": { + # all variants that sent debug images + "terms": { + "field": "variant", + "size": EventMetrics.MAX_VARIANTS_COUNT, + }, + "aggs": { + "last_iter": {"max": {"field": "iter"}}, + "urls": { + # group by urls and choose the minimal iteration + # from all the maximal iterations per url + "terms": { + "field": "url", + "order": {"max_iter": "asc"}, + "size": 1, + }, + "aggs": { + # find max iteration for each url + "max_iter": {"max": {"field": "iter"}} + }, + }, + }, + } + }, + } + + with translate_errors_context(), TimingContext( + "es", "get_debug_image_iterations" + ): + es_res = self.es.search(index=es_index, body=es_req, routing=task) + + def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]: + variant = variant_bucket["key"] + urls = nested_get(variant_bucket, ("urls", "buckets")) + min_iter = int(urls[0]["max_iter"]["value"]) + max_iter = int(variant_bucket["last_iter"]["value"]) + return variant, min_iter, max_iter + + return [ + get_variant_data(variant_bucket) + for variant_bucket in nested_get( + es_res, ("aggregations", "variants", "buckets") + ) + ] diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index f868c33..45b890a 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -12,6 +12,7 @@ from elasticsearch import helpers from mongoengine import Q from nested_dict import nested_dict +from apiserver.bll.event.debug_sample_history import DebugSampleHistory from apiserver.bll.util import parallel_chunked_decorator from apiserver.database import utils as dbutils from apiserver.es_factory import es_factory @@ -54,6 +55,7 @@ class EventBLL(object): ) self.redis = redis or redman.connection("apiserver") self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis) + self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis) self.log_events_iterator = LogEventsIterator(es=self.es) @property diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py index 7197754..c9f5da7 100644 --- a/apiserver/bll/task/artifacts.py +++ b/apiserver/bll/task/artifacts.py @@ -8,6 +8,7 @@ from apiserver.bll.task.utils import get_task_for_update from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact from apiserver.timing_context import TimingContext from apiserver.utilities.dicts import nested_get, nested_set +from apiserver.utilities.parameter_key_escaper import mongoengine_safe def get_artifact_id(artifact: dict): @@ -60,7 +61,7 @@ class Artifacts: } update_cmds = { - f"set__execution__artifacts__{name}": value + f"set__execution__artifacts__{mongoengine_safe(name)}": value for name, value in artifacts.items() } return task.update(**update_cmds, last_update=datetime.utcnow()) diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index ca0a8c9..a21181a 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -17,7 +17,10 @@ from apiserver.bll.task.utils import get_task_for_update from apiserver.config_repo import config from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem from apiserver.timing_context import TimingContext -from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper +from apiserver.utilities.parameter_key_escaper import ( + ParameterKeyEscaper, + mongoengine_safe, +) log = config.logger(__file__) task_bll = TaskBLL() @@ -65,7 +68,9 @@ class HyperParams: with TimingContext("mongo", "delete_hyperparams"): properties_only = cls._normalize_params(hyperparams) task = get_task_for_update( - company_id=company_id, task_id=task_id, allow_all_statuses=properties_only + company_id=company_id, + task_id=task_id, + allow_all_statuses=properties_only, ) with_param, without_param = iterutils.partition( @@ -99,7 +104,9 @@ class HyperParams: with TimingContext("mongo", "edit_hyperparams"): properties_only = cls._normalize_params(hyperparams) task = get_task_for_update( - company_id=company_id, task_id=task_id, allow_all_statuses=properties_only + company_id=company_id, + task_id=task_id, + allow_all_statuses=properties_only, ) update_cmds = dict() @@ -108,11 +115,15 @@ class HyperParams: update_cmds["set__hyperparams"] = hyperparams elif replace_hyperparams == ReplaceHyperparams.section: for section, value in hyperparams.items(): - update_cmds[f"set__hyperparams__{section}"] = value + update_cmds[ + f"set__hyperparams__{mongoengine_safe(section)}" + ] = value else: for section, section_params in hyperparams.items(): for name, value in section_params.items(): - update_cmds[f"set__hyperparams__{section}__{name}"] = value + update_cmds[ + f"set__hyperparams__{section}__{mongoengine_safe(name)}" + ] = value return task.update(**update_cmds, last_update=datetime.utcnow()) @@ -200,7 +211,7 @@ class HyperParams: update_cmds["set__configuration"] = configuration else: for name, value in configuration.items(): - update_cmds[f"set__configuration__{name}"] = value + update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value return task.update(**update_cmds, last_update=datetime.utcnow()) diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 5374306..83ed130 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -230,6 +230,27 @@ } } } + debug_image_sample_reposnse { + type: object + properties { + scroll_id { + type: string + description: "Scroll ID to pass to the next calls to get_debug_image_sample or next_debug_image_sample" + } + event { + type: object + description: "Debug image event" + } + min_iteration { + type: integer + description: "minimal valid iteration for the variant" + } + max_iteration { + type: integer + description: "maximal valid iteration for the variant" + } + } + } } add { "2.1" { @@ -395,9 +416,9 @@ } } } - get_debug_image_event { + get_debug_image_sample { "2.11": { - description: "Return the last debug image per metric and variant for the provided iteration" + description: "Return the debug image per metric and variant for the provided iteration" request { type: object required: [task, metric, variant] @@ -415,56 +436,36 @@ type: string } iteration { - description: "The latest iteration to bring debug image from. If not specified then the latest reported image is retrieved" + description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved" type: integer } - } - } - response { - type: object - properties { - event { - description: "The latest debug image for the specifed iteration" - type: object + scroll_id { + type: string + description: "Scroll ID from the previous call to get_debug_image_sample or empty" } } } + response {"$ref": "#/definitions/task_log_event"} } } - get_debug_image_iterations { + next_debug_image_sample { "2.11": { - description: "Return the min and max iterations for which valid urls are present" + description: "Get the image for the next variant for the same iteration or for the next iteration" request { type: object - required: [task, metric, variant] + required: [task, scroll_id] properties { task { description: "Task ID" type: string } - metric { - description: "Metric name" - type: string - } - variant { - description: "Metric variant" + scroll_id { type: string + description: "Scroll ID from the previous call to get_debug_image_sample" } } } - response { - type: object - properties { - min_iteration { - description: "Mininal iteration for which a non recycled image exists" - type: integer - } - max_iteration { - description: "Maximal iteration for which an image was reported" - type: integer - } - } - } + response {"$ref": "#/definitions/task_log_event"} } } get_task_metrics{ diff --git a/apiserver/services/events.py b/apiserver/services/events.py index 902d24a..da08ac9 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -2,6 +2,8 @@ import itertools from collections import defaultdict from operator import itemgetter +import attr + from apiserver.apierrors import errors from apiserver.apimodels.events import ( MultiTaskScalarMetricsIterHistogramRequest, @@ -13,8 +15,8 @@ from apiserver.apimodels.events import ( TaskMetricsRequest, LogEventsRequest, LogOrderEnum, - DebugImageIterationsRequest, - DebugImageEventRequest, + GetDebugImageSampleRequest, + NextDebugImageSampleRequest, ) from apiserver.bll.event import EventBLL from apiserver.bll.event.event_metrics import EventMetrics @@ -627,46 +629,41 @@ def get_debug_images(call, company_id, request: DebugImagesRequest): @endpoint( - "events.get_debug_image_event", + "events.get_debug_image_sample", min_version="2.11", - request_data_model=DebugImageEventRequest, + request_data_model=GetDebugImageSampleRequest, ) -def get_debug_image(call, company_id, request: DebugImageEventRequest): +def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest): task = task_bll.assert_exists( company_id, task_ids=[request.task], allow_public=True, only=("company",) )[0] - call.result.data = { - "event": event_bll.debug_images_iterator.get_debug_image_event( - company_id=task.company, - task=request.task, - metric=request.metric, - variant=request.variant, - iteration=request.iteration, - ) - } + res = event_bll.debug_sample_history.get_debug_image_for_variant( + company_id=task.company, + task=request.task, + metric=request.metric, + variant=request.variant, + iteration=request.iteration, + state_id=request.scroll_id, + ) + call.result.data = attr.asdict(res, recurse=False) @endpoint( - "events.get_debug_image_iterations", + "events.next_debug_image_sample", min_version="2.11", - request_data_model=DebugImageIterationsRequest, + request_data_model=NextDebugImageSampleRequest, ) -def get_debug_image_iterations(call, company_id, request: DebugImageIterationsRequest): +def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest): task = task_bll.assert_exists( company_id, task_ids=[request.task], allow_public=True, only=("company",) )[0] - - min_iter, max_iter = event_bll.debug_images_iterator.get_debug_image_iterations( - company_id=task.company, - task=request.task, - metric=request.metric, - variant=request.variant, - ) - - call.result.data = { - "max_iteration": max_iter, - "min_iteration": min_iter, - } + res = event_bll.debug_sample_history.get_next_debug_image( + company_id=task.company, + task=request.task, + state_id=request.scroll_id, + navigate_earlier=request.navigate_earlier + ) + call.result.data = attr.asdict(res, recurse=False) @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) diff --git a/apiserver/tests/automated/test_task_debug_images.py b/apiserver/tests/automated/test_task_debug_images.py index 87bd2b4..eebc97c 100644 --- a/apiserver/tests/automated/test_task_debug_images.py +++ b/apiserver/tests/automated/test_task_debug_images.py @@ -27,20 +27,17 @@ class TestTaskDebugImages(TestService): **kwargs, } - def test_get_debug_image(self): + def test_get_debug_image_sample(self): task = self._temp_task() metric = "Metric1" variant = "Variant1" # test empty - res = self.api.events.get_debug_image_iterations( - task=task, metric=metric, variant=variant - ) - self.assertEqual(res.min_iteration, 0) - self.assertEqual(res.max_iteration, 0) - res = self.api.events.get_debug_image_event( + res = self.api.events.get_debug_image_sample( task=task, metric=metric, variant=variant ) + self.assertEqual(res.min_iteration, None) + self.assertEqual(res.max_iteration, None) self.assertEqual(res.event, None) # test existing events @@ -57,25 +54,74 @@ class TestTaskDebugImages(TestService): for n in range(iterations) ] self.send_batch(events) - res = self.api.events.get_debug_image_iterations( - task=task, metric=metric, variant=variant - ) - self.assertEqual(res.max_iteration, iterations-1) - self.assertEqual(res.min_iteration, max(0, iterations - unique_images)) # if iteration is not specified then return the event from the last one - res = self.api.events.get_debug_image_event( + res = self.api.events.get_debug_image_sample( task=task, metric=metric, variant=variant ) self._assertEqualEvent(res.event, events[-1]) + self.assertEqual(res.max_iteration, iterations - 1) + self.assertEqual(res.min_iteration, max(0, iterations - unique_images)) + self.assertTrue(res.scroll_id) # else from the specific iteration iteration = 8 - res = self.api.events.get_debug_image_event( - task=task, metric=metric, variant=variant, iteration=iteration + res = self.api.events.get_debug_image_sample( + task=task, + metric=metric, + variant=variant, + iteration=iteration, + scroll_id=res.scroll_id, ) self._assertEqualEvent(res.event, events[iteration]) + def test_next_debug_image_sample(self): + task = self._temp_task() + metric = "Metric1" + variant1 = "Variant1" + variant2 = "Variant2" + + # test existing events + events = [ + self._create_task_event( + task=task, + iteration=n, + metric=metric, + variant=v, + url=f"{metric}_{v}_{n}", + ) + for n in range(2) + for v in (variant1, variant2) + ] + self.send_batch(events) + + # init scroll + res = self.api.events.get_debug_image_sample( + task=task, metric=metric, variant=variant1 + ) + self._assertEqualEvent(res.event, events[-2]) + + # navigate forwards + res = self.api.events.next_debug_image_sample( + task=task, scroll_id=res.scroll_id, navigate_earlier=False + ) + self._assertEqualEvent(res.event, events[-1]) + res = self.api.events.next_debug_image_sample( + task=task, scroll_id=res.scroll_id, navigate_earlier=False + ) + self.assertEqual(res.event, None) + + # navigate backwards + for i in range(3): + res = self.api.events.next_debug_image_sample( + task=task, scroll_id=res.scroll_id + ) + self._assertEqualEvent(res.event, events[-2 - i]) + res = self.api.events.next_debug_image_sample( + task=task, scroll_id=res.scroll_id + ) + self.assertEqual(res.event, None) + def _assertEqualEvent(self, ev1: dict, ev2: dict): self.assertEqual(ev1["iter"], ev2["iter"]) self.assertEqual(ev1["url"], ev2["url"]) diff --git a/apiserver/utilities/parameter_key_escaper.py b/apiserver/utilities/parameter_key_escaper.py index c08bc80..1609c8a 100644 --- a/apiserver/utilities/parameter_key_escaper.py +++ b/apiserver/utilities/parameter_key_escaper.py @@ -1,4 +1,5 @@ from boltons.dictutils import OneToOne +from mongoengine.queryset.transform import MATCH_OPERATORS class ParameterKeyEscaper: @@ -39,3 +40,9 @@ class ParameterKeyEscaper: value = "_" + value[2:] return value + + +def mongoengine_safe(field_name): + if field_name in MATCH_OPERATORS: + return field_name + "__" + return field_name