Add organization.update_company_name

Fix unit-tests
This commit is contained in:
allegroai 2022-02-13 19:29:46 +02:00
parent cae38a365b
commit 604a38035b
3 changed files with 63 additions and 37 deletions

View File

@ -28,7 +28,7 @@ from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.props import PropsMixin from apiserver.database.props import PropsMixin
from apiserver.database.query import RegexQ, RegexWrapper from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
from apiserver.database.utils import ( from apiserver.database.utils import (
get_company_or_none_constraint, get_company_or_none_constraint,
get_fields_choices, get_fields_choices,
@ -131,6 +131,8 @@ class GetMixin(PropsMixin):
"nop": (default_mongo_op, False), "nop": (default_mongo_op, False),
"all": ("all", True), "all": ("all", True),
"and": ("all", True), "and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
} }
def __init__(self, legacy=False): def __init__(self, legacy=False):
@ -138,6 +140,13 @@ class GetMixin(PropsMixin):
self._sticky = False self._sticky = False
self._support_legacy = legacy self._support_legacy = legacy
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
def _key(self, v) -> Optional[Union[str, bool]]: def _key(self, v) -> Optional[Union[str, bool]]:
if v is None: if v is None:
self._current_op = None self._current_op = None
@ -151,16 +160,27 @@ class GetMixin(PropsMixin):
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix): elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None self._current_op = None
return False return False
elif v.startswith(self.op_prefix): else:
op = self._get_op(v)
if op is not None:
self._current_op, self._sticky = self._ops.get( self._current_op, self._sticky = self._ops.get(
v[len(self.op_prefix):], (self.default_mongo_op, self._sticky) op, (self.default_mongo_op, self._sticky)
) )
return None return None
return self.default_mongo_op return self.default_mongo_op
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]: def get_actions(self, data: Sequence[str]) -> Tuple[Dict[str, List[Union[str, None]]], Optional[str]]:
actions = {} actions = {}
if not data:
return actions, None
global_op = self._get_op(data[0], translate=True)
if global_op in ("in", "all"):
data = data[1:]
else:
global_op = None
for val in data: for val in data:
key = self._key(val) key = self._key(val)
if key is None: if key is None:
@ -169,7 +189,8 @@ class GetMixin(PropsMixin):
key = self._legacy_exclude_mongo_op key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :] val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val) actions.setdefault(key, []).append(val)
return actions
return actions, global_op or self.default_mongo_op
get_all_query_options = QueryParameterOptions() get_all_query_options = QueryParameterOptions()
@ -285,7 +306,7 @@ class GetMixin(PropsMixin):
Prepare a query object based on the provided query dictionary and various fields. Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies. NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
IMPLEMENTATION NOTE: Make sure that inside this function or the function it depends on RegexQ is always IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
used instead of Q. Otherwise we can and up with some combination that is not processed according to used instead of Q. Otherwise we can and up with some combination that is not processed according to
RegexQ rules RegexQ rules
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions) :param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
@ -425,11 +446,14 @@ class GetMixin(PropsMixin):
""" """
if not isinstance(data, (list, tuple)): if not isinstance(data, (list, tuple)):
data = [data] data = [data]
# raise MakeGetAllQueryError("expected list", field)
actions = cls.ListFieldBucketHelper(legacy=True).get_actions(data) actions, global_op = cls.ListFieldBucketHelper(legacy=True).get_actions(data)
default_op = cls.ListFieldBucketHelper.default_mongo_op
# Handle `allow_empty` hack: controlled using `None` as a specific value in the default "in" action
allow_empty = False allow_empty = False
default_op_actions = actions.get(cls.ListFieldBucketHelper.default_mongo_op) default_op_actions = actions.get(default_op)
if default_op_actions and None in default_op_actions: if default_op_actions and None in default_op_actions:
allow_empty = True allow_empty = True
default_op_actions.remove(None) default_op_actions.remove(None)
@ -438,10 +462,17 @@ class GetMixin(PropsMixin):
mongoengine_field = field.replace(".", "__") mongoengine_field = field.replace(".", "__")
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
]
if not queries:
q = RegexQ() q = RegexQ()
for action in filter(None, actions): else:
q &= RegexQ( q = RegexQCombination(
**{f"{mongoengine_field}__{action}": list(set(actions[action]))} operation=RegexQ.AND if global_op is not default_op else RegexQ.OR,
children=queries
) )
if not allow_empty: if not allow_empty:

View File

@ -2,9 +2,6 @@ from apiserver.tests.automated import TestService
class TestOrganization(TestService): class TestOrganization(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_get_user_companies(self): def test_get_user_companies(self):
company = self.api.organization.get_user_companies().companies[0] company = self.api.organization.get_user_companies().companies[0]
self.assertEqual(len(company.owners), company.allocated) self.assertEqual(len(company.owners), company.allocated)

View File

@ -2,7 +2,7 @@ import math
from apiserver.tests.automated import TestService from apiserver.tests.automated import TestService
class TestEntityOrdering(TestService): class TestPagingAndScrolling(TestService):
name_prefix = f"Test paging " name_prefix = f"Test paging "
def setUp(self, **kwargs): def setUp(self, **kwargs):
@ -13,7 +13,16 @@ class TestEntityOrdering(TestService):
tasks = [ tasks = [
self._temp_task( self._temp_task(
name=f"{self.name_prefix}{i}", name=f"{self.name_prefix}{i}",
hyperparams={"test": {"param": {"section": "test", "name": "param", "type": "str", "value": str(i)}}}, hyperparams={
"test": {
"param": {
"section": "test",
"name": "param",
"type": "str",
"value": str(i),
}
}
},
) )
for i in range(18) for i in range(18)
] ]
@ -24,10 +33,7 @@ class TestEntityOrdering(TestService):
for page in range(0, math.ceil(len(self.task_ids) / page_size)): for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start) expected_size = min(page_size, len(self.task_ids) - start)
tasks = self._get_tasks( tasks = self._get_tasks(page=page, page_size=page_size,).tasks
page=page,
page_size=page_size,
).tasks
self.assertEqual(len(tasks), expected_size) self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}") self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
@ -38,10 +44,7 @@ class TestEntityOrdering(TestService):
for page in range(0, math.ceil(len(self.task_ids) / page_size)): for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start) expected_size = min(page_size, len(self.task_ids) - start)
res = self._get_tasks( res = self._get_tasks(size=page_size, scroll_id=scroll_id,)
size=page_size,
scroll_id=scroll_id,
)
self.assertTrue(res.scroll_id) self.assertTrue(res.scroll_id)
scroll_id = res.scroll_id scroll_id = res.scroll_id
tasks = res.tasks tasks = res.tasks
@ -50,24 +53,19 @@ class TestEntityOrdering(TestService):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}") self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
# no more data in this scroll # no more data in this scroll
tasks = self._get_tasks( tasks = self._get_tasks(size=page_size, scroll_id=scroll_id,).tasks
size=page_size,
scroll_id=scroll_id,
).tasks
self.assertFalse(tasks) self.assertFalse(tasks)
# refresh brings all # refresh brings all
tasks = self._get_tasks( tasks = self._get_tasks(
size=page_size, size=page_size, scroll_id=scroll_id, refresh_scroll=True,
scroll_id=scroll_id,
refresh_scroll=True,
).tasks ).tasks
self.assertEqual([t.id for t in tasks], self.task_ids) self.assertEqual([t.id for t in tasks], self.task_ids)
def _get_tasks(self, **page_params): def _get_tasks(self, **page_params):
return self.api.tasks.get_all_ex( return self.api.tasks.get_all_ex(
name="^Test paging ", name="^Test paging ",
order_by=["hyperparams.param"], order_by=["hyperparams.test.param.value"],
**page_params, **page_params,
) )