diff --git a/apiserver/apierrors/errors.conf b/apiserver/apierrors/errors.conf index a78d780..e604822 100644 --- a/apiserver/apierrors/errors.conf +++ b/apiserver/apierrors/errors.conf @@ -1,3 +1,8 @@ +301 { + _: "moved_permanently" + 1: ["not_supported", "this endpoint is no longer supported for the requested API version"] +} + 400 { _: "bad_request" 1: ["not_supported", "endpoint is not supported"] diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 454c5dc..fcc7f66 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -22,7 +22,12 @@ class DeleteRequest(ProjectRequest): delete_contents = fields.BoolField(default=False) -class GetHyperParamRequest(ProjectRequest): +class ProjectOrNoneRequest(models.Base): + project = fields.StringField() + include_subprojects = fields.BoolField(default=True) + + +class GetHyperParamRequest(ProjectOrNoneRequest): page = fields.IntField(default=0) page_size = fields.IntField(default=500) @@ -33,6 +38,7 @@ class ProjectTagsRequest(TagsRequest): class MultiProjectRequest(models.Base): projects = fields.ListField(str) + include_subprojects = fields.BoolField(default=True) class ProjectTaskParentsRequest(MultiProjectRequest): diff --git a/apiserver/bll/project/project_bll.py b/apiserver/bll/project/project_bll.py index 236bb56..c692e4b 100644 --- a/apiserver/bll/project/project_bll.py +++ b/apiserver/bll/project/project_bll.py @@ -147,7 +147,7 @@ class ProjectBLL: user: str, company: str, name: str, - description: str, + description: str = "", tags: Sequence[str] = None, system_tags: Sequence[str] = None, default_output_destination: str = None, @@ -507,20 +507,24 @@ class ProjectBLL: company, project_ids: Sequence[str], user_ids: Optional[Sequence[str]] = None, - ) -> set: + ) -> Set[str]: """ Get the set of user ids that created tasks/models/dataviews in the given projects If project_ids is empty then all projects are examined If user_ids are passed then only subset of these users is returned """ with TimingContext("mongo", "active_users_in_projects"): - res = set() query = Q(company=company) + if user_ids: + query &= Q(user__in=user_ids) + + projects_query = query if project_ids: project_ids = _ids_with_children(project_ids) query &= Q(project__in=project_ids) - if user_ids: - query &= Q(user__in=user_ids) + projects_query &= Q(id__in=project_ids) + + res = set(Project.objects(projects_query).distinct(field="user")) for cls_ in (Task, Model): res |= set(cls_.objects(query).distinct(field="user")) @@ -545,10 +549,17 @@ class ProjectBLL: else: query &= Q(company=company) + user_projects_query = query if project_ids: - query &= Q(project__in=_ids_with_children(project_ids)) + ids_with_children = _ids_with_children(project_ids) + query &= Q(project__in=ids_with_children) + user_projects_query &= Q(id__in=ids_with_children) - res = Task.objects(query).distinct(field="project") + res = {p.id for p in Project.objects(user_projects_query).only("id")} + for cls_ in (Task, Model): + res |= set(cls_.objects(query).distinct(field="project")) + + res = list(res) if not res: return res @@ -563,6 +574,7 @@ class ProjectBLL: cls, company_id: str, projects: Sequence[str], + include_subprojects: bool, state: Optional[EntityVisibility] = None, ) -> Sequence[dict]: """ @@ -571,7 +583,8 @@ class ProjectBLL: """ query = Q(company=company_id) if projects: - projects = _ids_with_children(projects) + if include_subprojects: + projects = _ids_with_children(projects) query &= Q(project__in=projects) if state == EntityVisibility.archived: query &= Q(system_tags__in=[EntityVisibility.archived.value]) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 1335786..d3b2529 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -13,7 +13,7 @@ import apiserver.database.utils as dbutils from apiserver.apierrors import errors from apiserver.bll.queue import QueueBLL from apiserver.bll.organization import OrgBLL, Tags -from apiserver.bll.project import ProjectBLL +from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.config_repo import config from apiserver.database.errors import translate_errors_context from apiserver.database.model.model import Model @@ -37,7 +37,12 @@ from apiserver.timing_context import TimingContext from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from .artifacts import artifacts_prepare_for_save from .param_utils import params_prepare_for_save -from .utils import ChangeStatusRequest, validate_status_change, update_project_time, deleted_prefix +from .utils import ( + ChangeStatusRequest, + validate_status_change, + update_project_time, + deleted_prefix, +) log = config.logger(__file__) org_bll = OrgBLL() @@ -317,12 +322,19 @@ class TaskBLL: cls.validate_execution_model(task) @staticmethod - def get_unique_metric_variants(company_id, project_ids=None): + def get_unique_metric_variants( + company_id, project_ids: Sequence[str], include_subprojects: bool + ): + if project_ids: + if include_subprojects: + project_ids = project_ids_with_children(project_ids) + project_constraint = {"project": {"$in": project_ids}} + else: + project_constraint = {} pipeline = [ { "$match": dict( - company={"$in": [None, "", company_id]}, - **({"project": {"$in": project_ids}} if project_ids else {}), + company={"$in": [None, "", company_id]}, **project_constraint, ) }, {"$project": {"metrics": {"$objectToArray": "$last_metrics"}}}, @@ -601,11 +613,17 @@ class TaskBLL: @staticmethod def get_aggregated_project_parameters( company_id, - project_ids: Sequence[str] = None, + project_ids: Sequence[str], + include_subprojects: bool, page: int = 0, page_size: int = 500, ) -> Tuple[int, int, Sequence[dict]]: - + if project_ids: + if include_subprojects: + project_ids = project_ids_with_children(project_ids) + project_constraint = {"project": {"$in": project_ids}} + else: + project_constraint = {} page = max(0, page) page_size = max(1, page_size) pipeline = [ @@ -613,7 +631,7 @@ class TaskBLL: "$match": { "company": {"$in": [None, "", company_id]}, "hyperparams": {"$exists": True, "$gt": {}}, - **({"project": {"$in": project_ids}} if project_ids else {}), + **project_constraint, } }, {"$project": {"sections": {"$objectToArray": "$hyperparams"}}}, @@ -687,6 +705,7 @@ class TaskBLL: project_ids: Sequence[str], section: str, name: str, + include_subprojects: bool, allow_public: bool = True, ) -> HyperParamValues: if allow_public: @@ -694,6 +713,8 @@ class TaskBLL: else: company_constraint = {"company": company_id} if project_ids: + if include_subprojects: + project_ids = project_ids_with_children(project_ids) project_constraint = {"project": {"$in": project_ids}} else: project_constraint = {} diff --git a/apiserver/database/model/base.py b/apiserver/database/model/base.py index edc37ba..b201f86 100644 --- a/apiserver/database/model/base.py +++ b/apiserver/database/model/base.py @@ -193,8 +193,7 @@ class GetMixin(PropsMixin): """ Pop the parameters that match the specified patterns and return the dictionary of matching parameters - For backwards compatibility with the previous version of the code - the None values are discarded + Pop None parameters since they are not the real queries """ if not patterns: return {} @@ -351,11 +350,7 @@ class GetMixin(PropsMixin): q = RegexQ() for action in filter(None, actions): q &= RegexQ( - **{ - f"{mongoengine_field}__{action}": list( - set(filter(None, actions[action])) - ) - } + **{f"{mongoengine_field}__{action}": list(set(actions[action]))} ) if not allow_empty: diff --git a/apiserver/database/model/project.py b/apiserver/database/model/project.py index e4ff6d4..7588806 100644 --- a/apiserver/database/model/project.py +++ b/apiserver/database/model/project.py @@ -36,7 +36,7 @@ class Project(AttributedDocument): min_length=3, sparse=True, ) - description = StringField(required=True) + description = StringField() created = DateTimeField(required=True) tags = SafeSortedListField(StringField(required=True)) system_tags = SafeSortedListField(StringField(required=True)) diff --git a/apiserver/database/model/task/task.py b/apiserver/database/model/task/task.py index 0b68af7..28d09fe 100644 --- a/apiserver/database/model/task/task.py +++ b/apiserver/database/model/task/task.py @@ -115,7 +115,7 @@ class Execution(EmbeddedDocument, ProperDictMixin): framework = StringField() artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact)) docker_cmd = StringField() - queue = StringField() + queue = StringField(reference_field="Queue") """ Queue ID where task was queued """ diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index d7fd4d8..2bfaea3 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -151,6 +151,17 @@ get_by_id_ex { get_all_ex { internal: true "2.1": ${get_all."2.1"} + "2.13": ${get_all_ex."2.1"} { + request { + properties { + include_subprojects { + description: "If set to 'true' and project field is set then models from the subprojects are searched too" + type: boolean + default: false + } + } + } + } } get_all { "2.1" { diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index 6a76c98..c4d9a9a 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -283,17 +283,14 @@ create { description: "Create a new project" request { type: object - required :[ - name - description - ] + required :[name] properties { name { description: "Project name Unique within the company." type: string } description { - description: "Project description. " + description: "Project description." type: string } tags { @@ -677,6 +674,17 @@ get_unique_metric_variants { } } } + "2.13": ${get_unique_metric_variants."2.1"} { + request { + properties { + include_subprojects { + description: "If set to 'true' and the project field is set then the result includes metrics/variants from the subproject tasks" + type: boolean + default: true + } + } + } + } } get_hyperparam_values { "2.13" { @@ -702,6 +710,11 @@ get_hyperparam_values { description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'" type: boolean } + include_subprojects { + description: "If set to 'true' and the project field is set then the result includes hyper parameters values from the subproject tasks" + type: boolean + default: true + } } } response { @@ -761,6 +774,17 @@ get_hyper_parameters { } } } + "2.13": ${get_hyper_parameters."2.9"} { + request { + properties { + include_subprojects { + description: "If set to 'true' and the project field is set then the result includes hyper parameters from the subproject tasks" + type: boolean + default: true + } + } + } + } } get_task_tags { @@ -830,7 +854,7 @@ make_private { } get_task_parents { "2.12" { - description: "Get unique parent tasks for the tasks in the specified pprojects" + description: "Get unique parent tasks for the tasks in the specified projects" request { type: object properties { @@ -881,4 +905,15 @@ get_task_parents { } } } + "2.13": ${get_task_parents."2.12"} { + request { + properties { + include_subprojects { + description: "If set to 'true' and the projects field is not empty then the result includes tasks parents from the subproject tasks" + type: boolean + default: true + } + } + } + } } \ No newline at end of file diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index c305735..a178290 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -583,6 +583,17 @@ get_by_id_ex { get_all_ex { internal: true "2.1": ${get_all."2.1"} + "2.13": ${get_all_ex."2.1"} { + request { + properties { + include_subprojects { + description: "If set to 'true' and project field is set then tasks from the subprojects are searched too" + type: boolean + default: false + } + } + } + } } get_all { "2.1" { diff --git a/apiserver/services/models.py b/apiserver/services/models.py index 1ac3138..c1988c3 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -17,7 +17,7 @@ from apiserver.apimodels.models import ( DeleteModelRequest, ) from apiserver.bll.organization import OrgBLL, Tags -from apiserver.bll.project import ProjectBLL +from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.task import TaskBLL from apiserver.bll.task.utils import deleted_prefix from apiserver.config_repo import config @@ -86,10 +86,22 @@ def get_by_task_id(call: APICall, company_id, _): call.result.data = {"model": model_dict} +def _process_include_subprojects(call_data: dict): + include_subprojects = call_data.pop("include_subprojects", False) + project_ids = call_data.get("project") + if not project_ids or not include_subprojects: + return + + if not isinstance(project_ids, list): + project_ids = [project_ids] + call_data["project"] = project_ids_with_children(project_ids) + + @endpoint("models.get_all_ex", required_fields=[]) def get_all_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) with translate_errors_context(): + _process_include_subprojects(call.data) with TimingContext("mongo", "models_get_all_ex"): models = Model.get_many_with_join( company=company_id, query_dict=call.data, allow_public=True diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index 658f97a..eec2177 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -8,12 +8,14 @@ from apiserver.apierrors.errors.bad_request import InvalidProjectId from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.projects import ( GetHyperParamRequest, - ProjectRequest, ProjectTagsRequest, ProjectTaskParentsRequest, ProjectHyperparamValuesRequest, ProjectsGetRequest, - DeleteRequest, MoveRequest, MergeRequest, + DeleteRequest, + MoveRequest, + MergeRequest, + ProjectOrNoneRequest, ) from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.project import ProjectBLL @@ -80,14 +82,13 @@ def _adjust_search_parameters(data: dict, shallow_search: bool): return if "parent" not in data: - data["parent"] = [None, ""] + data["parent"] = [None] @endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): conform_tag_fields(call, call.data) allow_public = not request.non_public - shallow_search = request.shallow_search or request.include_stats with TimingContext("mongo", "projects_get_all"): data = call.data if request.active_users: @@ -102,12 +103,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): return data["id"] = ids - _adjust_search_parameters(data, shallow_search=shallow_search) + _adjust_search_parameters(data, shallow_search=request.shallow_search) projects = Project.get_many_with_join( - company=company_id, - query_dict=data, - allow_public=allow_public, + company=company_id, query_dict=data, allow_public=allow_public, ) conform_output_tags(call, projects) @@ -147,9 +146,7 @@ def get_all(call: APICall): @endpoint( - "projects.create", - required_fields=["name", "description"], - response_data_model=IdResponse, + "projects.create", required_fields=["name"], response_data_model=IdResponse, ) def create(call: APICall): identity = call.identity @@ -232,11 +229,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest): call.result.data = {**attr.asdict(res)} -@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectRequest) -def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectRequest): +@endpoint( + "projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest +) +def get_unique_metric_variants( + call: APICall, company_id: str, request: ProjectOrNoneRequest +): metrics = task_bll.get_unique_metric_variants( - company_id, [request.project] if request.project else None + company_id, + [request.project] if request.project else None, + include_subprojects=request.include_subprojects, ) call.result.data = {"metrics": metrics} @@ -252,6 +255,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR total, remaining, parameters = TaskBLL.get_aggregated_project_parameters( company_id, project_ids=[request.project] if request.project else None, + include_subprojects=request.include_subprojects, page=request.page, page_size=request.page_size, ) @@ -276,6 +280,7 @@ def get_hyperparam_values( project_ids=request.projects, section=request.section, name=request.name, + include_subprojects=request.include_subprojects, allow_public=request.allow_public, ) call.result.data = { @@ -340,6 +345,9 @@ def get_task_parents( ): call.result.data = { "parents": project_bll.get_task_parents( - company_id, projects=request.projects, state=request.tasks_state + company_id, + projects=request.projects, + include_subprojects=request.include_subprojects, + state=request.tasks_state, ) } diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 93bc6d0..b926972 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -47,7 +47,7 @@ from apiserver.apimodels.tasks import ( ) from apiserver.bll.event import EventBLL from apiserver.bll.organization import OrgBLL, Tags -from apiserver.bll.project import ProjectBLL +from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.queue import QueueBLL from apiserver.bll.task import ( TaskBLL, @@ -152,27 +152,49 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest): call.result.data = {"task": task_dict} -def escape_execution_parameters(call: APICall): - projection = Task.get_projection(call.data) - if projection: - Task.set_projection(call.data, escape_paths(projection)) +def escape_execution_parameters(call: APICall) -> dict: + if not call.data: + return call.data - ordering = Task.get_ordering(call.data) + keys = list(call.data) + call_data = { + safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys)) + } + + projection = Task.get_projection(call_data) + if projection: + Task.set_projection(call_data, escape_paths(projection)) + + ordering = Task.get_ordering(call_data) if ordering: - Task.set_ordering(call.data, escape_paths(ordering)) + Task.set_ordering(call_data, escape_paths(ordering)) + + return call_data + + +def _process_include_subprojects(call_data: dict): + include_subprojects = call_data.pop("include_subprojects", False) + project_ids = call_data.get("project") + if not project_ids or not include_subprojects: + return + + if not isinstance(project_ids, list): + project_ids = [project_ids] + call_data["project"] = project_ids_with_children(project_ids) @endpoint("tasks.get_all_ex", required_fields=[]) def get_all_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) - escape_execution_parameters(call) + call_data = escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_all_ex"): + _process_include_subprojects(call_data) tasks = Task.get_many_with_join( company=company_id, - query_dict=call.data, + query_dict=call_data, allow_public=True, # required in case projection is requested for public dataset/versions ) unprepare_from_saved(call, tasks) @@ -183,12 +205,12 @@ def get_all_ex(call: APICall, company_id, _): def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) - escape_execution_parameters(call) + call_data = escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_by_id_ex"): tasks = Task.get_many_with_join( - company=company_id, query_dict=call.data, allow_public=True, + company=company_id, query_dict=call_data, allow_public=True, ) unprepare_from_saved(call, tasks) @@ -199,14 +221,14 @@ def get_by_id_ex(call: APICall, company_id, _): def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) - escape_execution_parameters(call) + call_data = escape_execution_parameters(call) with translate_errors_context(): with TimingContext("mongo", "task_get_all"): tasks = Task.get_many( company=company_id, - parameters=call.data, - query_dict=call.data, + parameters=call_data, + query_dict=call_data, allow_public=True, # required in case projection is requested for public dataset/versions ) unprepare_from_saved(call, tasks) @@ -216,7 +238,9 @@ def get_all(call: APICall, company_id, _): @endpoint("tasks.get_types", request_data_model=GetTypesRequest) def get_types(call: APICall, company_id, request: GetTypesRequest): call.result.data = { - "types": list(project_bll.get_task_types(company_id, project_ids=request.projects)) + "types": list( + project_bll.get_task_types(company_id, project_ids=request.projects) + ) } diff --git a/apiserver/tests/api_client.py b/apiserver/tests/api_client.py index b67a733..2c81f4e 100644 --- a/apiserver/tests/api_client.py +++ b/apiserver/tests/api_client.py @@ -1,5 +1,4 @@ import json -import logging import os import time from contextlib import contextmanager @@ -10,16 +9,14 @@ import requests import six from boltons.iterutils import remap from boltons.typeutils import issubclass -from pyhocon import ConfigFactory from requests.adapters import HTTPAdapter from requests.auth import HTTPBasicAuth from requests.packages.urllib3.util.retry import Retry from apiserver.apierrors.base import BaseError +from apiserver.config_repo import config -config = ConfigFactory.parse_file("api_client.conf") - -log = logging.getLogger("api_client") +log = config.logger(__file__) class APICallResult: @@ -111,7 +108,7 @@ class APIClient: self.api_key = ( api_key or os.environ.get("SM_API_KEY") - or config.get("api_key") + or config.get("apiclient.api_key") ) if not self.api_key: raise ValueError("APIClient requires api_key in constructor or config") @@ -119,7 +116,7 @@ class APIClient: self.secret_key = ( secret_key or os.environ.get("SM_API_SECRET") - or config.get("secret_key") + or config.get("apiclient.secret_key") ) if not self.secret_key: raise ValueError( @@ -127,7 +124,7 @@ class APIClient: ) self.base_url = ( - base_url or os.environ.get("SM_API_URL") or config.get("base_url") + base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url") ) if not self.base_url: raise ValueError("APIClient requires base_url in constructor or config") @@ -139,9 +136,9 @@ class APIClient: # create http session self.http_session = requests.session() - retries = config.get("retries", 7) - backoff_factor = config.get("backoff_factor", 0.3) - status_forcelist = config.get("status_forcelist", (500, 502, 504)) + retries = config.get("apiclient.retries", 7) + backoff_factor = config.get("apiclient.backoff_factor", 0.3) + status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504)) retry = Retry( total=retries, read=retries, diff --git a/apiserver/tests/automated/test_projects_retrieval.py b/apiserver/tests/automated/test_projects_retrieval.py deleted file mode 100644 index effa739..0000000 --- a/apiserver/tests/automated/test_projects_retrieval.py +++ /dev/null @@ -1,65 +0,0 @@ -from boltons.iterutils import first - -from apiserver.tests.automated import TestService - - -class TestProjectsRetrieval(TestService): - def setUp(self, **kwargs): - super().setUp(version="2.13") - - def test_active_user(self): - user = self.api.users.get_current_user().user.id - project1 = self.temp_project(name="Project retrieval1") - project2 = self.temp_project(name="Project retrieval2") - self.temp_task(project=project2) - - projects = self.api.projects.get_all_ex().projects - self.assertTrue({project1, project2}.issubset({p.id for p in projects})) - - projects = self.api.projects.get_all_ex(active_users=[user]).projects - ids = {p.id for p in projects} - self.assertFalse(project1 in ids) - self.assertTrue(project2 in ids) - - def test_stats(self): - project = self.temp_project() - self.temp_task(project=project) - self.temp_task(project=project) - archived_task = self.temp_task(project=project) - self.api.tasks.archive(tasks=[archived_task]) - - p = self._get_project(project) - self.assertFalse("stats" in p) - - p = self._get_project(project, include_stats=True) - self.assertFalse("archived" in p.stats) - self.assertTrue(p.stats.active.status_count.created, 2) - - p = self._get_project(project, include_stats=True, stats_for_state=None) - self.assertTrue(p.stats.active.status_count.created, 2) - self.assertTrue(p.stats.archived.status_count.created, 1) - - def _get_project(self, project, **kwargs): - projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects - p = first(p for p in projects if p.id == project) - self.assertIsNotNone(p) - return p - - def temp_project(self, **kwargs) -> str: - self.update_missing( - kwargs, - name="Test projects retrieval", - description="test", - delete_params=dict(force=True), - ) - return self.create_temp("projects", **kwargs) - - def temp_task(self, **kwargs) -> str: - self.update_missing( - kwargs, - type="testing", - name="test projects retrieval", - input=dict(view=dict()), - delete_params=dict(force=True), - ) - return self.create_temp("tasks", **kwargs) diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index 4a9446a..87bb753 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -4,8 +4,10 @@ from typing import Sequence, Optional, Tuple from boltons.iterutils import first from apiserver.apierrors import errors +from apiserver.config_repo import config from apiserver.database.model import EntityVisibility from apiserver.database.utils import id as db_id +from apiserver.tests.api_client import APIClient from apiserver.tests.automated import TestService @@ -14,7 +16,14 @@ class TestSubProjects(TestService): super().setUp(version="2.13") def test_project_aggregations(self): - child = self._temp_project(name="Aggregation/Pr1") + """This test requires user with user_auth_only... credentials in db""" + user2_client = APIClient( + api_key=config.get("apiclient.user_auth_only"), + secret_key=config.get("apiclient.user_auth_only_secret"), + base_url=f"http://localhost:8008/v2.13", + ) + + child = self._temp_project(name="Aggregation/Pr1", client=user2_client) project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id child_project = self.api.projects.get_all_ex(id=[child]).projects[0] self.assertEqual(child_project.parent.id, project) @@ -210,12 +219,13 @@ class TestSubProjects(TestService): delete_params = dict(can_fail=True, force=True) - def _temp_project(self, name, **kwargs): + def _temp_project(self, name, client=None, **kwargs): return self.create_temp( "projects", delete_params=self.delete_params, name=name, description="", + client=client, **kwargs, ) diff --git a/apiserver/tests/automated/test_tasks_edit.py b/apiserver/tests/automated/test_tasks_edit.py index 54c26c7..d6ac6e2 100644 --- a/apiserver/tests/automated/test_tasks_edit.py +++ b/apiserver/tests/automated/test_tasks_edit.py @@ -225,4 +225,4 @@ class TestTasksEdit(TestService): self.api.tasks.enqueue(task=task_id, queue=queue_id) task = self.api.tasks.get_all_ex(id=task_id, projection=projection).tasks[0] self.assertEqual(task.status, "queued") - self.assertEqual(task.execution.queue, queue_id) + self.assertEqual(task.execution.queue.id, queue_id)