Support filtering users by activity in projects

This commit is contained in:
allegroai 2020-06-01 11:55:40 +03:00
parent 45d434a123
commit f8d8fc40a6
13 changed files with 201 additions and 95 deletions

View File

@ -1 +1 @@
__version__ = "2.7.0"
__version__ = "2.8.0"

View File

@ -13,17 +13,21 @@
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
}
webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
revoke_in_fixed_mode: true
}
tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
}
}
}

View File

@ -43,6 +43,7 @@ class Role(object):
class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
last_used = DateTimeField()

View File

@ -3,6 +3,7 @@ from mongoengine import Document, StringField, DateTimeField, ListField, Boolean
from database import Database, strict
from database.fields import StrippedStringField, SafeDictField
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.model_labels import ModelLabels
from database.model.company import Company
from database.model.project import Project
@ -19,6 +20,7 @@ class Model(DbModelMixin, Document):
"project",
"task",
("company", "name"),
("company", "user"),
{
"name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@ -34,6 +36,21 @@ class Model(DbModelMixin, Document):
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"user",
"project",
"task",
"parent",
),
)
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)

View File

@ -18,7 +18,7 @@ from database.fields import (
SafeSortedListField,
)
from database.model import AttributedDocument
from database.model.base import ProperDictMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
@ -113,6 +113,7 @@ class Task(AttributedDocument):
"parent",
"project",
("company", "name"),
("company", "user"),
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
@ -140,6 +141,12 @@ class Task(AttributedDocument):
},
],
}
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
id = StringField(primary_key=True)
name = StrippedStringField(

View File

@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField
from database import Database, strict
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.company import Company
class User(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
"db_alias": Database.backend,
"strict": strict,
}
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
id = StringField(primary_key=True)
company = StringField(required=True, reference_field=Company)

View File

@ -38,29 +38,20 @@ def init_mongo_data():
PrePopulate.import_from_zip(zip_file, user_id=user_id)
users = [
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
"revoke_in_fixed_mode": True,
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
fixed_mode = FixedUser.enabled()
for user in users:
revoke = fixed_mode and user.pop("revoke_in_fixed_mode", False)
credentials = config.get(f"secure.credentials.{user['name']}")
user["key"] = credentials.user_key
user["secret"] = credentials.user_secret
_ensure_auth_user(user, company_id, log=log, revoke=revoke)
for user, credentials in config.get("secure.credentials", {}).items():
user_data = {
"name": user,
"role": credentials.role,
"email": f"{user}@example.com",
"key": credentials.user_key,
"secret": credentials.user_secret,
}
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)
if fixed_mode:
log.info("Fixed users mode is enabled")

View File

@ -159,6 +159,11 @@
description: "Get only models whose name matches this pattern (python regular expression syntax)"
type: string
}
user {
description: "List of user IDs used to filter results by the model's creating user"
type: array
items { type: string }
}
ready {
description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects."
type: boolean

View File

@ -145,6 +145,19 @@ get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
}
"2.8": ${get_all."2.1"} {
request {
type: object
properties {
active_in_projects {
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
type: array
items { type: string }
}
}
}
}
}
get_all {

View File

@ -29,20 +29,6 @@ from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
log = config.logger(__file__)
get_all_query_options = Model.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"project",
"task",
"parent",
),
)
@endpoint("models.get_by_id", required_fields=["model"])
@ -103,10 +89,7 @@ def get_all_ex(call: APICall):
with translate_errors_context():
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
company=call.identity.company, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
@ -122,7 +105,6 @@ def get_all(call: APICall):
parameters=call.data,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
)
conform_output_tags(call, models)
call.result.data = {"models": models}

View File

@ -59,12 +59,6 @@ from utilities import safe_get
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
get_all_query_options = Task.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
task_bll = TaskBLL()
event_bll = EventBLL()
@ -145,7 +139,6 @@ def get_all_ex(call: APICall):
tasks = Task.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
unprepare_from_saved(call, tasks)
@ -164,7 +157,6 @@ def get_all(call: APICall):
company=call.identity.company,
parameters=call.data,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
)
unprepare_from_saved(call, tasks)

View File

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, Tuple
from typing import Tuple
import dpath
from boltons.iterutils import remap
@ -8,6 +8,7 @@ from mongoengine import Q
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.users import CreateRequest, SetPreferencesRequest
from bll.project import ProjectBLL
from bll.user import UserBLL
from config import config
from database.errors import translate_errors_context
@ -19,10 +20,10 @@ from service_repo import APICall, endpoint
from utilities.json import loads, dumps
log = config.logger(__file__)
get_all_query_options = User.QueryParameterOptions(list_fields=("id",))
project_bll = ProjectBLL()
def get_user(call, user_id, only=None):
def get_user(call, company_id, user_id, only=None):
"""
Get user object by the user's ID
:param call: API call
@ -34,7 +35,7 @@ def get_user(call, user_id, only=None):
# allow system users to get info for all users
query = dict(id=user_id)
else:
query = dict(id=user_id, company=call.identity.company)
query = dict(id=user_id, company=company_id)
with translate_errors_context("retrieving user"):
user = User.objects(**query)
@ -48,47 +49,53 @@ def get_user(call, user_id, only=None):
@endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call):
assert isinstance(call, APICall)
def get_by_id(call: APICall, company_id, _):
user_id = call.data["user"]
call.result.data = {"user": get_user(call, user_id)}
call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[])
def get_all_ex(call):
assert isinstance(call, APICall)
def get_all_ex(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
)
res = User.get_many_with_join(company=company_id, query_dict=call.data)
call.result.data = {"users": res}
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
active_in_projects = call.data.get("active_in_projects", None)
if active_in_projects is not None:
active_users = project_bll.get_active_users(
company_id, active_in_projects, call.data.get("id")
)
active_users.discard(None)
if not active_users:
call.result.data = {"users": []}
return
data = data.copy()
data["id"] = list(active_users)
res = User.get_many_with_join(company=company_id, query_dict=data)
call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[])
def get_all(call):
assert isinstance(call, APICall)
def get_all(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
res = User.get_many(
company=call.identity.company,
parameters=call.data,
query_dict=call.data,
query_options=get_all_query_options,
company=company_id, parameters=call.data, query_dict=call.data
)
call.result.data = {"users": res}
@endpoint("users.get_current_user")
def get_current_user(call):
assert isinstance(call, APICall)
def get_current_user(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
projection = (
{"company.name"}
.union(User.get_fields())
@ -96,7 +103,7 @@ def get_current_user(call):
)
res = User.get_many_with_join(
query=Q(id=call.identity.user),
company=call.identity.company,
company=company_id,
override_projection=projection,
)
@ -126,13 +133,11 @@ def create(call: APICall):
@endpoint("users.delete", required_fields=["user"])
def delete(call):
assert isinstance(call, APICall)
def delete(call: APICall):
UserBLL.delete(call.data["user"])
def update_user(user_id, company_id, data):
# type: (str, str, Dict) -> Tuple[int, Dict]
def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
"""
Update user.
:param user_id: user ID to update
@ -150,31 +155,29 @@ def update_user(user_id, company_id, data):
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _):
assert isinstance(call, APICall)
user_id = call.data["user"]
update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
def get_user_preferences(call):
def get_user_preferences(call: APICall, company_id):
user_id = call.identity.user
preferences = get_user(call, user_id, ["preferences"]).get("preferences")
preferences = get_user(call, company_id, user_id, only=["preferences"]).get(
"preferences"
)
if preferences and isinstance(preferences, str):
preferences = loads(preferences)
return preferences or {}
@endpoint("users.get_preferences")
def get_preferences(call):
assert isinstance(call, APICall)
return {"preferences": get_user_preferences(call)}
def get_preferences(call: APICall, company_id, _):
return {"preferences": get_user_preferences(call, company_id)}
@endpoint("users.set_preferences", request_data_model=SetPreferencesRequest)
def set_preferences(call, company_id, req_model):
# type: (APICall, str, SetPreferencesRequest) -> Dict
assert isinstance(call, APICall)
changes = req_model.preferences
def set_preferences(call: APICall, company_id, request: SetPreferencesRequest):
changes = request.preferences
def invalid_key(_, key, __):
if not isinstance(key, str):
@ -187,7 +190,7 @@ def set_preferences(call, company_id, req_model):
remap(changes, visit=invalid_key)
base_preferences = get_user_preferences(call)
base_preferences = get_user_preferences(call, company_id)
new_preferences = deepcopy(base_preferences)
for key, value in changes.items():
try:

View File

@ -0,0 +1,89 @@
from typing import Sequence
from uuid import uuid4
from apierrors import errors
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestUsersService(TestService):
def setUp(self, version="2.8"):
super(TestUsersService, self).setUp(version=version)
self.company = self.api.users.get_current_user().user.company.id
def new_user(self):
user_name = uuid4().hex
user_id = self.api.auth.create_user(
company=self.company, name=user_name, email="{0}@{0}.com".format(user_name)
).id
self.defer(self.api.users.delete, user=user_id)
return user_id
def test_active_users(self):
user_1 = self.new_user()
user_2 = self.new_user()
user_3 = self.new_user()
model = (
self.api.impersonate(user_2)
.models.create(name="test", uri="file:///a", labels={})
.id
)
self.defer(self.api.models.delete, model=model)
project = self.create_temp("projects", name="users test", description="")
task = (
self.api.impersonate(user_3)
.tasks.create(
name="test", type="testing", input=dict(view={}), project=project
)
.id
)
self.defer(self.api.tasks.delete, task=task, move_to_trash=False)
user_ids = [user_1, user_2, user_3]
# no projects filtering
users = self.api.users.get_all_ex(id=user_ids).users
self._assertUsers((user_1, user_2, user_3), users)
# all projects
users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[]).users
self._assertUsers((user_2, user_3), users)
# specific project
users = self.api.users.get_all_ex(active_in_projects=[project]).users
self._assertUsers((user_3,), users)
def _assertUsers(self, expected: Sequence, users: Sequence):
self.assertEqual(set(expected), set(u.id for u in users))
def test_no_preferences(self):
user = self.new_user()
assert self.api.impersonate(user).users.get_preferences().preferences == {}
def _test_update(self, user, tests):
"""
Check that all for each (updates, expected_result) pair, ``updates`` yield ``result``.
"""
new_user_client = self.api.impersonate(user)
for update, expected in tests:
new_user_client.users.set_preferences(user=user, preferences=update)
preferences = new_user_client.users.get_preferences(user=user).preferences
self.assertEqual(preferences, expected)
def test_nested_update(self):
tests = [
({"a": 0}, {"a": 0}),
({"b": 1}, {"a": 0, "b": 1}),
({"section": {"a": 2}}, {"a": 0, "b": 1, "section": {"a": 2}}),
]
self._test_update(self.new_user(), tests)
def test_delete(self):
tests = [
({"section": {"a": 0, "b": 1}},) * 2,
({"section": {"a": None}}, {"section": {"a": None}}),
({"section": None}, {"section": None}),
]
self._test_update(self.new_user(), tests)