Add preliminary support for datasets under projects

This commit is contained in:
allegroai 2022-11-29 17:27:02 +02:00
parent e5230edac3
commit 53c9b5525e
3 changed files with 18 additions and 26 deletions

View File

@ -24,7 +24,7 @@ from apiserver.apimodels.models import (
) )
from apiserver.bll.model import ModelBLL, Metadata from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation from apiserver.bll.util import run_batch_operation
@ -51,6 +51,7 @@ from apiserver.services.utils import (
ModelsBackwardsCompatibility, ModelsBackwardsCompatibility,
unescape_metadata, unescape_metadata,
escape_metadata, escape_metadata,
process_include_subprojects,
) )
log = config.logger(__file__) log = config.logger(__file__)
@ -106,21 +107,10 @@ def get_by_task_id(call: APICall, company_id, _):
call.result.data = {"model": model_dict} call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest) @endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest): def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
_process_include_subprojects(call.data) process_include_subprojects(call.data)
Metadata.escape_query_parameters(call) Metadata.escape_query_parameters(call)
ret_params = {} ret_params = {}
models = Model.get_many_with_join( models = Model.get_many_with_join(

View File

@ -68,7 +68,7 @@ from apiserver.apimodels.tasks import (
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.project import ProjectBLL
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import ( from apiserver.bll.task import (
TaskBLL, TaskBLL,
@ -118,6 +118,7 @@ from apiserver.services.utils import (
DockerCmdBackwardsCompatibility, DockerCmdBackwardsCompatibility,
escape_dict_field, escape_dict_field,
unescape_dict_field, unescape_dict_field,
process_include_subprojects,
) )
from apiserver.utilities.dicts import nested_get from apiserver.utilities.dicts import nested_get
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
@ -203,17 +204,6 @@ def escape_execution_parameters(call: APICall) -> dict:
return call_data return call_data
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
def _hidden_query(data: dict) -> Q: def _hidden_query(data: dict) -> Q:
""" """
1. Add only non-hidden tasks search condition (unless specifically specified differently) 1. Add only non-hidden tasks search condition (unless specifically specified differently)
@ -230,7 +220,7 @@ def get_all_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call) call_data = escape_execution_parameters(call)
_process_include_subprojects(call_data) process_include_subprojects(call_data)
ret_params = {} ret_params = {}
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=company_id, company=company_id,

View File

@ -3,6 +3,7 @@ from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.organization import Filter from apiserver.apimodels.organization import Filter
from apiserver.bll.project import project_ids_with_children
from apiserver.database.model.base import GetMixin from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.database.utils import partition_tags from apiserver.database.utils import partition_tags
@ -12,6 +13,17 @@ from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from apiserver.utilities.partial_version import PartialVersion from apiserver.utilities.partial_version import PartialVersion
def process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
def get_tags_filter_dictionary(input_: Filter) -> dict: def get_tags_filter_dictionary(input_: Filter) -> dict:
if not input_: if not input_:
return {} return {}