Fix using reserved keywords as atrifact/hyperparams/configuration names

Replace events.get_debug_image_event and event.get_debug_image_iterations with events.get_debug_image_sample and events.next_debug_image_sample
This commit is contained in:
allegroai 2021-01-05 17:47:27 +02:00
parent d198138c5b
commit be788965e0
10 changed files with 516 additions and 190 deletions

View File

@ -57,12 +57,15 @@ class TaskMetricVariant(Base):
variant: str = StringField(required=True)
class DebugImageIterationsRequest(TaskMetricVariant):
pass
class DebugImageEventRequest(TaskMetricVariant):
class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
class NextDebugImageSampleRequest(Base):
task: str = StringField(required=True)
scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True)
class LogOrderEnum(StringEnum):

View File

@ -22,7 +22,6 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model.task.metrics import MetricEventStats
from apiserver.database.model.task.task import Task
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantScrollState(Base):
@ -466,103 +465,3 @@ class DebugImagesIterator:
if events_to_remove:
it["events"] = [ev for ev in it["events"] if ev not in events_to_remove]
return [it for it in iterations if it["events"]]
def get_debug_image_event(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
) -> Optional[dict]:
"""
Get the debug image for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return None
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append({"range": {"iter": {"lte": iteration}}})
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}}
}
with translate_errors_context(), TimingContext("es", "get_debug_image_event"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return None
return hits[0]["_source"]
def get_debug_image_iterations(
self, company_id: str, task: str, metric: str, variant: str
) -> Tuple[int, int]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
"""
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return 0, 0
es_req: dict = {
"size": 1,
"sort": {"iter": "desc"},
"query": {
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
}
},
"aggs": {
"url_min": {
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1, # we need only one url from the least recent iteration
},
"aggs": {
"max_iter": {"max": {"field": "iter"}},
"iters": {
"top_hits": {
"sort": {"iter": {"order": "desc"}},
"size": 1,
"_source": "iter",
}
},
}
}
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_iterations"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return 0, 0
max_iter = hits[0]["_source"]["iter"]
url_min_buckets = nested_get(es_res, ("aggregations", "url_min", "buckets"))
if not url_min_buckets:
return 0, max_iter
min_iter = url_min_buckets[0]["max_iter"]["value"]
return int(min_iter), max_iter

View File

@ -0,0 +1,359 @@
import operator
from typing import Sequence, Tuple, Optional
import attr
from boltons.iterutils import first
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, ListField, IntField, BoolField
from jsonmodels.models import Base
from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get
class VariantState(Base):
name: str = StringField(required=True)
min_iteration: int = IntField()
max_iteration: int = IntField()
class DebugSampleHistoryState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
iteration: int = IntField()
variant: str = StringField()
task: str = StringField()
metric: str = StringField()
reached_first: bool = BoolField()
reached_last: bool = BoolField()
variant_states: Sequence[VariantState] = ListField([VariantState])
@attr.s(auto_attribs=True)
class DebugSampleHistoryResult(object):
scroll_id: str = None
event: dict = None
min_iteration: int = None
max_iteration: int = None
class DebugSampleHistory:
EVENT_TYPE = "training_debug_image"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=DebugSampleHistoryState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_next_debug_image(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
) -> DebugSampleHistoryResult:
"""
Get the debug image for next/prev variant on the current iteration
If does not exist then try getting image for the first/last variant from next/prev iteration
"""
res = DebugSampleHistoryResult(scroll_id=state_id)
state = self.cache_manager.get_state(state_id)
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return res
image = self._get_next_for_current_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state
)
if not image:
return res
self._fill_res_and_update_state(image=image, res=res, state=state)
self.cache_manager.set_state(state=state)
return res
def _fill_res_and_update_state(
self, image: dict, res: DebugSampleHistoryResult, state: DebugSampleHistoryState
):
state.variant = image["variant"]
state.iteration = image["iter"]
res.event = image
var_state = first(s for s in state.variant_states if s.name == state.variant)
if var_state:
res.min_iteration = var_state.min_iteration
res.max_iteration = var_state.max_iteration
def _get_next_for_current_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
Only variants for which the iteration falls into their valid range are considered
Return None if no such variant or image is found
"""
cmp = operator.lt if navigate_earlier else operator.gt
variants = [
var_state
for var_state in state.variant_states
if cmp(var_state.name, state.variant)
and var_state.min_iteration <= state.iteration
]
if not variants:
return
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"terms": {"variant": [v.name for v in variants]}},
{"term": {"iter": state.iteration}},
{"exists": {"field": "url"}},
]
es_req = {
"size": 1,
"sort": {"variant": "desc" if navigate_earlier else "asc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
or from the last variant for the previous iteration (otherwise)
The variants for which the image falls in invalid range are discarded
If no suitable image is found then None is returned
"""
must_conditions = [
{"term": {"task": state.task}},
{"term": {"metric": state.metric}},
{"exists": {"field": "url"}},
]
if navigate_earlier:
range_operator = "lt"
order = "desc"
variants = [
var_state
for var_state in state.variant_states
if var_state.min_iteration < state.iteration
]
else:
range_operator = "gt"
order = "asc"
variants = state.variant_states
if not variants:
return
variants_conditions = [
{
"bool": {
"must": [
{"term": {"variant": v.name}},
{"range": {"iter": {"gte": v.min_iteration}}},
]
}
}
for v in variants
]
must_conditions.append({"bool": {"should": variants_conditions}})
must_conditions.append({"range": {"iter": {range_operator: state.iteration}}},)
es_req = {
"size": 1,
"sort": [{"iter": order}, {"variant": order}],
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return
return hits[0]["_source"]
def get_debug_image_for_variant(
self,
company_id: str,
task: str,
metric: str,
variant: str,
iteration: Optional[int] = None,
state_id: str = None,
) -> DebugSampleHistoryResult:
"""
Get the debug image for the requested iteration or the latest before it
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return res
def init_state(state_: DebugSampleHistoryState):
state_.task = task
state_.metric = metric
variant_iterations = self._get_variant_iterations(
es_index=es_index, task=task, metric=metric
)
state_.variant_states = [
VariantState(
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
for var_name, min_iter, max_iter in variant_iterations
]
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
raise errors.bad_request.InvalidScrollId(
"Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id,
)
state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res.scroll_id = state.id
var_state = first(s for s in state.variant_states if s.name == variant)
if not var_state:
return res
must_conditions = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"exists": {"field": "url"}},
]
if iteration is not None:
must_conditions.append(
{
"range": {
"iter": {"lte": iteration, "gte": var_state.min_iteration}
}
}
)
else:
must_conditions.append(
{"range": {"iter": {"gte": var_state.min_iteration}}}
)
es_req = {
"size": 1,
"sort": {"iter": "desc"},
"query": {"bool": {"must": must_conditions}},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
return res
self._fill_res_and_update_state(
image=hits[0]["_source"], res=res, state=state
)
return res
def _get_variant_iterations(
self,
es_index: str,
task: str,
metric: str,
variants: Optional[Sequence[str]] = None,
) -> Sequence[Tuple[str, int, int]]:
"""
Return valid min and max iterations that the task reported images
The min iteration is the lowest iteration that contains non-recycled image url
"""
must = [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"exists": {"field": "url"}},
]
if variants:
must.append({"terms": {"variant": variants}})
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"aggs": {
"variants": {
# all variants that sent debug images
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
},
"aggs": {
"last_iter": {"max": {"field": "iter"}},
"urls": {
# group by urls and choose the minimal iteration
# from all the maximal iterations per url
"terms": {
"field": "url",
"order": {"max_iter": "asc"},
"size": 1,
},
"aggs": {
# find max iteration for each url
"max_iter": {"max": {"field": "iter"}}
},
},
},
}
},
}
with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]
urls = nested_get(variant_bucket, ("urls", "buckets"))
min_iter = int(urls[0]["max_iter"]["value"])
max_iter = int(variant_bucket["last_iter"]["value"])
return variant, min_iter, max_iter
return [
get_variant_data(variant_bucket)
for variant_bucket in nested_get(
es_res, ("aggregations", "variants", "buckets")
)
]

View File

@ -12,6 +12,7 @@ from elasticsearch import helpers
from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
@ -54,6 +55,7 @@ class EventBLL(object):
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
@property

View File

@ -8,6 +8,7 @@ from apiserver.bll.task.utils import get_task_for_update
from apiserver.database.model.task.task import DEFAULT_ARTIFACT_MODE, Artifact
from apiserver.timing_context import TimingContext
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.parameter_key_escaper import mongoengine_safe
def get_artifact_id(artifact: dict):
@ -60,7 +61,7 @@ class Artifacts:
}
update_cmds = {
f"set__execution__artifacts__{name}": value
f"set__execution__artifacts__{mongoengine_safe(name)}": value
for name, value in artifacts.items()
}
return task.update(**update_cmds, last_update=datetime.utcnow())

View File

@ -17,7 +17,10 @@ from apiserver.bll.task.utils import get_task_for_update
from apiserver.config_repo import config
from apiserver.database.model.task.task import ParamsItem, Task, ConfigurationItem
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
log = config.logger(__file__)
task_bll = TaskBLL()
@ -65,7 +68,9 @@ class HyperParams:
with TimingContext("mongo", "delete_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
)
with_param, without_param = iterutils.partition(
@ -99,7 +104,9 @@ class HyperParams:
with TimingContext("mongo", "edit_hyperparams"):
properties_only = cls._normalize_params(hyperparams)
task = get_task_for_update(
company_id=company_id, task_id=task_id, allow_all_statuses=properties_only
company_id=company_id,
task_id=task_id,
allow_all_statuses=properties_only,
)
update_cmds = dict()
@ -108,11 +115,15 @@ class HyperParams:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{section}"] = value
update_cmds[
f"set__hyperparams__{mongoengine_safe(section)}"
] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[f"set__hyperparams__{section}__{name}"] = value
update_cmds[
f"set__hyperparams__{section}__{mongoengine_safe(name)}"
] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@ -200,7 +211,7 @@ class HyperParams:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{name}"] = value
update_cmds[f"set__configuration__{mongoengine_safe(name)}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())

View File

@ -230,6 +230,27 @@
}
}
}
debug_image_sample_reposnse {
type: object
properties {
scroll_id {
type: string
description: "Scroll ID to pass to the next calls to get_debug_image_sample or next_debug_image_sample"
}
event {
type: object
description: "Debug image event"
}
min_iteration {
type: integer
description: "minimal valid iteration for the variant"
}
max_iteration {
type: integer
description: "maximal valid iteration for the variant"
}
}
}
}
add {
"2.1" {
@ -395,9 +416,9 @@
}
}
}
get_debug_image_event {
get_debug_image_sample {
"2.11": {
description: "Return the last debug image per metric and variant for the provided iteration"
description: "Return the debug image per metric and variant for the provided iteration"
request {
type: object
required: [task, metric, variant]
@ -415,56 +436,36 @@
type: string
}
iteration {
description: "The latest iteration to bring debug image from. If not specified then the latest reported image is retrieved"
description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved"
type: integer
}
}
}
response {
type: object
properties {
event {
description: "The latest debug image for the specifed iteration"
type: object
scroll_id {
type: string
description: "Scroll ID from the previous call to get_debug_image_sample or empty"
}
}
}
response {"$ref": "#/definitions/task_log_event"}
}
}
get_debug_image_iterations {
next_debug_image_sample {
"2.11": {
description: "Return the min and max iterations for which valid urls are present"
description: "Get the image for the next variant for the same iteration or for the next iteration"
request {
type: object
required: [task, metric, variant]
required: [task, scroll_id]
properties {
task {
description: "Task ID"
type: string
}
metric {
description: "Metric name"
type: string
}
variant {
description: "Metric variant"
scroll_id {
type: string
description: "Scroll ID from the previous call to get_debug_image_sample"
}
}
}
response {
type: object
properties {
min_iteration {
description: "Mininal iteration for which a non recycled image exists"
type: integer
}
max_iteration {
description: "Maximal iteration for which an image was reported"
type: integer
}
}
}
response {"$ref": "#/definitions/task_log_event"}
}
}
get_task_metrics{

View File

@ -2,6 +2,8 @@ import itertools
from collections import defaultdict
from operator import itemgetter
import attr
from apiserver.apierrors import errors
from apiserver.apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
@ -13,8 +15,8 @@ from apiserver.apimodels.events import (
TaskMetricsRequest,
LogEventsRequest,
LogOrderEnum,
DebugImageIterationsRequest,
DebugImageEventRequest,
GetDebugImageSampleRequest,
NextDebugImageSampleRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_metrics import EventMetrics
@ -627,46 +629,41 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
@endpoint(
"events.get_debug_image_event",
"events.get_debug_image_sample",
min_version="2.11",
request_data_model=DebugImageEventRequest,
request_data_model=GetDebugImageSampleRequest,
)
def get_debug_image(call, company_id, request: DebugImageEventRequest):
def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
call.result.data = {
"event": event_bll.debug_images_iterator.get_debug_image_event(
res = event_bll.debug_sample_history.get_debug_image_for_variant(
company_id=task.company,
task=request.task,
metric=request.metric,
variant=request.variant,
iteration=request.iteration,
state_id=request.scroll_id,
)
}
call.result.data = attr.asdict(res, recurse=False)
@endpoint(
"events.get_debug_image_iterations",
"events.next_debug_image_sample",
min_version="2.11",
request_data_model=DebugImageIterationsRequest,
request_data_model=NextDebugImageSampleRequest,
)
def get_debug_image_iterations(call, company_id, request: DebugImageIterationsRequest):
def next_debug_image_sample(call, company_id, request: NextDebugImageSampleRequest):
task = task_bll.assert_exists(
company_id, task_ids=[request.task], allow_public=True, only=("company",)
)[0]
min_iter, max_iter = event_bll.debug_images_iterator.get_debug_image_iterations(
res = event_bll.debug_sample_history.get_next_debug_image(
company_id=task.company,
task=request.task,
metric=request.metric,
variant=request.variant,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier
)
call.result.data = {
"max_iteration": max_iter,
"min_iteration": min_iter,
}
call.result.data = attr.asdict(res, recurse=False)
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)

View File

@ -27,20 +27,17 @@ class TestTaskDebugImages(TestService):
**kwargs,
}
def test_get_debug_image(self):
def test_get_debug_image_sample(self):
task = self._temp_task()
metric = "Metric1"
variant = "Variant1"
# test empty
res = self.api.events.get_debug_image_iterations(
task=task, metric=metric, variant=variant
)
self.assertEqual(res.min_iteration, 0)
self.assertEqual(res.max_iteration, 0)
res = self.api.events.get_debug_image_event(
res = self.api.events.get_debug_image_sample(
task=task, metric=metric, variant=variant
)
self.assertEqual(res.min_iteration, None)
self.assertEqual(res.max_iteration, None)
self.assertEqual(res.event, None)
# test existing events
@ -57,25 +54,74 @@ class TestTaskDebugImages(TestService):
for n in range(iterations)
]
self.send_batch(events)
res = self.api.events.get_debug_image_iterations(
task=task, metric=metric, variant=variant
)
self.assertEqual(res.max_iteration, iterations-1)
self.assertEqual(res.min_iteration, max(0, iterations - unique_images))
# if iteration is not specified then return the event from the last one
res = self.api.events.get_debug_image_event(
res = self.api.events.get_debug_image_sample(
task=task, metric=metric, variant=variant
)
self._assertEqualEvent(res.event, events[-1])
self.assertEqual(res.max_iteration, iterations - 1)
self.assertEqual(res.min_iteration, max(0, iterations - unique_images))
self.assertTrue(res.scroll_id)
# else from the specific iteration
iteration = 8
res = self.api.events.get_debug_image_event(
task=task, metric=metric, variant=variant, iteration=iteration
res = self.api.events.get_debug_image_sample(
task=task,
metric=metric,
variant=variant,
iteration=iteration,
scroll_id=res.scroll_id,
)
self._assertEqualEvent(res.event, events[iteration])
def test_next_debug_image_sample(self):
task = self._temp_task()
metric = "Metric1"
variant1 = "Variant1"
variant2 = "Variant2"
# test existing events
events = [
self._create_task_event(
task=task,
iteration=n,
metric=metric,
variant=v,
url=f"{metric}_{v}_{n}",
)
for n in range(2)
for v in (variant1, variant2)
]
self.send_batch(events)
# init scroll
res = self.api.events.get_debug_image_sample(
task=task, metric=metric, variant=variant1
)
self._assertEqualEvent(res.event, events[-2])
# navigate forwards
res = self.api.events.next_debug_image_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self._assertEqualEvent(res.event, events[-1])
res = self.api.events.next_debug_image_sample(
task=task, scroll_id=res.scroll_id, navigate_earlier=False
)
self.assertEqual(res.event, None)
# navigate backwards
for i in range(3):
res = self.api.events.next_debug_image_sample(
task=task, scroll_id=res.scroll_id
)
self._assertEqualEvent(res.event, events[-2 - i])
res = self.api.events.next_debug_image_sample(
task=task, scroll_id=res.scroll_id
)
self.assertEqual(res.event, None)
def _assertEqualEvent(self, ev1: dict, ev2: dict):
self.assertEqual(ev1["iter"], ev2["iter"])
self.assertEqual(ev1["url"], ev2["url"])

View File

@ -1,4 +1,5 @@
from boltons.dictutils import OneToOne
from mongoengine.queryset.transform import MATCH_OPERATORS
class ParameterKeyEscaper:
@ -39,3 +40,9 @@ class ParameterKeyEscaper:
value = "_" + value[2:]
return value
def mongoengine_safe(field_name):
if field_name in MATCH_OPERATORS:
return field_name + "__"
return field_name