from copy import deepcopy from typing import Tuple import dpath from boltons.iterutils import remap from mongoengine import Q from apiserver.apierrors import errors from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.users import CreateRequest, SetPreferencesRequest from apiserver.bll.project import ProjectBLL from apiserver.bll.user import UserBLL from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.auth import Role from apiserver.database.model.company import Company from apiserver.database.model.user import User from apiserver.database.utils import parse_from_call from apiserver.service_repo import APICall, endpoint from apiserver.utilities.json import loads, dumps log = config.logger(__file__) project_bll = ProjectBLL() def get_user(call, company_id, user_id, only=None): """ Get user object by the user's ID :param call: API call :param user_id: user ID :param only: fields to include in projection, by default all :return: User object """ if call.identity.role in (Role.system,): # allow system users to get info for all users query = dict(id=user_id) else: query = dict(id=user_id, company=company_id) with translate_errors_context("retrieving user"): user = User.objects(**query) if only: user = user.only(*only) res = user.first() if not res: raise errors.bad_request.InvalidUserId(**query) return res.to_proper_dict() @endpoint("users.get_by_id", required_fields=["user"]) def get_by_id(call: APICall, company_id, _): user_id = call.data["user"] call.result.data = {"user": get_user(call, company_id, user_id)} @endpoint("users.get_all_ex", required_fields=[]) def get_all_ex(call: APICall, company_id, _): with translate_errors_context("retrieving users"): res = User.get_many_with_join(company=company_id, query_dict=call.data) call.result.data = {"users": res} @endpoint("users.get_all_ex", min_version="2.8", required_fields=[]) def get_all_ex2_8(call: APICall, company_id, _): with translate_errors_context("retrieving users"): data = call.data active_in_projects = call.data.get("active_in_projects", None) if active_in_projects is not None: active_users = project_bll.get_active_users( company_id, active_in_projects, call.data.get("id") ) active_users.discard(None) if not active_users: call.result.data = {"users": []} return data = data.copy() data["id"] = list(active_users) res = User.get_many_with_join(company=company_id, query_dict=data) call.result.data = {"users": res} @endpoint("users.get_all", required_fields=[]) def get_all(call: APICall, company_id, _): with translate_errors_context("retrieving users"): res = User.get_many( company=company_id, parameters=call.data, query_dict=call.data ) call.result.data = {"users": res} @endpoint("users.get_current_user") def get_current_user(call: APICall, company_id, _): with translate_errors_context("retrieving users"): projection = ( {"company.name"} .union(User.get_fields()) .difference(User.get_exclude_fields()) ) res = User.get_many_with_join( query=Q(id=call.identity.user), company=company_id, override_projection=projection, ) if not res: raise errors.bad_request.InvalidUser("failed loading user") user = res[0] user["role"] = call.identity.role resp = {"user": user} call.result.data = resp create_fields = { "name": None, "family_name": None, "given_name": None, "avatar": None, "company": Company, "preferences": dict, } @endpoint("users.create", request_data_model=CreateRequest) def create(call: APICall): UserBLL.create(call.data_model) @endpoint("users.delete", required_fields=["user"]) def delete(call: APICall): UserBLL.delete(call.data["user"]) def update_user(user_id, company_id, data: dict) -> Tuple[int, dict]: """ Update user. :param user_id: user ID to update :param company_id: ID of company user belongs to :param data: mapping to update user by :return: (updated fields count, updated fields) pair """ update_fields = { k: v for k, v in create_fields.items() if k in User.user_set_allowed() } partial_update_dict = parse_from_call(data, update_fields, User.get_fields()) with translate_errors_context("updating user"): return User.safe_update(company_id, user_id, partial_update_dict) @endpoint("users.update", required_fields=["user"], response_data_model=UpdateResponse) def update(call, company_id, _): user_id = call.data["user"] update_count, updated_fields = update_user(user_id, company_id, call.data) call.result.data_model = UpdateResponse(updated=update_count, fields=updated_fields) def get_user_preferences(call: APICall, company_id): user_id = call.identity.user preferences = get_user(call, company_id, user_id, only=["preferences"]).get( "preferences" ) if preferences and isinstance(preferences, str): preferences = loads(preferences) return preferences or {} @endpoint("users.get_preferences") def get_preferences(call: APICall, company_id, _): return {"preferences": get_user_preferences(call, company_id)} @endpoint("users.set_preferences", request_data_model=SetPreferencesRequest) def set_preferences(call: APICall, company_id, request: SetPreferencesRequest): changes = request.preferences def invalid_key(_, key, __): if not isinstance(key, str): return True elif key.startswith("$") or "." in key: raise errors.bad_request.FieldsValueError( f"Key {key} is invalid. Keys cannot start with '$' or contain '.'." ) return True remap(changes, visit=invalid_key) base_preferences = get_user_preferences(call, company_id) new_preferences = deepcopy(base_preferences) for key, value in changes.items(): try: dpath.new(new_preferences, key, value, separator=".") except Exception: log.exception( 'invalid preferences update for user "{}": key=`%s`, value=`%s`', key, value, ) raise errors.bad_request.InvalidPreferencesUpdate(key=key, value=value) if new_preferences == base_preferences: updated, fields = 0, {} else: with translate_errors_context("updating user preferences"): updated = User.objects(id=call.identity.user, company=company_id).update( upsert=False, preferences=dumps(new_preferences) ) return { "updated": updated, "fields": {"preferences": new_preferences} if updated else {}, }