Add organization.update_company_name

Fix unit-tests
This commit is contained in:
allegroai 2022-02-13 19:29:46 +02:00
parent cae38a365b
commit 604a38035b
3 changed files with 63 additions and 37 deletions

View File

@ -28,7 +28,7 @@ from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.props import PropsMixin
from apiserver.database.query import RegexQ, RegexWrapper
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
from apiserver.database.utils import (
get_company_or_none_constraint,
get_fields_choices,
@ -131,6 +131,8 @@ class GetMixin(PropsMixin):
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
}
def __init__(self, legacy=False):
@ -138,6 +140,13 @@ class GetMixin(PropsMixin):
self._sticky = False
self._support_legacy = legacy
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self._current_op = None
@ -151,16 +160,27 @@ class GetMixin(PropsMixin):
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
elif v.startswith(self.op_prefix):
else:
op = self._get_op(v)
if op is not None:
self._current_op, self._sticky = self._ops.get(
v[len(self.op_prefix):], (self.default_mongo_op, self._sticky)
op, (self.default_mongo_op, self._sticky)
)
return None
return self.default_mongo_op
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
def get_actions(self, data: Sequence[str]) -> Tuple[Dict[str, List[Union[str, None]]], Optional[str]]:
actions = {}
if not data:
return actions, None
global_op = self._get_op(data[0], translate=True)
if global_op in ("in", "all"):
data = data[1:]
else:
global_op = None
for val in data:
key = self._key(val)
if key is None:
@ -169,7 +189,8 @@ class GetMixin(PropsMixin):
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
return actions, global_op or self.default_mongo_op
get_all_query_options = QueryParameterOptions()
@ -285,7 +306,7 @@ class GetMixin(PropsMixin):
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
IMPLEMENTATION NOTE: Make sure that inside this function or the function it depends on RegexQ is always
IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
used instead of Q. Otherwise we can and up with some combination that is not processed according to
RegexQ rules
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
@ -425,11 +446,14 @@ class GetMixin(PropsMixin):
"""
if not isinstance(data, (list, tuple)):
data = [data]
# raise MakeGetAllQueryError("expected list", field)
actions = cls.ListFieldBucketHelper(legacy=True).get_actions(data)
actions, global_op = cls.ListFieldBucketHelper(legacy=True).get_actions(data)
default_op = cls.ListFieldBucketHelper.default_mongo_op
# Handle `allow_empty` hack: controlled using `None` as a specific value in the default "in" action
allow_empty = False
default_op_actions = actions.get(cls.ListFieldBucketHelper.default_mongo_op)
default_op_actions = actions.get(default_op)
if default_op_actions and None in default_op_actions:
allow_empty = True
default_op_actions.remove(None)
@ -438,10 +462,17 @@ class GetMixin(PropsMixin):
mongoengine_field = field.replace(".", "__")
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
]
if not queries:
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
else:
q = RegexQCombination(
operation=RegexQ.AND if global_op is not default_op else RegexQ.OR,
children=queries
)
if not allow_empty:

View File

@ -2,9 +2,6 @@ from apiserver.tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_get_user_companies(self):
company = self.api.organization.get_user_companies().companies[0]
self.assertEqual(len(company.owners), company.allocated)

View File

@ -2,7 +2,7 @@ import math
from apiserver.tests.automated import TestService
class TestEntityOrdering(TestService):
class TestPagingAndScrolling(TestService):
name_prefix = f"Test paging "
def setUp(self, **kwargs):
@ -13,7 +13,16 @@ class TestEntityOrdering(TestService):
tasks = [
self._temp_task(
name=f"{self.name_prefix}{i}",
hyperparams={"test": {"param": {"section": "test", "name": "param", "type": "str", "value": str(i)}}},
hyperparams={
"test": {
"param": {
"section": "test",
"name": "param",
"type": "str",
"value": str(i),
}
}
},
)
for i in range(18)
]
@ -24,10 +33,7 @@ class TestEntityOrdering(TestService):
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
tasks = self._get_tasks(
page=page,
page_size=page_size,
).tasks
tasks = self._get_tasks(page=page, page_size=page_size,).tasks
self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
@ -38,10 +44,7 @@ class TestEntityOrdering(TestService):
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
res = self._get_tasks(
size=page_size,
scroll_id=scroll_id,
)
res = self._get_tasks(size=page_size, scroll_id=scroll_id,)
self.assertTrue(res.scroll_id)
scroll_id = res.scroll_id
tasks = res.tasks
@ -50,24 +53,19 @@ class TestEntityOrdering(TestService):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
# no more data in this scroll
tasks = self._get_tasks(
size=page_size,
scroll_id=scroll_id,
).tasks
tasks = self._get_tasks(size=page_size, scroll_id=scroll_id,).tasks
self.assertFalse(tasks)
# refresh brings all
tasks = self._get_tasks(
size=page_size,
scroll_id=scroll_id,
refresh_scroll=True,
size=page_size, scroll_id=scroll_id, refresh_scroll=True,
).tasks
self.assertEqual([t.id for t in tasks], self.task_ids)
def _get_tasks(self, **page_params):
return self.api.tasks.get_all_ex(
name="^Test paging ",
order_by=["hyperparams.param"],
order_by=["hyperparams.test.param.value"],
**page_params,
)