Add scroll support to *.get_*

This commit is contained in:
allegroai 2022-02-13 19:23:29 +02:00
parent 446bd35006
commit f20cd6536e
17 changed files with 1623 additions and 1204 deletions

View File

@ -126,14 +126,27 @@ class QueueBLL(object):
)
queue.delete()
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_all(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
return Queue.get_many(
company=company_id, parameters=query_dict, query_dict=query_dict
company=company_id,
parameters=query_dict,
query_dict=query_dict,
ret_params=ret_params,
)
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_queue_infos(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
"""
@ -143,6 +156,7 @@ class QueueBLL(object):
company=company_id,
query_dict=query_dict,
override_projection=projection,
ret_params=ret_params,
)
queue_workers = defaultdict(list)

View File

@ -49,6 +49,21 @@ class RedisCacheManager(Generic[T]):
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
def get_or_create_state_core(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
) -> T:
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
return state
@contextmanager
def get_or_create_state(
self,
@ -66,12 +81,9 @@ class RedisCacheManager(Generic[T]):
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
state = self.get_or_create_state_core(
state_id=state_id, init_state=init_state, validate_state=validate_state
)
try:
yield state

View File

@ -0,0 +1,4 @@
max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600

View File

@ -1,16 +1,18 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
from functools import reduce, partial
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any, Callable
from boltons.iterutils import first, bucketize, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from mongoengine import Q, Document, ListField, StringField, IntField
from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors
from apiserver.apierrors.base import BaseError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.props import PropsMixin
@ -21,6 +23,7 @@ from apiserver.database.utils import (
field_does_not_exist,
field_exists,
)
from apiserver.redis_manager import redman
log = config.logger("dbmodel")
@ -70,6 +73,9 @@ class GetMixin(PropsMixin):
_ordering_key = "order_by"
_search_text_key = "search_text"
_start_key = "start"
_size_key = "size"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
("_any_", "_or_"): lambda a, b: a | b,
@ -143,6 +149,20 @@ class GetMixin(PropsMixin):
get_all_query_options = QueryParameterOptions()
class GetManyScrollState(ProperDictMixin, Document):
meta = {"db_alias": Database.backend, "strict": False}
id = StringField(primary_key=True)
position = IntField(default=0)
cache_manager = RedisCacheManager(
state_class=GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
),
)
@classmethod
def get(
cls: Union["GetMixin", Document],
@ -421,27 +441,41 @@ class GetMixin(PropsMixin):
return order_by
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
):
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
if parameters is None:
parameters = {}
default_page = parameters.get("page", default_page)
if default_page is None:
return None, None
default_page_size = parameters.get("page_size", default_page_size)
if not default_page_size:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
elif default_page < 0:
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
"""
Validate and extract paging info from from the provided dictionary. Supports default values.
If page is specified then it should be non-negative, if page size is specified then it should be positive
If page size is specified and page is not then 0 page is assumed
If page is specified then page size should be specified too
"""
parameters = parameters or {}
start = parameters.get(cls._start_key)
if start is not None:
return start, cls.validate_scroll_size(parameters)
max_page_size = config.get("services._mongo.max_page_size", 500)
page = parameters.get("page", default_page)
if page is not None and page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
elif default_page_size < 1:
page_size = parameters.get("page_size", default_page_size or max_page_size)
if page_size is not None and page_size < 1:
raise errors.bad_request.ValidationError(
"page_size must be >0", field="page_size"
)
return default_page, default_page_size
if page_size is not None:
page = page or 0
page_size = min(page_size, max_page_size)
return page * page_size, page_size
if page is not None:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
return None, None
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
@ -485,6 +519,57 @@ class GetMixin(PropsMixin):
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
@classmethod
def validate_scroll_size(cls, query_dict: dict) -> int:
size = query_dict.get(cls._size_key)
if not size or not isinstance(size, int) or size < 1:
raise errors.bad_request.ValidationError(
"Integer size parameter greater than 1 should be provided when working with scroll"
)
return size
@classmethod
def get_data_with_scroll_and_filter_support(
cls,
query_dict: dict,
data_getter: Callable[[], Sequence[dict]],
ret_params: dict,
) -> Sequence[dict]:
"""
Retrieves the data by calling the provided data_getter api
If scroll parameters are specified then put the query_dict 'start' parameter to the last
scroll position and continue retrievals from that position
If refresh_scroll is requested then bring once more the data from the beginning
till the current scroll position
In the end the scroll position is updated and accumulated frames are returned
"""
query_dict = query_dict or {}
state: Optional[cls.GetManyScrollState] = None
if "scroll_id" in query_dict:
size = cls.validate_scroll_size(query_dict)
state = cls.cache_manager.get_or_create_state_core(
query_dict.get("scroll_id")
)
if query_dict.get("refresh_scroll"):
query_dict[cls._size_key] = max(state.position, size)
state.position = 0
query_dict[cls._start_key] = state.position
data = data_getter()
if cls._start_key in query_dict:
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
def update_state(returned_len: int):
if not state:
return
state.position = query_dict[cls._start_key]
cls.cache_manager.set_state(state)
if ret_params is not None:
ret_params["scroll_id"] = state.id
update_state(len(data))
return data
@classmethod
def get_many_with_join(
cls,
@ -495,6 +580,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
@ -530,6 +616,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
ret_params=ret_params,
)
def projection_func(doc_type, projection, ids):
@ -560,6 +647,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
@ -605,12 +693,18 @@ class GetMixin(PropsMixin):
_query = (q & query) if query else q
if return_dicts:
return cls._get_many_override_none_ordering(
data_getter = partial(
cls._get_many_override_none_ordering,
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
return cls.get_data_with_scroll_and_filter_support(
query_dict=query_dict,
data_getter=data_getter,
ret_params=ret_params,
)
return cls._get_many_no_company(
query=_query,
@ -662,7 +756,7 @@ class GetMixin(PropsMixin):
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
if order_by and not override_collation:
override_collation = cls._get_collation_override(order_by[0])
page, page_size = cls.validate_paging(parameters=parameters)
start, size = cls.validate_paging(parameters=parameters)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
@ -683,9 +777,9 @@ class GetMixin(PropsMixin):
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
if start is not None and size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
qs = qs.skip(start).limit(size)
return qs
@ -746,7 +840,7 @@ class GetMixin(PropsMixin):
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
start, size = cls.validate_paging(parameters=parameters)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
@ -778,25 +872,23 @@ class GetMixin(PropsMixin):
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
start = page * page_size
for qs in query_sets:
qs_size = qs.count()
if qs_size < start:
start -= qs_size
continue
ret.extend(
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
obj.to_proper_dict(only=include) for obj in qs.skip(start).limit(size)
)
if len(ret) >= page_size:
if len(ret) >= size:
break
start = 0
page_size -= len(ret)
size -= len(ret)
return ret

View File

@ -11,6 +11,7 @@ class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
range_fields=("last_update",),
)
meta = {

File diff suppressed because it is too large Load Diff

View File

@ -199,6 +199,29 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@ -302,6 +325,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_frameworks {
"2.8" {

View File

@ -152,6 +152,11 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
@ -430,6 +435,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all_ex {
internal: true
@ -488,6 +516,29 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
update {
"2.1" {

View File

@ -115,6 +115,29 @@ get_by_id {
get_all_ex {
internal: true
"2.4": ${get_all."2.4"}
"2.15": ${get_all_ex."2.4"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of queues to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.4" {
@ -178,6 +201,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.4"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of queues to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_default {
"2.4" {

View File

@ -685,6 +685,29 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@ -799,6 +822,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_types {
"2.8" {

View File

@ -37,7 +37,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.14")
_max_version = PartialVersion("2.15")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@ -124,11 +124,15 @@ def get_all_ex(call: APICall, company_id, _):
with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_by_id_ex", required_fields=["id"])
@ -148,14 +152,16 @@ def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)

View File

@ -111,8 +111,12 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
_adjust_search_parameters(data, shallow_search=request.shallow_search)
ret_params = {}
projects = Project.get_many_with_join(
company=company_id, query_dict=data, allow_public=allow_public,
company=company_id,
query_dict=data,
allow_public=allow_public,
ret_params=ret_params,
)
if request.check_own_contents and requested_ids:
@ -128,7 +132,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
return
project_ids = {project["id"] for project in projects}
@ -142,7 +146,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
@endpoint("projects.get_all")
@ -151,15 +155,17 @@ def get_all(call: APICall):
data = call.data
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
@endpoint(

View File

@ -48,21 +48,29 @@ def get_by_id(call: APICall):
@endpoint("queues.get_all_ex", min_version="2.4")
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
ret_params = {}
queues = queue_bll.get_queue_infos(
company_id=call.identity.company, query_dict=call.data
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.get_all", min_version="2.4")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data)
ret_params = {}
queues = queue_bll.get_all(
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)

View File

@ -221,11 +221,15 @@ def get_all_ex(call: APICall, company_id, _):
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
company=company_id,
query_dict=call_data,
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
call.result.data = {"tasks": tasks, **ret_params}
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
@ -250,14 +254,16 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all"):
ret_params = {}
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
call.result.data = {"tasks": tasks, **ret_params}
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)

View File

@ -71,7 +71,7 @@ class TestService(TestCase, TestServiceInterface):
delete_params=delete_params,
)
def setUp(self, version="1.7"):
def setUp(self, version="999.0"):
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
self._deferred = []
self._version = parse(version)

View File

@ -0,0 +1,82 @@
import math
from apiserver.tests.automated import TestService
class TestEntityOrdering(TestService):
name_prefix = f"Test paging "
def setUp(self, **kwargs):
super().setUp(**kwargs)
self.task_ids = self._create_tasks()
def _create_tasks(self):
tasks = [
self._temp_task(
name=f"{self.name_prefix}{i}",
hyperparams={"test": {"param": {"section": "test", "name": "param", "type": "str", "value": str(i)}}},
)
for i in range(18)
]
return tasks
def test_paging(self):
page_size = 10
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
tasks = self._get_tasks(
page=page,
page_size=page_size,
).tasks
self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
def test_scrolling(self):
page_size = 10
scroll_id = None
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
res = self._get_tasks(
size=page_size,
scroll_id=scroll_id,
)
self.assertTrue(res.scroll_id)
scroll_id = res.scroll_id
tasks = res.tasks
self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
# no more data in this scroll
tasks = self._get_tasks(
size=page_size,
scroll_id=scroll_id,
).tasks
self.assertFalse(tasks)
# refresh brings all
tasks = self._get_tasks(
size=page_size,
scroll_id=scroll_id,
refresh_scroll=True,
).tasks
self.assertEqual([t.id for t in tasks], self.task_ids)
def _get_tasks(self, **page_params):
return self.api.tasks.get_all_ex(
name="^Test paging ",
order_by=["hyperparams.param"],
**page_params,
)
def _temp_task(self, name, **kwargs):
return self.create_temp(
"tasks",
name=name,
comment="Test task",
type="testing",
input=dict(view=dict()),
**kwargs,
)