from datetime import datetime from typing import Sequence import attr from mongoengine import Q from apiserver.apierrors import errors from apiserver.apierrors.errors.bad_request import InvalidProjectId from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.projects import ( GetHyperParamReq, ProjectReq, ProjectTagsRequest, ProjectTaskParentsRequest, ProjectHyperparamValuesRequest, ProjectsGetRequest, DeleteRequest, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL from apiserver.bll.project.project_cleanup import delete_project from apiserver.bll.task import TaskBLL from apiserver.database.errors import translate_errors_context from apiserver.database.model.project import Project from apiserver.database.utils import ( parse_from_call, get_company_or_none_constraint, ) from apiserver.service_repo import APICall, endpoint from apiserver.services.utils import ( conform_tag_fields, conform_output_tags, get_tags_filter_dictionary, get_tags_response, ) from apiserver.timing_context import TimingContext org_bll = OrgBLL() task_bll = TaskBLL() project_bll = ProjectBLL() create_fields = { "name": None, "description": None, "tags": list, "system_tags": list, "default_output_destination": None, } get_all_query_options = Project.QueryParameterOptions( pattern_fields=("name", "description"), list_fields=("tags", "system_tags", "id"), ) @endpoint("projects.get_by_id", required_fields=["project"]) def get_by_id(call): assert isinstance(call, APICall) project_id = call.data["project"] with translate_errors_context(): with TimingContext("mongo", "projects_by_id"): query = Q(id=project_id) & get_company_or_none_constraint( call.identity.company ) project = Project.objects(query).first() if not project: raise errors.bad_request.InvalidProjectId(id=project_id) project_dict = project.to_proper_dict() conform_output_tags(call, project_dict) call.result.data = {"project": project_dict} @endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): conform_tag_fields(call, call.data) allow_public = not request.non_public with TimingContext("mongo", "projects_get_all"): if request.active_users: ids = project_bll.get_projects_with_active_user( company=company_id, users=request.active_users, project_ids=call.data.get("id"), allow_public=allow_public, ) if not ids: call.result.data = {"projects": []} return call.data["id"] = ids projects = Project.get_many_with_join( company=company_id, query_dict=call.data, query_options=get_all_query_options, allow_public=allow_public, ) conform_output_tags(call, projects) if not request.include_stats: call.result.data = {"projects": projects} return project_ids = {project["id"] for project in projects} stats = project_bll.get_project_stats( company=company_id, project_ids=list(project_ids), specific_state=request.stats_for_state, ) for project in projects: project["stats"] = stats[project["id"]] call.result.data = {"projects": projects} @endpoint("projects.get_all") def get_all(call: APICall): conform_tag_fields(call, call.data) with translate_errors_context(), TimingContext("mongo", "projects_get_all"): projects = Project.get_many( company=call.identity.company, query_dict=call.data, query_options=get_all_query_options, parameters=call.data, allow_public=True, ) conform_output_tags(call, projects) call.result.data = {"projects": projects} @endpoint( "projects.create", required_fields=["name", "description"], response_data_model=IdResponse, ) def create(call: APICall): identity = call.identity with translate_errors_context(): fields = parse_from_call(call.data, create_fields, Project.get_fields()) conform_tag_fields(call, fields, validate=True) return IdResponse( id=ProjectBLL.create( user=identity.user, company=identity.company, **fields, ) ) @endpoint( "projects.update", required_fields=["project"], response_data_model=UpdateResponse ) def update(call: APICall): """ update :summary: Update project information. See `project.create` for parameters. :return: updated - `int` - number of projects updated fields - `[string]` - updated fields """ project_id = call.data["project"] with translate_errors_context(): project = Project.get_for_writing(company=call.identity.company, id=project_id) if not project: raise errors.bad_request.InvalidProjectId(id=project_id) fields = parse_from_call( call.data, create_fields, Project.get_fields(), discard_none_values=False ) conform_tag_fields(call, fields, validate=True) fields["last_update"] = datetime.utcnow() with TimingContext("mongo", "projects_update"): updated = project.update(upsert=False, **fields) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) def _reset_cached_tags(company: str, projects: Sequence[str]): org_bll.reset_tags(company, Tags.Task, projects=projects) org_bll.reset_tags(company, Tags.Model, projects=projects) @endpoint("projects.delete", request_data_model=DeleteRequest) def delete(call: APICall, company_id: str, request: DeleteRequest): res = delete_project( company=company_id, project_id=request.project, force=request.force, delete_contents=request.delete_contents, ) _reset_cached_tags(company_id, projects=[request.project]) call.result.data = {**attr.asdict(res)} @endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq) def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq): metrics = task_bll.get_unique_metric_variants( company_id, [request.project] if request.project else None ) call.result.data = {"metrics": metrics} @endpoint( "projects.get_hyper_parameters", min_version="2.9", request_data_model=GetHyperParamReq, ) def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq): total, remaining, parameters = TaskBLL.get_aggregated_project_parameters( company_id, project_ids=[request.project] if request.project else None, page=request.page, page_size=request.page_size, ) call.result.data = { "total": total, "remaining": remaining, "parameters": parameters, } @endpoint( "projects.get_hyperparam_values", min_version="2.13", request_data_model=ProjectHyperparamValuesRequest, ) def get_hyperparam_values( call: APICall, company_id: str, request: ProjectHyperparamValuesRequest ): total, values = task_bll.get_hyperparam_distinct_values( company_id, project_ids=request.projects, section=request.section, name=request.name, allow_public=request.allow_public, ) call.result.data = { "total": total, "values": values, } @endpoint( "projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest ) def get_tags(call: APICall, company, request: ProjectTagsRequest): ret = org_bll.get_tags( company, Tags.Task, include_system=request.include_system, filter_=get_tags_filter_dictionary(request.filter), projects=request.projects, ) call.result.data = get_tags_response(ret) @endpoint( "projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest ) def get_tags(call: APICall, company, request: ProjectTagsRequest): ret = org_bll.get_tags( company, Tags.Model, include_system=request.include_system, filter_=get_tags_filter_dictionary(request.filter), projects=request.projects, ) call.result.data = get_tags_response(ret) @endpoint( "projects.make_public", min_version="2.9", request_data_model=MakePublicRequest ) def make_public(call: APICall, company_id, request: MakePublicRequest): call.result.data = Project.set_public( company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True ) @endpoint( "projects.make_private", min_version="2.9", request_data_model=MakePublicRequest ) def make_public(call: APICall, company_id, request: MakePublicRequest): call.result.data = Project.set_public( company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False ) @endpoint( "projects.get_task_parents", min_version="2.12", request_data_model=ProjectTaskParentsRequest, ) def get_task_parents( call: APICall, company_id: str, request: ProjectTaskParentsRequest ): call.result.data = { "parents": org_bll.get_parent_tasks( company_id, projects=request.projects, state=request.tasks_state ) }