diff --git a/apiserver/schema/services/models.conf b/apiserver/schema/services/models.conf index 3d5a713..ae44395 100644 --- a/apiserver/schema/services/models.conf +++ b/apiserver/schema/services/models.conf @@ -395,6 +395,17 @@ get_all { description: "Scroll ID that can be used with the next calls to get_all to retrieve more data" } } + "999.0": ${get_all."2.15"} { + request { + properties { + include_subprojects { + description: "If set to 'true' and project field is set then models from the subprojects are searched too" + type: boolean + default: false + } + } + } + } } get_frameworks { "2.8" { diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 462ab72..4d9526a 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -334,6 +334,17 @@ get_all { description: "Scroll ID that can be used with the next calls to get_all to retrieve more data" } } + "999.0": ${get_all."2.15"} { + request { + properties { + include_subprojects { + description: "If set to 'true' and project field is set then tasks from the subprojects are searched too" + type: boolean + default: false + } + } + } + } } get_types { "2.8" { diff --git a/apiserver/services/models.py b/apiserver/services/models.py index a51aca5..1816520 100644 --- a/apiserver/services/models.py +++ b/apiserver/services/models.py @@ -67,11 +67,10 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]): @endpoint("models.get_by_id", required_fields=["model"]) def get_by_id(call: APICall, company_id, _): model_id = call.data["model"] - - Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call) models = Model.get_many( company=company_id, - query_dict=call.data, + query_dict=call_data, query=Q(id=model_id), allow_public=True, ) @@ -113,12 +112,12 @@ def get_by_task_id(call: APICall, company_id, _): @endpoint("models.get_all_ex", request_data_model=ModelsGetRequest) def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): conform_tag_fields(call, call.data) - process_include_subprojects(call.data) - Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call) + process_include_subprojects(call_data) ret_params = {} models = Model.get_many_with_join( company=company_id, - query_dict=call.data, + query_dict=call_data, allow_public=request.allow_public, ret_params=ret_params, ) @@ -140,9 +139,9 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): @endpoint("models.get_by_id_ex", required_fields=["id"]) def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) - Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call) models = Model.get_many_with_join( - company=company_id, query_dict=call.data, allow_public=True + company=company_id, query_dict=call_data, allow_public=True ) conform_model_data(call, models) call.result.data = {"models": models} @@ -151,12 +150,13 @@ def get_by_id_ex(call: APICall, company_id, _): @endpoint("models.get_all", required_fields=[]) def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) - Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call) + process_include_subprojects(call_data) ret_params = {} models = Model.get_many( company=company_id, - parameters=call.data, - query_dict=call.data, + parameters=call_data, + query_dict=call_data, allow_public=True, ret_params=ret_params, ) diff --git a/apiserver/services/pipelines.py b/apiserver/services/pipelines.py index ab10a69..0767489 100644 --- a/apiserver/services/pipelines.py +++ b/apiserver/services/pipelines.py @@ -1,12 +1,21 @@ import re +from functools import partial -from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest +import attr + +from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns +from apiserver.apimodels.pipelines import ( + StartPipelineResponse, + StartPipelineRequest, + DeleteRunsRequest, +) from apiserver.bll.organization import OrgBLL from apiserver.bll.project import ProjectBLL from apiserver.bll.task import TaskBLL -from apiserver.bll.task.task_operations import enqueue_task +from apiserver.bll.task.task_operations import enqueue_task, delete_task +from apiserver.bll.util import run_batch_operation from apiserver.database.model.project import Project -from apiserver.database.model.task.task import Task +from apiserver.database.model.task.task import Task, TaskType from apiserver.service_repo import APICall, endpoint org_bll = OrgBLL() @@ -31,6 +40,45 @@ def _update_task_name(task: Task): task.update(name=new_name) +@endpoint("pipelines.delete_runs") +def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest): + existing_runs = set( + Task.objects(project=request.project, type=TaskType.controller).scalar("id") + ) + if not existing_runs.difference(request.ids): + raise CannotRemoveAllRuns(project=request.project) + + # make sure that only controller tasks are deleted + ids = existing_runs.intersection(request.ids) + if not ids: + return dict(succeeded=[], failed=[]) + + results, failures = run_batch_operation( + func=partial( + delete_task, + company_id=company_id, + user_id=call.identity.user, + move_to_trash=False, + force=True, + return_file_urls=False, + delete_output_models=True, + status_message="", + status_reason="Pipeline run deleted", + delete_external_artifacts=True, + ), + ids=list(ids), + ) + + succeeded = [] + if results: + for _id, (deleted, task, cleanup_res) in results: + succeeded.append( + dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res)) + ) + + call.result.data = dict(succeeded=succeeded, failed=failures) + + @endpoint( "pipelines.start_pipeline", response_data_model=StartPipelineResponse, ) diff --git a/apiserver/services/queues.py b/apiserver/services/queues.py index 27188a2..54e255a 100644 --- a/apiserver/services/queues.py +++ b/apiserver/services/queues.py @@ -83,11 +83,11 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest): conform_tag_fields(call, call.data) ret_params = {} - Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call) queues = queue_bll.get_queue_infos( company_id=company, - query_dict=call.data, - query=_hidden_query(call.data), + query_dict=call_data, + query=_hidden_query(call_data), max_task_entries=request.max_task_entries, ret_params=ret_params, ) @@ -99,11 +99,11 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest): def get_all(call: APICall, company: str, request: GetAllRequest): conform_tag_fields(call, call.data) ret_params = {} - Metadata.escape_query_parameters(call) + call_data = Metadata.escape_query_parameters(call) queues = queue_bll.get_all( company_id=company, - query_dict=call.data, - query=_hidden_query(call.data), + query_dict=call_data, + query=_hidden_query(call_data), max_task_entries=request.max_task_entries, ret_params=ret_params, ) diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 776f983..9f0961d 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -204,9 +204,7 @@ def _hidden_query(data: dict) -> Q: @endpoint("tasks.get_all_ex") def get_all_ex(call: APICall, company_id, request: GetAllReq): conform_tag_fields(call, call.data) - call_data = escape_execution_parameters(call) - process_include_subprojects(call_data) ret_params = {} tasks = Task.get_many_with_join( @@ -223,9 +221,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq): @endpoint("tasks.get_by_id_ex", required_fields=["id"]) def get_by_id_ex(call: APICall, company_id, _): conform_tag_fields(call, call.data) - call_data = escape_execution_parameters(call) - tasks = Task.get_many_with_join( company=company_id, query_dict=call_data, allow_public=True, ) @@ -237,8 +233,8 @@ def get_by_id_ex(call: APICall, company_id, _): @endpoint("tasks.get_all", required_fields=[]) def get_all(call: APICall, company_id, _): conform_tag_fields(call, call.data) - call_data = escape_execution_parameters(call) + process_include_subprojects(call_data) ret_params = {} tasks = Task.get_many( diff --git a/apiserver/tests/automated/test_subprojects.py b/apiserver/tests/automated/test_subprojects.py index 487d179..54c2c12 100644 --- a/apiserver/tests/automated/test_subprojects.py +++ b/apiserver/tests/automated/test_subprojects.py @@ -342,6 +342,32 @@ class TestSubProjects(TestService): self.assertEqual([p.id for p in res], [project2]) self.api.projects.delete(project=project1, force=True) + def test_include_subprojects(self): + project1, _ = self._temp_project_with_tasks(name="project1x") + project2, _ = self._temp_project_with_tasks(name="project1x/project22") + self._temp_model(project=project1) + self._temp_model(project=project2) + + # tasks + res = self.api.tasks.get_all_ex(project=project1).tasks + self.assertEqual(len(res), 2) + res = self.api.tasks.get_all(project=project1).tasks + self.assertEqual(len(res), 2) + res = self.api.tasks.get_all_ex(project=project1, include_subprojects=True).tasks + self.assertEqual(len(res), 4) + res = self.api.tasks.get_all(project=project1, include_subprojects=True).tasks + self.assertEqual(len(res), 4) + + # models + res = self.api.models.get_all_ex(project=project1).models + self.assertEqual(len(res), 1) + res = self.api.models.get_all(project=project1).models + self.assertEqual(len(res), 1) + res = self.api.models.get_all_ex(project=project1, include_subprojects=True).models + self.assertEqual(len(res), 2) + res = self.api.models.get_all(project=project1, include_subprojects=True).models + self.assertEqual(len(res), 2) + def test_get_all_with_check_own_contents(self): project1, _ = self._temp_project_with_tasks(name="project1x") project2 = self._temp_project(name="project2x")