Support filtering users by activity in projects

This commit is contained in:
allegroai 2020-06-01 11:55:40 +03:00
parent 45d434a123
commit f8d8fc40a6
13 changed files with 201 additions and 95 deletions

View File

@ -1 +1 @@
__version__ = "2.7.0" __version__ = "2.8.0"

View File

@ -13,17 +13,21 @@
credentials { credentials {
# system credentials as they appear in the auth DB, used for intra-service communications # system credentials as they appear in the auth DB, used for intra-service communications
apiserver { apiserver {
role: "system"
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO" user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq" user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
} }
webserver { webserver {
role: "system"
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP" user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S" user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
revoke_in_fixed_mode: true
} }
tests { tests {
role: "user"
display_name: "Default User"
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW" user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8" user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
} }
} }
} }

View File

@ -43,6 +43,7 @@ class Role(object):
class Credentials(EmbeddedDocument): class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True) key = StringField(required=True)
secret = StringField(required=True) secret = StringField(required=True)
last_used = DateTimeField() last_used = DateTimeField()

View File

@ -3,6 +3,7 @@ from mongoengine import Document, StringField, DateTimeField, ListField, Boolean
from database import Database, strict from database import Database, strict
from database.fields import StrippedStringField, SafeDictField from database.fields import StrippedStringField, SafeDictField
from database.model import DbModelMixin from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.model_labels import ModelLabels from database.model.model_labels import ModelLabels
from database.model.company import Company from database.model.company import Company
from database.model.project import Project from database.model.project import Project
@ -19,6 +20,7 @@ class Model(DbModelMixin, Document):
"project", "project",
"task", "task",
("company", "name"), ("company", "name"),
("company", "user"),
{ {
"name": "%s.model.main_text_index" % Database.backend, "name": "%s.model.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"], "fields": ["$name", "$id", "$comment", "$parent", "$task", "$project"],
@ -34,6 +36,21 @@ class Model(DbModelMixin, Document):
}, },
], ],
} }
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"user",
"project",
"task",
"parent",
),
)
id = StringField(primary_key=True) id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3) name = StrippedStringField(user_set_allowed=True, min_length=3)

View File

@ -18,7 +18,7 @@ from database.fields import (
SafeSortedListField, SafeSortedListField,
) )
from database.model import AttributedDocument from database.model import AttributedDocument
from database.model.base import ProperDictMixin from database.model.base import ProperDictMixin, GetMixin
from database.model.model_labels import ModelLabels from database.model.model_labels import ModelLabels
from database.model.project import Project from database.model.project import Project
from database.utils import get_options from database.utils import get_options
@ -113,6 +113,7 @@ class Task(AttributedDocument):
"parent", "parent",
"project", "project",
("company", "name"), ("company", "name"),
("company", "user"),
("company", "type", "system_tags", "status"), ("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"), ("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks ("status", "last_update"), # for maintenance tasks
@ -140,6 +141,12 @@ class Task(AttributedDocument):
}, },
], ],
} }
get_all_query_options = GetMixin.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
id = StringField(primary_key=True) id = StringField(primary_key=True)
name = StrippedStringField( name = StrippedStringField(

View File

@ -2,14 +2,16 @@ from mongoengine import Document, StringField, DynamicField
from database import Database, strict from database import Database, strict
from database.model import DbModelMixin from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.company import Company from database.model.company import Company
class User(DbModelMixin, Document): class User(DbModelMixin, Document):
meta = { meta = {
'db_alias': Database.backend, "db_alias": Database.backend,
'strict': strict, "strict": strict,
} }
get_all_query_options = GetMixin.QueryParameterOptions(list_fields=("id",))
id = StringField(primary_key=True) id = StringField(primary_key=True)
company = StringField(required=True, reference_field=Company) company = StringField(required=True, reference_field=Company)

View File

@ -38,29 +38,20 @@ def init_mongo_data():
PrePopulate.import_from_zip(zip_file, user_id=user_id) PrePopulate.import_from_zip(zip_file, user_id=user_id)
users = [
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
"revoke_in_fixed_mode": True,
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
fixed_mode = FixedUser.enabled() fixed_mode = FixedUser.enabled()
for user in users: for user, credentials in config.get("secure.credentials", {}).items():
revoke = fixed_mode and user.pop("revoke_in_fixed_mode", False) user_data = {
credentials = config.get(f"secure.credentials.{user['name']}") "name": user,
user["key"] = credentials.user_key "role": credentials.role,
user["secret"] = credentials.user_secret "email": f"{user}@example.com",
_ensure_auth_user(user, company_id, log=log, revoke=revoke) "key": credentials.user_key,
"secret": credentials.user_secret,
}
revoke = fixed_mode and credentials.get("revoke_in_fixed_mode", False)
user_id = _ensure_auth_user(user_data, company_id, log=log, revoke=revoke)
if credentials.role == Role.user:
_ensure_backend_user(user_id, company_id, credentials.display_name)
if fixed_mode: if fixed_mode:
log.info("Fixed users mode is enabled") log.info("Fixed users mode is enabled")

View File

@ -159,6 +159,11 @@
description: "Get only models whose name matches this pattern (python regular expression syntax)" description: "Get only models whose name matches this pattern (python regular expression syntax)"
type: string type: string
} }
user {
description: "List of user IDs used to filter results by the model's creating user"
type: array
items { type: string }
}
ready { ready {
description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects." description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects."
type: boolean type: boolean

View File

@ -145,6 +145,19 @@ get_all_ex {
internal: true internal: true
"2.1": ${get_all."2.1"} { "2.1": ${get_all."2.1"} {
} }
"2.8": ${get_all."2.1"} {
request {
type: object
properties {
active_in_projects {
description: "List of project IDs. If provided, return only users that were active in these projects. If empty list is provided, return users that were active in all projects"
type: array
items { type: string }
}
}
}
}
} }
get_all { get_all {

View File

@ -29,20 +29,6 @@ from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext from timing_context import TimingContext
log = config.logger(__file__) log = config.logger(__file__)
get_all_query_options = Model.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"project",
"task",
"parent",
),
)
@endpoint("models.get_by_id", required_fields=["model"]) @endpoint("models.get_by_id", required_fields=["model"])
@ -103,10 +89,7 @@ def get_all_ex(call: APICall):
with translate_errors_context(): with translate_errors_context():
with TimingContext("mongo", "models_get_all_ex"): with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join( models = Model.get_many_with_join(
company=call.identity.company, company=call.identity.company, query_dict=call.data, allow_public=True
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
) )
conform_output_tags(call, models) conform_output_tags(call, models)
call.result.data = {"models": models} call.result.data = {"models": models}
@ -122,7 +105,6 @@ def get_all(call: APICall):
parameters=call.data, parameters=call.data,
query_dict=call.data, query_dict=call.data,
allow_public=True, allow_public=True,
query_options=get_all_query_options,
) )
conform_output_tags(call, models) conform_output_tags(call, models)
call.result.data = {"models": models} call.result.data = {"models": models}

View File

@ -59,12 +59,6 @@ from utilities import safe_get
task_fields = set(Task.get_fields()) task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script)) task_script_fields = set(get_fields(Script))
get_all_query_options = Task.QueryParameterOptions(
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
)
task_bll = TaskBLL() task_bll = TaskBLL()
event_bll = EventBLL() event_bll = EventBLL()
@ -145,7 +139,6 @@ def get_all_ex(call: APICall):
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=call.identity.company, company=call.identity.company,
query_dict=call.data, query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions allow_public=True, # required in case projection is requested for public dataset/versions
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
@ -164,7 +157,6 @@ def get_all(call: APICall):
company=call.identity.company, company=call.identity.company,
parameters=call.data, parameters=call.data,
query_dict=call.data, query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions allow_public=True, # required in case projection is requested for public dataset/versions
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)

View File

@ -1,5 +1,5 @@
from copy import deepcopy from copy import deepcopy
from typing import Dict, Tuple from typing import Tuple
import dpath import dpath
from boltons.iterutils import remap from boltons.iterutils import remap
@ -8,6 +8,7 @@ from mongoengine import Q
from apierrors import errors from apierrors import errors
from apimodels.base import UpdateResponse from apimodels.base import UpdateResponse
from apimodels.users import CreateRequest, SetPreferencesRequest from apimodels.users import CreateRequest, SetPreferencesRequest
from bll.project import ProjectBLL
from bll.user import UserBLL from bll.user import UserBLL
from config import config from config import config
from database.errors import translate_errors_context from database.errors import translate_errors_context
@ -19,10 +20,10 @@ from service_repo import APICall, endpoint
from utilities.json import loads, dumps from utilities.json import loads, dumps
log = config.logger(__file__) log = config.logger(__file__)
get_all_query_options = User.QueryParameterOptions(list_fields=("id",)) project_bll = ProjectBLL()
def get_user(call, user_id, only=None): def get_user(call, company_id, user_id, only=None):
""" """
Get user object by the user's ID Get user object by the user's ID
:param call: API call :param call: API call
@ -34,7 +35,7 @@ def get_user(call, user_id, only=None):
# allow system users to get info for all users # allow system users to get info for all users
query = dict(id=user_id) query = dict(id=user_id)
else: else:
query = dict(id=user_id, company=call.identity.company) query = dict(id=user_id, company=company_id)
with translate_errors_context("retrieving user"): with translate_errors_context("retrieving user"):
user = User.objects(**query) user = User.objects(**query)
@ -48,47 +49,53 @@ def get_user(call, user_id, only=None):
@endpoint("users.get_by_id", required_fields=["user"]) @endpoint("users.get_by_id", required_fields=["user"])
def get_by_id(call): def get_by_id(call: APICall, company_id, _):
assert isinstance(call, APICall)
user_id = call.data["user"] user_id = call.data["user"]
call.result.data = {"user": get_user(call, user_id)} call.result.data = {"user": get_user(call, company_id, user_id)}
@endpoint("users.get_all_ex", required_fields=[]) @endpoint("users.get_all_ex", required_fields=[])
def get_all_ex(call): def get_all_ex(call: APICall, company_id, _):
assert isinstance(call, APICall)
with translate_errors_context("retrieving users"): with translate_errors_context("retrieving users"):
res = User.get_many_with_join( res = User.get_many_with_join(company=company_id, query_dict=call.data)
company=call.identity.company,
query_dict=call.data, call.result.data = {"users": res}
query_options=get_all_query_options,
@endpoint("users.get_all_ex", min_version="2.8", required_fields=[])
def get_all_ex2_8(call: APICall, company_id, _):
with translate_errors_context("retrieving users"):
data = call.data
active_in_projects = call.data.get("active_in_projects", None)
if active_in_projects is not None:
active_users = project_bll.get_active_users(
company_id, active_in_projects, call.data.get("id")
) )
active_users.discard(None)
if not active_users:
call.result.data = {"users": []}
return
data = data.copy()
data["id"] = list(active_users)
res = User.get_many_with_join(company=company_id, query_dict=data)
call.result.data = {"users": res} call.result.data = {"users": res}
@endpoint("users.get_all", required_fields=[]) @endpoint("users.get_all", required_fields=[])
def get_all(call): def get_all(call: APICall, company_id, _):
assert isinstance(call, APICall)
with translate_errors_context("retrieving users"): with translate_errors_context("retrieving users"):
res = User.get_many( res = User.get_many(
company=call.identity.company, company=company_id, parameters=call.data, query_dict=call.data
parameters=call.data,
query_dict=call.data,
query_options=get_all_query_options,
) )
call.result.data = {"users": res} call.result.data = {"users": res}
@endpoint("users.get_current_user") @endpoint("users.get_current_user")
def get_current_user(call): def get_current_user(call: APICall, company_id, _):
assert isinstance(call, APICall)
with translate_errors_context("retrieving users"): with translate_errors_context("retrieving users"):
projection = ( projection = (
{"company.name"} {"company.name"}
.union(User.get_fields()) .union(User.get_fields())
@ -96,7 +103,7 @@ def get_current_user(call):
) )
res = User.get_many_with_join( res = User.get_many_with_join(
query=Q(id=call.identity.user), query=Q(id=call.identity.user),
company=call.identity.company, company=company_id,
override_projection=projection, override_projection=projection,
) )
@ -126,13 +133,11 @@ def create(call: APICall):
@endpoint("users.delete", required_fields=["user"]) @endpoint("users.delete", required_fields=["user"])
def delete(call): def delete(call: APICall):
assert isinstance(call, APICall)
UserBLL.delete(call.data["user"]) UserBLL.delete(call.data["user"])
def update_user(user_id, company_id, data): def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]:
# type: (str, str, Dict) -> Tuple[int, Dict]
""" """
Update user. Update user.
:param user_id: user ID to update :param user_id: user ID to update
@ -150,31 +155,29 @@ def update_user(user_id, company_id, data):
@endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse) @endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse)
def update(call, company_id, _): def update(call, company_id, _):
assert isinstance(call, APICall)
user_id = call.data["user"] user_id = call.data["user"]
update_count, updated_fields = update_user(user_id, company_id, call.data) update_count, updated_fields = update_user(user_id, company_id, call.data)
call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields) call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields)
def get_user_preferences(call): def get_user_preferences(call: APICall, company_id):
user_id = call.identity.user user_id = call.identity.user
preferences = get_user(call, user_id, ["preferences"]).get("preferences") preferences = get_user(call, company_id, user_id, only=["preferences"]).get(
"preferences"
)
if preferences and isinstance(preferences, str): if preferences and isinstance(preferences, str):
preferences = loads(preferences) preferences = loads(preferences)
return preferences or {} return preferences or {}
@endpoint("users.get_preferences") @endpoint("users.get_preferences")
def get_preferences(call): def get_preferences(call: APICall, company_id, _):
assert isinstance(call, APICall) return {"preferences": get_user_preferences(call, company_id)}
return {"preferences": get_user_preferences(call)}
@endpoint("users.set_preferences", request_data_model=SetPreferencesRequest) @endpoint("users.set_preferences", request_data_model=SetPreferencesRequest)
def set_preferences(call, company_id, req_model): def set_preferences(call: APICall, company_id, request: SetPreferencesRequest):
# type: (APICall, str, SetPreferencesRequest) -> Dict changes = request.preferences
assert isinstance(call, APICall)
changes = req_model.preferences
def invalid_key(_, key, __): def invalid_key(_, key, __):
if not isinstance(key, str): if not isinstance(key, str):
@ -187,7 +190,7 @@ def set_preferences(call, company_id, req_model):
remap(changes, visit=invalid_key) remap(changes, visit=invalid_key)
base_preferences = get_user_preferences(call) base_preferences = get_user_preferences(call, company_id)
new_preferences = deepcopy(base_preferences) new_preferences = deepcopy(base_preferences)
for key, value in changes.items(): for key, value in changes.items():
try: try:

View File

@ -0,0 +1,89 @@
from typing import Sequence
from uuid import uuid4
from apierrors import errors
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestUsersService(TestService):
def setUp(self, version="2.8"):
super(TestUsersService, self).setUp(version=version)
self.company = self.api.users.get_current_user().user.company.id
def new_user(self):
user_name = uuid4().hex
user_id = self.api.auth.create_user(
company=self.company, name=user_name, email="{0}@{0}.com".format(user_name)
).id
self.defer(self.api.users.delete, user=user_id)
return user_id
def test_active_users(self):
user_1 = self.new_user()
user_2 = self.new_user()
user_3 = self.new_user()
model = (
self.api.impersonate(user_2)
.models.create(name="test", uri="file:///a", labels={})
.id
)
self.defer(self.api.models.delete, model=model)
project = self.create_temp("projects", name="users test", description="")
task = (
self.api.impersonate(user_3)
.tasks.create(
name="test", type="testing", input=dict(view={}), project=project
)
.id
)
self.defer(self.api.tasks.delete, task=task, move_to_trash=False)
user_ids = [user_1, user_2, user_3]
# no projects filtering
users = self.api.users.get_all_ex(id=user_ids).users
self._assertUsers((user_1, user_2, user_3), users)
# all projects
users = self.api.users.get_all_ex(id=user_ids, active_in_projects=[]).users
self._assertUsers((user_2, user_3), users)
# specific project
users = self.api.users.get_all_ex(active_in_projects=[project]).users
self._assertUsers((user_3,), users)
def _assertUsers(self, expected: Sequence, users: Sequence):
self.assertEqual(set(expected), set(u.id for u in users))
def test_no_preferences(self):
user = self.new_user()
assert self.api.impersonate(user).users.get_preferences().preferences == {}
def _test_update(self, user, tests):
"""
Check that all for each (updates, expected_result) pair, ``updates`` yield ``result``.
"""
new_user_client = self.api.impersonate(user)
for update, expected in tests:
new_user_client.users.set_preferences(user=user, preferences=update)
preferences = new_user_client.users.get_preferences(user=user).preferences
self.assertEqual(preferences, expected)
def test_nested_update(self):
tests = [
({"a": 0}, {"a": 0}),
({"b": 1}, {"a": 0, "b": 1}),
({"section": {"a": 2}}, {"a": 0, "b": 1, "section": {"a": 2}}),
]
self._test_update(self.new_user(), tests)
def test_delete(self):
tests = [
({"section": {"a": 0, "b": 1}},) * 2,
({"section": {"a": None}}, {"section": {"a": None}}),
({"section": None}, {"section": None}),
]
self._test_update(self.new_user(), tests)