mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add scroll support to *.get_*
This commit is contained in:
parent
446bd35006
commit
f20cd6536e
@ -126,14 +126,27 @@ class QueueBLL(object):
|
||||
)
|
||||
queue.delete()
|
||||
|
||||
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
return Queue.get_many(
|
||||
company=company_id, parameters=query_dict, query_dict=query_dict
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_queue_infos(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
"""
|
||||
@ -143,6 +156,7 @@ class QueueBLL(object):
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
override_projection=projection,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
queue_workers = defaultdict(list)
|
||||
|
@ -49,6 +49,21 @@ class RedisCacheManager(Generic[T]):
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
def get_or_create_state_core(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
) -> T:
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
return state
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
@ -66,12 +81,9 @@ class RedisCacheManager(Generic[T]):
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
state = self.get_or_create_state_core(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
)
|
||||
|
||||
try:
|
||||
yield state
|
||||
|
4
apiserver/config/default/services/_mongo.conf
Normal file
4
apiserver/config/default/services/_mongo.conf
Normal file
@ -0,0 +1,4 @@
|
||||
max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
@ -1,16 +1,18 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
|
||||
from functools import reduce, partial
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any, Callable
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.errors import MakeGetAllQueryError
|
||||
from apiserver.database.projection import project_dict, ProjectionHelper
|
||||
from apiserver.database.props import PropsMixin
|
||||
@ -21,6 +23,7 @@ from apiserver.database.utils import (
|
||||
field_does_not_exist,
|
||||
field_exists,
|
||||
)
|
||||
from apiserver.redis_manager import redman
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
@ -70,6 +73,9 @@ class GetMixin(PropsMixin):
|
||||
_ordering_key = "order_by"
|
||||
_search_text_key = "search_text"
|
||||
|
||||
_start_key = "start"
|
||||
_size_key = "size"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
@ -143,6 +149,20 @@ class GetMixin(PropsMixin):
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
class GetManyScrollState(ProperDictMixin, Document):
|
||||
meta = {"db_alias": Database.backend, "strict": False}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
position = IntField(default=0)
|
||||
|
||||
cache_manager = RedisCacheManager(
|
||||
state_class=GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Union["GetMixin", Document],
|
||||
@ -421,27 +441,41 @@ class GetMixin(PropsMixin):
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
|
||||
"""
|
||||
Validate and extract paging info from from the provided dictionary. Supports default values.
|
||||
If page is specified then it should be non-negative, if page size is specified then it should be positive
|
||||
If page size is specified and page is not then 0 page is assumed
|
||||
If page is specified then page size should be specified too
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
start = parameters.get(cls._start_key)
|
||||
if start is not None:
|
||||
return start, cls.validate_scroll_size(parameters)
|
||||
|
||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
||||
page = parameters.get("page", default_page)
|
||||
if page is not None and page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
|
||||
page_size = parameters.get("page_size", default_page_size or max_page_size)
|
||||
if page_size is not None and page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
if page_size is not None:
|
||||
page = page or 0
|
||||
page_size = min(page_size, max_page_size)
|
||||
return page * page_size, page_size
|
||||
|
||||
if page is not None:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
@ -485,6 +519,57 @@ class GetMixin(PropsMixin):
|
||||
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
||||
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
||||
|
||||
@classmethod
|
||||
def validate_scroll_size(cls, query_dict: dict) -> int:
|
||||
size = query_dict.get(cls._size_key)
|
||||
if not size or not isinstance(size, int) or size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Integer size parameter greater than 1 should be provided when working with scroll"
|
||||
)
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_data_with_scroll_and_filter_support(
|
||||
cls,
|
||||
query_dict: dict,
|
||||
data_getter: Callable[[], Sequence[dict]],
|
||||
ret_params: dict,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Retrieves the data by calling the provided data_getter api
|
||||
If scroll parameters are specified then put the query_dict 'start' parameter to the last
|
||||
scroll position and continue retrievals from that position
|
||||
If refresh_scroll is requested then bring once more the data from the beginning
|
||||
till the current scroll position
|
||||
In the end the scroll position is updated and accumulated frames are returned
|
||||
"""
|
||||
query_dict = query_dict or {}
|
||||
state: Optional[cls.GetManyScrollState] = None
|
||||
if "scroll_id" in query_dict:
|
||||
size = cls.validate_scroll_size(query_dict)
|
||||
state = cls.cache_manager.get_or_create_state_core(
|
||||
query_dict.get("scroll_id")
|
||||
)
|
||||
if query_dict.get("refresh_scroll"):
|
||||
query_dict[cls._size_key] = max(state.position, size)
|
||||
state.position = 0
|
||||
query_dict[cls._start_key] = state.position
|
||||
|
||||
data = data_getter()
|
||||
if cls._start_key in query_dict:
|
||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||
|
||||
def update_state(returned_len: int):
|
||||
if not state:
|
||||
return
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.cache_manager.set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
update_state(len(data))
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
@ -495,6 +580,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
@ -530,6 +616,7 @@ class GetMixin(PropsMixin):
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
@ -560,6 +647,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
@ -605,12 +693,18 @@ class GetMixin(PropsMixin):
|
||||
_query = (q & query) if query else q
|
||||
|
||||
if return_dicts:
|
||||
return cls._get_many_override_none_ordering(
|
||||
data_getter = partial(
|
||||
cls._get_many_override_none_ordering,
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
return cls.get_data_with_scroll_and_filter_support(
|
||||
query_dict=query_dict,
|
||||
data_getter=data_getter,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
@ -662,7 +756,7 @@ class GetMixin(PropsMixin):
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
if order_by and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_by[0])
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@ -683,9 +777,9 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
if start is not None and size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
qs = qs.skip(start).limit(size)
|
||||
|
||||
return qs
|
||||
|
||||
@ -746,7 +840,7 @@ class GetMixin(PropsMixin):
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@ -778,25 +872,23 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
if start is None or not size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
obj.to_proper_dict(only=include) for obj in qs.skip(start).limit(size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
if len(ret) >= size:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -11,6 +11,7 @@ class Project(AttributedDocument):
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
range_fields=("last_update",),
|
||||
)
|
||||
|
||||
meta = {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -199,6 +199,29 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@ -302,6 +325,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_frameworks {
|
||||
"2.8" {
|
||||
|
@ -152,6 +152,11 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
@ -430,6 +435,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
@ -488,6 +516,29 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
|
@ -115,6 +115,29 @@ get_by_id {
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.4": ${get_all."2.4"}
|
||||
"2.15": ${get_all_ex."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.4" {
|
||||
@ -178,6 +201,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_default {
|
||||
"2.4" {
|
||||
|
@ -685,6 +685,29 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@ -799,6 +822,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_types {
|
||||
"2.8" {
|
||||
|
@ -37,7 +37,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.14")
|
||||
_max_version = PartialVersion("2.15")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
|
@ -124,11 +124,15 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
with translate_errors_context():
|
||||
_process_include_subprojects(call.data)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
||||
@ -148,14 +152,16 @@ def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
||||
|
@ -111,8 +111,12 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
_adjust_search_parameters(data, shallow_search=request.shallow_search)
|
||||
|
||||
ret_params = {}
|
||||
projects = Project.get_many_with_join(
|
||||
company=company_id, query_dict=data, allow_public=allow_public,
|
||||
company=company_id,
|
||||
query_dict=data,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
if request.check_own_contents and requested_ids:
|
||||
@ -128,7 +132,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
conform_output_tags(call, projects)
|
||||
if not request.include_stats:
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
return
|
||||
|
||||
project_ids = {project["id"] for project in projects}
|
||||
@ -142,7 +146,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
project["stats"] = stats[project["id"]]
|
||||
project["sub_projects"] = children[project["id"]]
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
@ -151,15 +155,17 @@ def get_all(call: APICall):
|
||||
data = call.data
|
||||
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=data,
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint(
|
||||
|
@ -48,21 +48,29 @@ def get_by_id(call: APICall):
|
||||
@endpoint("queues.get_all_ex", min_version="2.4")
|
||||
def get_all_ex(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=call.identity.company, query_dict=call.data
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
call.result.data = {"queues": queues}
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@endpoint("queues.get_all", min_version="2.4")
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data)
|
||||
ret_params = {}
|
||||
queues = queue_bll.get_all(
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
call.result.data = {"queues": queues}
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
|
||||
|
@ -221,11 +221,15 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
|
||||
@ -250,14 +254,16 @@ def get_all(call: APICall, company_id, _):
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
ret_params = {}
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
|
@ -71,7 +71,7 @@ class TestService(TestCase, TestServiceInterface):
|
||||
delete_params=delete_params,
|
||||
)
|
||||
|
||||
def setUp(self, version="1.7"):
|
||||
def setUp(self, version="999.0"):
|
||||
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
|
||||
self._deferred = []
|
||||
self._version = parse(version)
|
||||
|
82
apiserver/tests/automated/test_paging_and_scrolling.py
Normal file
82
apiserver/tests/automated/test_paging_and_scrolling.py
Normal file
@ -0,0 +1,82 @@
|
||||
import math
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestEntityOrdering(TestService):
|
||||
name_prefix = f"Test paging "
|
||||
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(**kwargs)
|
||||
self.task_ids = self._create_tasks()
|
||||
|
||||
def _create_tasks(self):
|
||||
tasks = [
|
||||
self._temp_task(
|
||||
name=f"{self.name_prefix}{i}",
|
||||
hyperparams={"test": {"param": {"section": "test", "name": "param", "type": "str", "value": str(i)}}},
|
||||
)
|
||||
for i in range(18)
|
||||
]
|
||||
return tasks
|
||||
|
||||
def test_paging(self):
|
||||
page_size = 10
|
||||
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
|
||||
start = page * page_size
|
||||
expected_size = min(page_size, len(self.task_ids) - start)
|
||||
tasks = self._get_tasks(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
).tasks
|
||||
self.assertEqual(len(tasks), expected_size)
|
||||
for i, t in enumerate(tasks):
|
||||
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
|
||||
|
||||
def test_scrolling(self):
|
||||
page_size = 10
|
||||
scroll_id = None
|
||||
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
|
||||
start = page * page_size
|
||||
expected_size = min(page_size, len(self.task_ids) - start)
|
||||
res = self._get_tasks(
|
||||
size=page_size,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
self.assertTrue(res.scroll_id)
|
||||
scroll_id = res.scroll_id
|
||||
tasks = res.tasks
|
||||
self.assertEqual(len(tasks), expected_size)
|
||||
for i, t in enumerate(tasks):
|
||||
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
|
||||
|
||||
# no more data in this scroll
|
||||
tasks = self._get_tasks(
|
||||
size=page_size,
|
||||
scroll_id=scroll_id,
|
||||
).tasks
|
||||
self.assertFalse(tasks)
|
||||
|
||||
# refresh brings all
|
||||
tasks = self._get_tasks(
|
||||
size=page_size,
|
||||
scroll_id=scroll_id,
|
||||
refresh_scroll=True,
|
||||
).tasks
|
||||
self.assertEqual([t.id for t in tasks], self.task_ids)
|
||||
|
||||
def _get_tasks(self, **page_params):
|
||||
return self.api.tasks.get_all_ex(
|
||||
name="^Test paging ",
|
||||
order_by=["hyperparams.param"],
|
||||
**page_params,
|
||||
)
|
||||
|
||||
def _temp_task(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
name=name,
|
||||
comment="Test task",
|
||||
type="testing",
|
||||
input=dict(view=dict()),
|
||||
**kwargs,
|
||||
)
|
Loading…
Reference in New Issue
Block a user