mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add scroll support to *.get_*
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
|
||||
from functools import reduce, partial
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any, Callable
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.errors import MakeGetAllQueryError
|
||||
from apiserver.database.projection import project_dict, ProjectionHelper
|
||||
from apiserver.database.props import PropsMixin
|
||||
@@ -21,6 +23,7 @@ from apiserver.database.utils import (
|
||||
field_does_not_exist,
|
||||
field_exists,
|
||||
)
|
||||
from apiserver.redis_manager import redman
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
@@ -70,6 +73,9 @@ class GetMixin(PropsMixin):
|
||||
_ordering_key = "order_by"
|
||||
_search_text_key = "search_text"
|
||||
|
||||
_start_key = "start"
|
||||
_size_key = "size"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
@@ -143,6 +149,20 @@ class GetMixin(PropsMixin):
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
class GetManyScrollState(ProperDictMixin, Document):
|
||||
meta = {"db_alias": Database.backend, "strict": False}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
position = IntField(default=0)
|
||||
|
||||
cache_manager = RedisCacheManager(
|
||||
state_class=GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Union["GetMixin", Document],
|
||||
@@ -421,27 +441,41 @@ class GetMixin(PropsMixin):
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
|
||||
"""
|
||||
Validate and extract paging info from from the provided dictionary. Supports default values.
|
||||
If page is specified then it should be non-negative, if page size is specified then it should be positive
|
||||
If page size is specified and page is not then 0 page is assumed
|
||||
If page is specified then page size should be specified too
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
start = parameters.get(cls._start_key)
|
||||
if start is not None:
|
||||
return start, cls.validate_scroll_size(parameters)
|
||||
|
||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
||||
page = parameters.get("page", default_page)
|
||||
if page is not None and page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
|
||||
page_size = parameters.get("page_size", default_page_size or max_page_size)
|
||||
if page_size is not None and page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
if page_size is not None:
|
||||
page = page or 0
|
||||
page_size = min(page_size, max_page_size)
|
||||
return page * page_size, page_size
|
||||
|
||||
if page is not None:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
@@ -485,6 +519,57 @@ class GetMixin(PropsMixin):
|
||||
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
||||
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
||||
|
||||
@classmethod
|
||||
def validate_scroll_size(cls, query_dict: dict) -> int:
|
||||
size = query_dict.get(cls._size_key)
|
||||
if not size or not isinstance(size, int) or size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Integer size parameter greater than 1 should be provided when working with scroll"
|
||||
)
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_data_with_scroll_and_filter_support(
|
||||
cls,
|
||||
query_dict: dict,
|
||||
data_getter: Callable[[], Sequence[dict]],
|
||||
ret_params: dict,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Retrieves the data by calling the provided data_getter api
|
||||
If scroll parameters are specified then put the query_dict 'start' parameter to the last
|
||||
scroll position and continue retrievals from that position
|
||||
If refresh_scroll is requested then bring once more the data from the beginning
|
||||
till the current scroll position
|
||||
In the end the scroll position is updated and accumulated frames are returned
|
||||
"""
|
||||
query_dict = query_dict or {}
|
||||
state: Optional[cls.GetManyScrollState] = None
|
||||
if "scroll_id" in query_dict:
|
||||
size = cls.validate_scroll_size(query_dict)
|
||||
state = cls.cache_manager.get_or_create_state_core(
|
||||
query_dict.get("scroll_id")
|
||||
)
|
||||
if query_dict.get("refresh_scroll"):
|
||||
query_dict[cls._size_key] = max(state.position, size)
|
||||
state.position = 0
|
||||
query_dict[cls._start_key] = state.position
|
||||
|
||||
data = data_getter()
|
||||
if cls._start_key in query_dict:
|
||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||
|
||||
def update_state(returned_len: int):
|
||||
if not state:
|
||||
return
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.cache_manager.set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
update_state(len(data))
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
@@ -495,6 +580,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
@@ -530,6 +616,7 @@ class GetMixin(PropsMixin):
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
@@ -560,6 +647,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
@@ -605,12 +693,18 @@ class GetMixin(PropsMixin):
|
||||
_query = (q & query) if query else q
|
||||
|
||||
if return_dicts:
|
||||
return cls._get_many_override_none_ordering(
|
||||
data_getter = partial(
|
||||
cls._get_many_override_none_ordering,
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
return cls.get_data_with_scroll_and_filter_support(
|
||||
query_dict=query_dict,
|
||||
data_getter=data_getter,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
@@ -662,7 +756,7 @@ class GetMixin(PropsMixin):
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
if order_by and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_by[0])
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -683,9 +777,9 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
if start is not None and size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
qs = qs.skip(start).limit(size)
|
||||
|
||||
return qs
|
||||
|
||||
@@ -746,7 +840,7 @@ class GetMixin(PropsMixin):
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -778,25 +872,23 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
if start is None or not size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include) for obj in qs.skip(start).limit(size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
if len(ret) >= size:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ class Project(AttributedDocument):
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
range_fields=("last_update",),
|
||||
)
|
||||
|
||||
meta = {
|
||||
|
||||
Reference in New Issue
Block a user