Add include_subprojects to tasks/models.get_all endpoints

Fix escaping metadata for tasks, models and queues
This commit is contained in:
allegroai 2023-07-26 18:24:49 +03:00
parent a83a932e84
commit 8135cf5258
7 changed files with 117 additions and 25 deletions

View File

@ -395,6 +395,17 @@ get_all {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data" description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
} }
} }
"999.0": ${get_all."2.15"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
} }
get_frameworks { get_frameworks {
"2.8" { "2.8" {

View File

@ -334,6 +334,17 @@ get_all {
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data" description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
} }
} }
"999.0": ${get_all."2.15"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
} }
get_types { get_types {
"2.8" { "2.8" {

View File

@ -67,11 +67,10 @@ def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
@endpoint("models.get_by_id", required_fields=["model"]) @endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call: APICall, company_id, _): def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"] model_id = call.data["model"]
call_data = Metadata.escape_query_parameters(call)
Metadata.escape_query_parameters(call)
models = Model.get_many( models = Model.get_many(
company=company_id, company=company_id,
query_dict=call.data, query_dict=call_data,
query=Q(id=model_id), query=Q(id=model_id),
allow_public=True, allow_public=True,
) )
@ -113,12 +112,12 @@ def get_by_task_id(call: APICall, company_id, _):
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest) @endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
process_include_subprojects(call.data) call_data = Metadata.escape_query_parameters(call)
Metadata.escape_query_parameters(call) process_include_subprojects(call_data)
ret_params = {} ret_params = {}
models = Model.get_many_with_join( models = Model.get_many_with_join(
company=company_id, company=company_id,
query_dict=call.data, query_dict=call_data,
allow_public=request.allow_public, allow_public=request.allow_public,
ret_params=ret_params, ret_params=ret_params,
) )
@ -140,9 +139,9 @@ def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
@endpoint("models.get_by_id_ex", required_fields=["id"]) @endpoint("models.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _): def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call) call_data = Metadata.escape_query_parameters(call)
models = Model.get_many_with_join( models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True company=company_id, query_dict=call_data, allow_public=True
) )
conform_model_data(call, models) conform_model_data(call, models)
call.result.data = {"models": models} call.result.data = {"models": models}
@ -151,12 +150,13 @@ def get_by_id_ex(call: APICall, company_id, _):
@endpoint("models.get_all", required_fields=[]) @endpoint("models.get_all", required_fields=[])
def get_all(call: APICall, company_id, _): def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
Metadata.escape_query_parameters(call) call_data = Metadata.escape_query_parameters(call)
process_include_subprojects(call_data)
ret_params = {} ret_params = {}
models = Model.get_many( models = Model.get_many(
company=company_id, company=company_id,
parameters=call.data, parameters=call_data,
query_dict=call.data, query_dict=call_data,
allow_public=True, allow_public=True,
ret_params=ret_params, ret_params=ret_params,
) )

View File

@ -1,12 +1,21 @@
import re import re
from functools import partial
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest import attr
from apiserver.apierrors.errors.bad_request import CannotRemoveAllRuns
from apiserver.apimodels.pipelines import (
StartPipelineResponse,
StartPipelineRequest,
DeleteRunsRequest,
)
from apiserver.bll.organization import OrgBLL from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task from apiserver.bll.task.task_operations import enqueue_task, delete_task
from apiserver.bll.util import run_batch_operation
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task, TaskType
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
org_bll = OrgBLL() org_bll = OrgBLL()
@ -31,6 +40,45 @@ def _update_task_name(task: Task):
task.update(name=new_name) task.update(name=new_name)
@endpoint("pipelines.delete_runs")
def delete_runs(call: APICall, company_id: str, request: DeleteRunsRequest):
existing_runs = set(
Task.objects(project=request.project, type=TaskType.controller).scalar("id")
)
if not existing_runs.difference(request.ids):
raise CannotRemoveAllRuns(project=request.project)
# make sure that only controller tasks are deleted
ids = existing_runs.intersection(request.ids)
if not ids:
return dict(succeeded=[], failed=[])
results, failures = run_batch_operation(
func=partial(
delete_task,
company_id=company_id,
user_id=call.identity.user,
move_to_trash=False,
force=True,
return_file_urls=False,
delete_output_models=True,
status_message="",
status_reason="Pipeline run deleted",
delete_external_artifacts=True,
),
ids=list(ids),
)
succeeded = []
if results:
for _id, (deleted, task, cleanup_res) in results:
succeeded.append(
dict(id=_id, deleted=bool(deleted), **attr.asdict(cleanup_res))
)
call.result.data = dict(succeeded=succeeded, failed=failures)
@endpoint( @endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse, "pipelines.start_pipeline", response_data_model=StartPipelineResponse,
) )

View File

@ -83,11 +83,11 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
ret_params = {} ret_params = {}
Metadata.escape_query_parameters(call) call_data = Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos( queues = queue_bll.get_queue_infos(
company_id=company, company_id=company,
query_dict=call.data, query_dict=call_data,
query=_hidden_query(call.data), query=_hidden_query(call_data),
max_task_entries=request.max_task_entries, max_task_entries=request.max_task_entries,
ret_params=ret_params, ret_params=ret_params,
) )
@ -99,11 +99,11 @@ def get_all_ex(call: APICall, company: str, request: GetAllRequest):
def get_all(call: APICall, company: str, request: GetAllRequest): def get_all(call: APICall, company: str, request: GetAllRequest):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
ret_params = {} ret_params = {}
Metadata.escape_query_parameters(call) call_data = Metadata.escape_query_parameters(call)
queues = queue_bll.get_all( queues = queue_bll.get_all(
company_id=company, company_id=company,
query_dict=call.data, query_dict=call_data,
query=_hidden_query(call.data), query=_hidden_query(call_data),
max_task_entries=request.max_task_entries, max_task_entries=request.max_task_entries,
ret_params=ret_params, ret_params=ret_params,
) )

View File

@ -204,9 +204,7 @@ def _hidden_query(data: dict) -> Q:
@endpoint("tasks.get_all_ex") @endpoint("tasks.get_all_ex")
def get_all_ex(call: APICall, company_id, request: GetAllReq): def get_all_ex(call: APICall, company_id, request: GetAllReq):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
process_include_subprojects(call_data) process_include_subprojects(call_data)
ret_params = {} ret_params = {}
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
@ -223,9 +221,7 @@ def get_all_ex(call: APICall, company_id, request: GetAllReq):
@endpoint("tasks.get_by_id_ex", required_fields=["id"]) @endpoint("tasks.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _): def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True, company=company_id, query_dict=call_data, allow_public=True,
) )
@ -237,8 +233,8 @@ def get_by_id_ex(call: APICall, company_id, _):
@endpoint("tasks.get_all", required_fields=[]) @endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall, company_id, _): def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
process_include_subprojects(call_data)
ret_params = {} ret_params = {}
tasks = Task.get_many( tasks = Task.get_many(

View File

@ -342,6 +342,32 @@ class TestSubProjects(TestService):
self.assertEqual([p.id for p in res], [project2]) self.assertEqual([p.id for p in res], [project2])
self.api.projects.delete(project=project1, force=True) self.api.projects.delete(project=project1, force=True)
def test_include_subprojects(self):
project1, _ = self._temp_project_with_tasks(name="project1x")
project2, _ = self._temp_project_with_tasks(name="project1x/project22")
self._temp_model(project=project1)
self._temp_model(project=project2)
# tasks
res = self.api.tasks.get_all_ex(project=project1).tasks
self.assertEqual(len(res), 2)
res = self.api.tasks.get_all(project=project1).tasks
self.assertEqual(len(res), 2)
res = self.api.tasks.get_all_ex(project=project1, include_subprojects=True).tasks
self.assertEqual(len(res), 4)
res = self.api.tasks.get_all(project=project1, include_subprojects=True).tasks
self.assertEqual(len(res), 4)
# models
res = self.api.models.get_all_ex(project=project1).models
self.assertEqual(len(res), 1)
res = self.api.models.get_all(project=project1).models
self.assertEqual(len(res), 1)
res = self.api.models.get_all_ex(project=project1, include_subprojects=True).models
self.assertEqual(len(res), 2)
res = self.api.models.get_all(project=project1, include_subprojects=True).models
self.assertEqual(len(res), 2)
def test_get_all_with_check_own_contents(self): def test_get_all_with_check_own_contents(self):
project1, _ = self._temp_project_with_tasks(name="project1x") project1, _ = self._temp_project_with_tasks(name="project1x")
project2 = self._temp_project(name="project2x") project2 = self._temp_project(name="project2x")