From f8d8fc40a6bfa01a5d90291c73788031ee1e234d Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 1 Jun 2020 11:55:40 +0300 Subject: [PATCH] Support filtering users by activity in projects --- server/api_version.py | 2 +- server/config/default/secure.conf | 6 +- server/database/model/auth.py | 1 + server/database/model/model.py | 17 ++++++ server/database/model/task/task.py | 9 ++- server/database/model/user.py | 6 +- server/mongo/initialize/__init__.py | 33 ++++------- server/schema/services/models.conf | 5 ++ server/schema/services/users.conf | 13 ++++ server/services/models.py | 20 +------ server/services/tasks.py | 8 --- server/services/users.py | 87 ++++++++++++++------------- server/tests/automated/test_users.py | 89 ++++++++++++++++++++++++++++ 13 files changed, 201 insertions(+), 95 deletions(-) create mode 100644 server/tests/automated/test_users.py diff --git a/server/api_version.py b/server/api_version.py index 2614ce9..892994a 100644 --- a/server/api_version.py +++ b/server/api_version.py @@ -1 +1 @@ -__version__ = "2.7.0" +__version__ = "2.8.0" diff --git a/server/config/default/secure.conf b/server/config/default/secure.conf index e50d339..b6d05be 100644 --- a/server/config/default/secure.conf +++ b/server/config/default/secure.conf @@ -13,17 +13,21 @@ credentials { # system credentials as they appear in the auth DB, used for intra-service communications apiserver { + role: "system" user_key: "62T8CP7HGBC6647XF9314C2VY67RJO" user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq" } webserver { + role: "system" user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP" user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S" + revoke_in_fixed_mode: true } tests { + role: "user" + display_name: "Default User" user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" - } } } \ No newline at end of file diff --git a/server/database/model/auth.py b/server/database/model/auth.py index c548b28..9dd0b39 100644 --- a/server/database/model/auth.py +++ b/server/database/model/auth.py @@ -43,6 +43,7 @@ class Role(object): class Credentials(EmbeddedDocument): + meta = {"strict": False} key = StringField(required=True) secret = StringField(required=True) last_used = DateTimeField() diff --git a/server/database/model/model.py b/server/database/model/model.py index 7aa7219..7d8f52b 100644 --- a/server/database/model/model.py +++ b/server/database/model/model.py @@ -3,6 +3,7 @@ from mongoengine import Document, StringField, DateTimeField, ListField, Boolean from database import Database, strict from database.fields import StrippedStringField, SafeDictField from database.model import DbModelMixin +from database.model.base import GetMixin from database.model.model_labels import ModelLabels from database.model.company import Company from database.model.project import Project @@ -19,6 +20,7 @@ class Model(DbModelMixin, Document): "project", "task", ("company", "name"), + ("company", "user"), { "name": "%s.model.main_text_index" % Database.backend, "fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"], @@ -34,6 +36,21 @@ class Model(DbModelMixin, Document): }, ], } + get_all_query_options = GetMixin.QueryParameterOptions( + pattern_fields=("name", "comment"), + fields=("ready",), + list_fields=( + "tags", + "system_tags", + "framework", + "uri", + "id", + "user", + "project", + "task", + "parent", + ), + ) id = StringField(primary_key=True) name = StrippedStringField(user_set_allowed=True, min_length=3) diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index b159233..b87feed 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -18,7 +18,7 @@ from database.fields import ( SafeSortedListField, ) from database.model import AttributedDocument -from database.model.base import ProperDictMixin +from database.model.base import ProperDictMixin, GetMixin from database.model.model_labels import ModelLabels from database.model.project import Project from database.utils import get_options @@ -113,6 +113,7 @@ class Task(AttributedDocument): "parent", "project", ("company", "name"), + ("company", "user"), ("company", "type", "system_tags", "status"), ("company", "project", "type", "system_tags", "status"), ("status", "last_update"), # for maintenance tasks @@ -140,6 +141,12 @@ class Task(AttributedDocument): }, ], } + get_all_query_options = GetMixin.QueryParameterOptions( + list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"), + datetime_fields=("status_changed",), + pattern_fields=("name", "comment"), + fields=("parent",), + ) id = StringField(primary_key=True) name = StrippedStringField( diff --git a/server/database/model/user.py b/server/database/model/user.py index 7981e02..e6031f5 100644 --- a/server/database/model/user.py +++ b/server/database/model/user.py @@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField from database import Database, strict from database.model import DbModelMixin +from database.model.base import GetMixin from database.model.company import Company class User(DbModelMixin, Document): meta = { - 'db_alias': Database.backend, - 'strict': strict, + "db_alias": Database.backend, + "strict": strict, } + get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",)) id = StringField(primary_key=True) company = StringField(required=True, reference_field=Company) diff --git a/server/mongo/initialize/__init__.py b/server/mongo/initialize/__init__.py index 7f4069b..e506aaf 100644 --- a/server/mongo/initialize/__init__.py +++ b/server/mongo/initialize/__init__.py @@ -38,29 +38,20 @@ def init_mongo_data(): PrePopulate.import_from_zip(zip_file, user_id=user_id) - users = [ - { - "name": "apiserver", - "role": Role.system, - "email": "apiserver@example.com", - }, - { - "name": "webserver", - "role": Role.system, - "email": "webserver@example.com", - "revoke_in_fixed_mode": True, - }, - {"name": "tests", "role": Role.user, "email": "tests@example.com"}, - ] - fixed_mode = FixedUser.enabled() - for user in users: - revoke = fixed_mode and user.pop("revoke_in_fixed_mode", False) - credentials = config.get(f"secure.credentials.{user['name']}") - user["key"] = credentials.user_key - user["secret"] = credentials.user_secret - _ensure_auth_user(user, company_id, log=log, revoke=revoke) + for user, credentials in config.get("secure.credentials", {}).items(): + user_data = { + "name": user, + "role": credentials.role, + "email": f"{user}@example.com", + "key": credentials.user_key, + "secret": credentials.user_secret, + } + revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False) + user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke) + if credentials.role == Role.user: + _ensure_backend_user(user_id, company_id, credentials.display_name) if fixed_mode: log.info("Fixed users mode is enabled") diff --git a/server/schema/services/models.conf b/server/schema/services/models.conf index 3932b17..2ed2db8 100644 --- a/server/schema/services/models.conf +++ b/server/schema/services/models.conf @@ -159,6 +159,11 @@ description: "Get only models whose name matches this pattern (python regular expression syntax)" type: string } + user { + description: "List of user IDs used to filter results by the model's creating user" + type: array + items { type: string } + } ready { description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects." type: boolean diff --git a/server/schema/services/users.conf b/server/schema/services/users.conf index d0bc130..9f254ca 100644 --- a/server/schema/services/users.conf +++ b/server/schema/services/users.conf @@ -145,6 +145,19 @@ get_all_ex { internal: true "2.1": ${get_all."2.1"} { } + "2.8": ${get_all."2.1"} { + request { + type: object + properties { + active_in_projects { + description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects" + type: array + items { type: string } + } + } + } + } + } get_all { diff --git a/server/services/models.py b/server/services/models.py index b9adc9d..2f20fde 100644 --- a/server/services/models.py +++ b/server/services/models.py @@ -29,20 +29,6 @@ from services.utils import conform_tag_fields, conform_output_tags from timing_context import TimingContext log = config.logger(__file__) -get_all_query_options = Model.QueryParameterOptions( - pattern_fields=("name", "comment"), - fields=("ready",), - list_fields=( - "tags", - "system_tags", - "framework", - "uri", - "id", - "project", - "task", - "parent", - ), -) @endpoint("models.get_by_id", required_fields=["model"]) @@ -103,10 +89,7 @@ def get_all_ex(call: APICall): with translate_errors_context(): with TimingContext("mongo", "models_get_all_ex"): models = Model.get_many_with_join( - company=call.identity.company, - query_dict=call.data, - allow_public=True, - query_options=get_all_query_options, + company=call.identity.company, query_dict=call.data, allow_public=True ) conform_output_tags(call, models) call.result.data = {"models": models} @@ -122,7 +105,6 @@ def get_all(call: APICall): parameters=call.data, query_dict=call.data, allow_public=True, - query_options=get_all_query_options, ) conform_output_tags(call, models) call.result.data = {"models": models} diff --git a/server/services/tasks.py b/server/services/tasks.py index be9ac5c..f99af1e 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -59,12 +59,6 @@ from utilities import safe_get task_fields = set(Task.get_fields()) task_script_fields = set(get_fields(Script)) -get_all_query_options = Task.QueryParameterOptions( - list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"), - datetime_fields=("status_changed",), - pattern_fields=("name", "comment"), - fields=("parent",), -) task_bll = TaskBLL() event_bll = EventBLL() @@ -145,7 +139,6 @@ def get_all_ex(call: APICall): tasks = Task.get_many_with_join( company=call.identity.company, query_dict=call.data, - query_options=get_all_query_options, allow_public=True, # required in case projection is requested for public dataset/versions ) unprepare_from_saved(call, tasks) @@ -164,7 +157,6 @@ def get_all(call: APICall): company=call.identity.company, parameters=call.data, query_dict=call.data, - query_options=get_all_query_options, allow_public=True, # required in case projection is requested for public dataset/versions ) unprepare_from_saved(call, tasks) diff --git a/server/services/users.py b/server/services/users.py index 51b31ec..d240b8d 100644 --- a/server/services/users.py +++ b/server/services/users.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, Tuple +from typing import Tuple import dpath from boltons.iterutils import remap @@ -8,6 +8,7 @@ from mongoengine import Q from apierrors import errors from apimodels.base import UpdateResponse from apimodels.users import CreateRequest, SetPreferencesRequest +from bll.project import ProjectBLL from bll.user import UserBLL from config import config from database.errors import translate_errors_context @@ -19,10 +20,10 @@ from service_repo import APICall, endpoint from utilities.json import loads, dumps log = config.logger(__file__) -get_all_query_options = User.QueryParameterOptions(list_fields=("id",)) +project_bll = ProjectBLL() -def get_user(call, user_id, only=None): +def get_user(call, company_id, user_id, only=None): """ Get user object by the user's ID :param call: API call @@ -34,7 +35,7 @@ def get_user(call, user_id, only=None): # allow system users to get info for all users query = dict(id=user_id) else: - query = dict(id=user_id, company=call.identity.company) + query = dict(id=user_id, company=company_id) with translate_errors_context("retrieving user"): user = User.objects(**query) @@ -48,47 +49,53 @@ def get_user(call, user_id, only=None): @endpoint("users.get_by_id", required_fields=["user"]) -def get_by_id(call): - assert isinstance(call, APICall) +def get_by_id(call: APICall, company_id, _): user_id = call.data["user"] - call.result.data = {"user": get_user(call, user_id)} + call.result.data = {"user": get_user(call, company_id, user_id)} @endpoint("users.get_all_ex", required_fields=[]) -def get_all_ex(call): - assert isinstance(call, APICall) - +def get_all_ex(call: APICall, company_id, _): with translate_errors_context("retrieving users"): - res = User.get_many_with_join( - company=call.identity.company, - query_dict=call.data, - query_options=get_all_query_options, - ) + res = User.get_many_with_join(company=company_id, query_dict=call.data) + + call.result.data = {"users": res} + + +@endpoint("users.get_all_ex", min_version="2.8", required_fields=[]) +def get_all_ex2_8(call: APICall, company_id, _): + with translate_errors_context("retrieving users"): + data = call.data + active_in_projects = call.data.get("active_in_projects", None) + if active_in_projects is not None: + active_users = project_bll.get_active_users( + company_id, active_in_projects, call.data.get("id") + ) + active_users.discard(None) + if not active_users: + call.result.data = {"users": []} + return + data = data.copy() + data["id"] = list(active_users) + + res = User.get_many_with_join(company=company_id, query_dict=data) call.result.data = {"users": res} @endpoint("users.get_all", required_fields=[]) -def get_all(call): - assert isinstance(call, APICall) - +def get_all(call: APICall, company_id, _): with translate_errors_context("retrieving users"): res = User.get_many( - company=call.identity.company, - parameters=call.data, - query_dict=call.data, - query_options=get_all_query_options, + company=company_id, parameters=call.data, query_dict=call.data ) call.result.data = {"users": res} @endpoint("users.get_current_user") -def get_current_user(call): - assert isinstance(call, APICall) - +def get_current_user(call: APICall, company_id, _): with translate_errors_context("retrieving users"): - projection = ( {"company.name"} .union(User.get_fields()) @@ -96,7 +103,7 @@ def get_current_user(call): ) res = User.get_many_with_join( query=Q(id=call.identity.user), - company=call.identity.company, + company=company_id, override_projection=projection, ) @@ -126,13 +133,11 @@ def create(call: APICall): @endpoint("users.delete", required_fields=["user"]) -def delete(call): - assert isinstance(call, APICall) +def delete(call: APICall): UserBLL.delete(call.data["user"]) -def update_user(user_id, company_id, data): - # type: (str, str, Dict) -> Tuple[int, Dict] +def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]: """ Update user. :param user_id: user ID to update @@ -150,31 +155,29 @@ def update_user(user_id, company_id, data): @endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse) def update(call, company_id, _): - assert isinstance(call, APICall) user_id = call.data["user"] update_count, updated_fields = update_user(user_id, company_id, call.data) call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields) -def get_user_preferences(call): +def get_user_preferences(call: APICall, company_id): user_id = call.identity.user - preferences = get_user(call, user_id, ["preferences"]).get("preferences") + preferences = get_user(call, company_id, user_id, only=["preferences"]).get( + "preferences" + ) if preferences and isinstance(preferences, str): preferences = loads(preferences) return preferences or {} @endpoint("users.get_preferences") -def get_preferences(call): - assert isinstance(call, APICall) - return {"preferences": get_user_preferences(call)} +def get_preferences(call: APICall, company_id, _): + return {"preferences": get_user_preferences(call, company_id)} @endpoint("users.set_preferences", request_data_model=SetPreferencesRequest) -def set_preferences(call, company_id, req_model): - # type: (APICall, str, SetPreferencesRequest) -> Dict - assert isinstance(call, APICall) - changes = req_model.preferences +def set_preferences(call: APICall, company_id, request: SetPreferencesRequest): + changes = request.preferences def invalid_key(_, key, __): if not isinstance(key, str): @@ -187,7 +190,7 @@ def set_preferences(call, company_id, req_model): remap(changes, visit=invalid_key) - base_preferences = get_user_preferences(call) + base_preferences = get_user_preferences(call, company_id) new_preferences = deepcopy(base_preferences) for key, value in changes.items(): try: diff --git a/server/tests/automated/test_users.py b/server/tests/automated/test_users.py new file mode 100644 index 0000000..ffa55ff --- /dev/null +++ b/server/tests/automated/test_users.py @@ -0,0 +1,89 @@ +from typing import Sequence +from uuid import uuid4 + +from apierrors import errors +from config import config +from tests.automated import TestService + +log = config.logger(__file__) + + +class TestUsersService(TestService): + def setUp(self, version="2.8"): + super(TestUsersService, self).setUp(version=version) + self.company = self.api.users.get_current_user().user.company.id + + def new_user(self): + user_name = uuid4().hex + user_id = self.api.auth.create_user( + company=self.company, name=user_name, email="{0}@{0}.com".format(user_name) + ).id + self.defer(self.api.users.delete, user=user_id) + return user_id + + def test_active_users(self): + user_1 = self.new_user() + user_2 = self.new_user() + user_3 = self.new_user() + + model = ( + self.api.impersonate(user_2) + .models.create(name="test", uri="file:///a", labels={}) + .id + ) + self.defer(self.api.models.delete, model=model) + project = self.create_temp("projects", name="users test", description="") + task = ( + self.api.impersonate(user_3) + .tasks.create( + name="test", type="testing", input=dict(view={}), project=project + ) + .id + ) + self.defer(self.api.tasks.delete, task=task, move_to_trash=False) + + user_ids = [user_1, user_2, user_3] + # no projects filtering + users = self.api.users.get_all_ex(id=user_ids).users + self._assertUsers((user_1, user_2, user_3), users) + + # all projects + users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[]).users + self._assertUsers((user_2, user_3), users) + + # specific project + users = self.api.users.get_all_ex(active_in_projects=[project]).users + self._assertUsers((user_3,), users) + + def _assertUsers(self, expected: Sequence, users: Sequence): + self.assertEqual(set(expected), set(u.id for u in users)) + + def test_no_preferences(self): + user = self.new_user() + assert self.api.impersonate(user).users.get_preferences().preferences == {} + + def _test_update(self, user, tests): + """ + Check that all for each (updates, expected_result) pair, ``updates`` yield ``result``. + """ + new_user_client = self.api.impersonate(user) + for update, expected in tests: + new_user_client.users.set_preferences(user=user, preferences=update) + preferences = new_user_client.users.get_preferences(user=user).preferences + self.assertEqual(preferences, expected) + + def test_nested_update(self): + tests = [ + ({"a": 0}, {"a": 0}), + ({"b": 1}, {"a": 0, "b": 1}), + ({"section": {"a": 2}}, {"a": 0, "b": 1, "section": {"a": 2}}), + ] + self._test_update(self.new_user(), tests) + + def test_delete(self): + tests = [ + ({"section": {"a": 0, "b": 1}},) * 2, + ({"section": {"a": None}}, {"section": {"a": None}}), + ({"section": None}, {"section": None}), + ] + self._test_update(self.new_user(), tests)