More sub-projects support and fixes

This commit is contained in:
allegroai 2021-05-03 17:44:54 +03:00
parent 0d5174c453
commit 3c5195028e
17 changed files with 225 additions and 142 deletions

View File

@ -1,3 +1,8 @@
301 {
_: "moved_permanently"
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
}
400 { 400 {
_: "bad_request" _: "bad_request"
1: ["not_supported", "endpoint is not supported"] 1: ["not_supported", "endpoint is not supported"]

View File

@ -22,7 +22,12 @@ class DeleteRequest(ProjectRequest):
delete_contents = fields.BoolField(default=False) delete_contents = fields.BoolField(default=False)
class GetHyperParamRequest(ProjectRequest): class ProjectOrNoneRequest(models.Base):
project = fields.StringField()
include_subprojects = fields.BoolField(default=True)
class GetHyperParamRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0) page = fields.IntField(default=0)
page_size = fields.IntField(default=500) page_size = fields.IntField(default=500)
@ -33,6 +38,7 @@ class ProjectTagsRequest(TagsRequest):
class MultiProjectRequest(models.Base): class MultiProjectRequest(models.Base):
projects = fields.ListField(str) projects = fields.ListField(str)
include_subprojects = fields.BoolField(default=True)
class ProjectTaskParentsRequest(MultiProjectRequest): class ProjectTaskParentsRequest(MultiProjectRequest):

View File

@ -147,7 +147,7 @@ class ProjectBLL:
user: str, user: str,
company: str, company: str,
name: str, name: str,
description: str, description: str = "",
tags: Sequence[str] = None, tags: Sequence[str] = None,
system_tags: Sequence[str] = None, system_tags: Sequence[str] = None,
default_output_destination: str = None, default_output_destination: str = None,
@ -507,20 +507,24 @@ class ProjectBLL:
company, company,
project_ids: Sequence[str], project_ids: Sequence[str],
user_ids: Optional[Sequence[str]] = None, user_ids: Optional[Sequence[str]] = None,
) -> set: ) -> Set[str]:
""" """
Get the set of user ids that created tasks/models/dataviews in the given projects Get the set of user ids that created tasks/models/dataviews in the given projects
If project_ids is empty then all projects are examined If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned If user_ids are passed then only subset of these users is returned
""" """
with TimingContext("mongo", "active_users_in_projects"): with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company) query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
projects_query = query
if project_ids: if project_ids:
project_ids = _ids_with_children(project_ids) project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids) query &= Q(project__in=project_ids)
if user_ids: projects_query &= Q(id__in=project_ids)
query &= Q(user__in=user_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model): for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user")) res |= set(cls_.objects(query).distinct(field="user"))
@ -545,10 +549,17 @@ class ProjectBLL:
else: else:
query &= Q(company=company) query &= Q(company=company)
user_projects_query = query
if project_ids: if project_ids:
query &= Q(project__in=_ids_with_children(project_ids)) ids_with_children = _ids_with_children(project_ids)
query &= Q(project__in=ids_with_children)
user_projects_query &= Q(id__in=ids_with_children)
res = Task.objects(query).distinct(field="project") res = {p.id for p in Project.objects(user_projects_query).only("id")}
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="project"))
res = list(res)
if not res: if not res:
return res return res
@ -563,6 +574,7 @@ class ProjectBLL:
cls, cls,
company_id: str, company_id: str,
projects: Sequence[str], projects: Sequence[str],
include_subprojects: bool,
state: Optional[EntityVisibility] = None, state: Optional[EntityVisibility] = None,
) -> Sequence[dict]: ) -> Sequence[dict]:
""" """
@ -571,7 +583,8 @@ class ProjectBLL:
""" """
query = Q(company=company_id) query = Q(company=company_id)
if projects: if projects:
projects = _ids_with_children(projects) if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects) query &= Q(project__in=projects)
if state == EntityVisibility.archived: if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value]) query &= Q(system_tags__in=[EntityVisibility.archived.value])

View File

@ -13,7 +13,7 @@ import apiserver.database.utils as dbutils
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
@ -37,7 +37,12 @@ from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change, update_project_time, deleted_prefix from .utils import (
ChangeStatusRequest,
validate_status_change,
update_project_time,
deleted_prefix,
)
log = config.logger(__file__) log = config.logger(__file__)
org_bll = OrgBLL() org_bll = OrgBLL()
@ -317,12 +322,19 @@ class TaskBLL:
cls.validate_execution_model(task) cls.validate_execution_model(task)
@staticmethod @staticmethod
def get_unique_metric_variants(company_id, project_ids=None): def get_unique_metric_variants(
company_id, project_ids: Sequence[str], include_subprojects: bool
):
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
pipeline = [ pipeline = [
{ {
"$match": dict( "$match": dict(
company={"$in": [None, "", company_id]}, company={"$in": [None, "", company_id]}, **project_constraint,
**({"project": {"$in": project_ids}} if project_ids else {}),
) )
}, },
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}}, {"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
@ -601,11 +613,17 @@ class TaskBLL:
@staticmethod @staticmethod
def get_aggregated_project_parameters( def get_aggregated_project_parameters(
company_id, company_id,
project_ids: Sequence[str] = None, project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0, page: int = 0,
page_size: int = 500, page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]: ) -> Tuple[int, int, Sequence[dict]]:
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
page = max(0, page) page = max(0, page)
page_size = max(1, page_size) page_size = max(1, page_size)
pipeline = [ pipeline = [
@ -613,7 +631,7 @@ class TaskBLL:
"$match": { "$match": {
"company": {"$in": [None, "", company_id]}, "company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}}, "hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}), **project_constraint,
} }
}, },
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}}, {"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
@ -687,6 +705,7 @@ class TaskBLL:
project_ids: Sequence[str], project_ids: Sequence[str],
section: str, section: str,
name: str, name: str,
include_subprojects: bool,
allow_public: bool = True, allow_public: bool = True,
) -> HyperParamValues: ) -> HyperParamValues:
if allow_public: if allow_public:
@ -694,6 +713,8 @@ class TaskBLL:
else: else:
company_constraint = {"company": company_id} company_constraint = {"company": company_id}
if project_ids: if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}} project_constraint = {"project": {"$in": project_ids}}
else: else:
project_constraint = {} project_constraint = {}

View File

@ -193,8 +193,7 @@ class GetMixin(PropsMixin):
""" """
Pop the parameters that match the specified patterns and return Pop the parameters that match the specified patterns and return
the dictionary of matching parameters the dictionary of matching parameters
For backwards compatibility with the previous version of the code Pop None parameters since they are not the real queries
the None values are discarded
""" """
if not patterns: if not patterns:
return {} return {}
@ -351,11 +350,7 @@ class GetMixin(PropsMixin):
q = RegexQ() q = RegexQ()
for action in filter(None, actions): for action in filter(None, actions):
q &= RegexQ( q &= RegexQ(
**{ **{f"{mongoengine_field}__{action}": list(set(actions[action]))}
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
) )
if not allow_empty: if not allow_empty:

View File

@ -36,7 +36,7 @@ class Project(AttributedDocument):
min_length=3, min_length=3,
sparse=True, sparse=True,
) )
description = StringField(required=True) description = StringField()
created = DateTimeField(required=True) created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True)) tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True)) system_tags = SafeSortedListField(StringField(required=True))

View File

@ -115,7 +115,7 @@ class Execution(EmbeddedDocument, ProperDictMixin):
framework = StringField() framework = StringField()
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact)) artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
docker_cmd = StringField() docker_cmd = StringField()
queue = StringField() queue = StringField(reference_field="Queue")
""" Queue ID where task was queued """ """ Queue ID where task was queued """

View File

@ -151,6 +151,17 @@ get_by_id_ex {
get_all_ex { get_all_ex {
internal: true internal: true
"2.1": ${get_all."2.1"} "2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
} }
get_all { get_all {
"2.1" { "2.1" {

View File

@ -283,17 +283,14 @@ create {
description: "Create a new project" description: "Create a new project"
request { request {
type: object type: object
required :[ required :[name]
name
description
]
properties { properties {
name { name {
description: "Project name Unique within the company." description: "Project name Unique within the company."
type: string type: string
} }
description { description {
description: "Project description. " description: "Project description."
type: string type: string
} }
tags { tags {
@ -677,6 +674,17 @@ get_unique_metric_variants {
} }
} }
} }
"2.13": ${get_unique_metric_variants."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metrics/variants from the subproject tasks"
type: boolean
default: true
}
}
}
}
} }
get_hyperparam_values { get_hyperparam_values {
"2.13" { "2.13" {
@ -702,6 +710,11 @@ get_hyperparam_values {
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'" description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
type: boolean type: boolean
} }
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters values from the subproject tasks"
type: boolean
default: true
}
} }
} }
response { response {
@ -761,6 +774,17 @@ get_hyper_parameters {
} }
} }
} }
"2.13": ${get_hyper_parameters."2.9"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters from the subproject tasks"
type: boolean
default: true
}
}
}
}
} }
get_task_tags { get_task_tags {
@ -830,7 +854,7 @@ make_private {
} }
get_task_parents { get_task_parents {
"2.12" { "2.12" {
description: "Get unique parent tasks for the tasks in the specified pprojects" description: "Get unique parent tasks for the tasks in the specified projects"
request { request {
type: object type: object
properties { properties {
@ -881,4 +905,15 @@ get_task_parents {
} }
} }
} }
"2.13": ${get_task_parents."2.12"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the projects field is not empty then the result includes tasks parents from the subproject tasks"
type: boolean
default: true
}
}
}
}
} }

View File

@ -583,6 +583,17 @@ get_by_id_ex {
get_all_ex { get_all_ex {
internal: true internal: true
"2.1": ${get_all."2.1"} "2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
} }
get_all { get_all {
"2.1" { "2.1" {

View File

@ -17,7 +17,7 @@ from apiserver.apimodels.models import (
DeleteModelRequest, DeleteModelRequest,
) )
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.task import TaskBLL from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import deleted_prefix from apiserver.bll.task.utils import deleted_prefix
from apiserver.config_repo import config from apiserver.config_repo import config
@ -86,10 +86,22 @@ def get_by_task_id(call: APICall, company_id, _):
call.result.data = {"model": model_dict} call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", required_fields=[]) @endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _): def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
with translate_errors_context(): with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"): with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join( models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True company=company_id, query_dict=call.data, allow_public=True

View File

@ -8,12 +8,14 @@ from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import ( from apiserver.apimodels.projects import (
GetHyperParamRequest, GetHyperParamRequest,
ProjectRequest,
ProjectTagsRequest, ProjectTagsRequest,
ProjectTaskParentsRequest, ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest, ProjectHyperparamValuesRequest,
ProjectsGetRequest, ProjectsGetRequest,
DeleteRequest, MoveRequest, MergeRequest, DeleteRequest,
MoveRequest,
MergeRequest,
ProjectOrNoneRequest,
) )
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL
@ -80,14 +82,13 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
return return
if "parent" not in data: if "parent" not in data:
data["parent"] = [None, ""] data["parent"] = [None]
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest) @endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest): def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
allow_public = not request.non_public allow_public = not request.non_public
shallow_search = request.shallow_search or request.include_stats
with TimingContext("mongo", "projects_get_all"): with TimingContext("mongo", "projects_get_all"):
data = call.data data = call.data
if request.active_users: if request.active_users:
@ -102,12 +103,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
return return
data["id"] = ids data["id"] = ids
_adjust_search_parameters(data, shallow_search=shallow_search) _adjust_search_parameters(data, shallow_search=request.shallow_search)
projects = Project.get_many_with_join( projects = Project.get_many_with_join(
company=company_id, company=company_id, query_dict=data, allow_public=allow_public,
query_dict=data,
allow_public=allow_public,
) )
conform_output_tags(call, projects) conform_output_tags(call, projects)
@ -147,9 +146,7 @@ def get_all(call: APICall):
@endpoint( @endpoint(
"projects.create", "projects.create", required_fields=["name"], response_data_model=IdResponse,
required_fields=["name", "description"],
response_data_model=IdResponse,
) )
def create(call: APICall): def create(call: APICall):
identity = call.identity identity = call.identity
@ -232,11 +229,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
call.result.data = {**attr.asdict(res)} call.result.data = {**attr.asdict(res)}
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectRequest) @endpoint(
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectRequest): "projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
)
def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
):
metrics = task_bll.get_unique_metric_variants( metrics = task_bll.get_unique_metric_variants(
company_id, [request.project] if request.project else None company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
) )
call.result.data = {"metrics": metrics} call.result.data = {"metrics": metrics}
@ -252,6 +255,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters( total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
company_id, company_id,
project_ids=[request.project] if request.project else None, project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page, page=request.page,
page_size=request.page_size, page_size=request.page_size,
) )
@ -276,6 +280,7 @@ def get_hyperparam_values(
project_ids=request.projects, project_ids=request.projects,
section=request.section, section=request.section,
name=request.name, name=request.name,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public, allow_public=request.allow_public,
) )
call.result.data = { call.result.data = {
@ -340,6 +345,9 @@ def get_task_parents(
): ):
call.result.data = { call.result.data = {
"parents": project_bll.get_task_parents( "parents": project_bll.get_task_parents(
company_id, projects=request.projects, state=request.tasks_state company_id,
projects=request.projects,
include_subprojects=request.include_subprojects,
state=request.tasks_state,
) )
} }

View File

@ -47,7 +47,7 @@ from apiserver.apimodels.tasks import (
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import ( from apiserver.bll.task import (
TaskBLL, TaskBLL,
@ -152,27 +152,49 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
call.result.data = {"task": task_dict} call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall): def escape_execution_parameters(call: APICall) -> dict:
projection = Task.get_projection(call.data) if not call.data:
if projection: return call.data
Task.set_projection(call.data, escape_paths(projection))
ordering = Task.get_ordering(call.data) keys = list(call.data)
call_data = {
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
}
projection = Task.get_projection(call_data)
if projection:
Task.set_projection(call_data, escape_paths(projection))
ordering = Task.get_ordering(call_data)
if ordering: if ordering:
Task.set_ordering(call.data, escape_paths(ordering)) Task.set_ordering(call_data, escape_paths(ordering))
return call_data
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("tasks.get_all_ex", required_fields=[]) @endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _): def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
escape_execution_parameters(call) call_data = escape_execution_parameters(call)
with translate_errors_context(): with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"): with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=company_id, company=company_id,
query_dict=call.data, query_dict=call_data,
allow_public=True, # required in case projection is requested for public dataset/versions allow_public=True, # required in case projection is requested for public dataset/versions
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
@ -183,12 +205,12 @@ def get_all_ex(call: APICall, company_id, _):
def get_by_id_ex(call: APICall, company_id, _): def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
escape_execution_parameters(call) call_data = escape_execution_parameters(call)
with translate_errors_context(): with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"): with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join( tasks = Task.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True, company=company_id, query_dict=call_data, allow_public=True,
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
@ -199,14 +221,14 @@ def get_by_id_ex(call: APICall, company_id, _):
def get_all(call: APICall, company_id, _): def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data) conform_tag_fields(call, call.data)
escape_execution_parameters(call) call_data = escape_execution_parameters(call)
with translate_errors_context(): with translate_errors_context():
with TimingContext("mongo", "task_get_all"): with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many( tasks = Task.get_many(
company=company_id, company=company_id,
parameters=call.data, parameters=call_data,
query_dict=call.data, query_dict=call_data,
allow_public=True, # required in case projection is requested for public dataset/versions allow_public=True, # required in case projection is requested for public dataset/versions
) )
unprepare_from_saved(call, tasks) unprepare_from_saved(call, tasks)
@ -216,7 +238,9 @@ def get_all(call: APICall, company_id, _):
@endpoint("tasks.get_types", request_data_model=GetTypesRequest) @endpoint("tasks.get_types", request_data_model=GetTypesRequest)
def get_types(call: APICall, company_id, request: GetTypesRequest): def get_types(call: APICall, company_id, request: GetTypesRequest):
call.result.data = { call.result.data = {
"types": list(project_bll.get_task_types(company_id, project_ids=request.projects)) "types": list(
project_bll.get_task_types(company_id, project_ids=request.projects)
)
} }

View File

@ -1,5 +1,4 @@
import json import json
import logging
import os import os
import time import time
from contextlib import contextmanager from contextlib import contextmanager
@ -10,16 +9,14 @@ import requests
import six import six
from boltons.iterutils import remap from boltons.iterutils import remap
from boltons.typeutils import issubclass from boltons.typeutils import issubclass
from pyhocon import ConfigFactory
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
from requests.packages.urllib3.util.retry import Retry from requests.packages.urllib3.util.retry import Retry
from apiserver.apierrors.base import BaseError from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
config = ConfigFactory.parse_file("api_client.conf") log = config.logger(__file__)
log = logging.getLogger("api_client")
class APICallResult: class APICallResult:
@ -111,7 +108,7 @@ class APIClient:
self.api_key = ( self.api_key = (
api_key api_key
or os.environ.get("SM_API_KEY") or os.environ.get("SM_API_KEY")
or config.get("api_key") or config.get("apiclient.api_key")
) )
if not self.api_key: if not self.api_key:
raise ValueError("APIClient requires api_key in constructor or config") raise ValueError("APIClient requires api_key in constructor or config")
@ -119,7 +116,7 @@ class APIClient:
self.secret_key = ( self.secret_key = (
secret_key secret_key
or os.environ.get("SM_API_SECRET") or os.environ.get("SM_API_SECRET")
or config.get("secret_key") or config.get("apiclient.secret_key")
) )
if not self.secret_key: if not self.secret_key:
raise ValueError( raise ValueError(
@ -127,7 +124,7 @@ class APIClient:
) )
self.base_url = ( self.base_url = (
base_url or os.environ.get("SM_API_URL") or config.get("base_url") base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")
) )
if not self.base_url: if not self.base_url:
raise ValueError("APIClient requires base_url in constructor or config") raise ValueError("APIClient requires base_url in constructor or config")
@ -139,9 +136,9 @@ class APIClient:
# create http session # create http session
self.http_session = requests.session() self.http_session = requests.session()
retries = config.get("retries", 7) retries = config.get("apiclient.retries", 7)
backoff_factor = config.get("backoff_factor", 0.3) backoff_factor = config.get("apiclient.backoff_factor", 0.3)
status_forcelist = config.get("status_forcelist", (500, 502, 504)) status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504))
retry = Retry( retry = Retry(
total=retries, total=retries,
read=retries, read=retries,

View File

@ -1,65 +0,0 @@
from boltons.iterutils import first
from apiserver.tests.automated import TestService
class TestProjectsRetrieval(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_active_user(self):
user = self.api.users.get_current_user().user.id
project1 = self.temp_project(name="Project retrieval1")
project2 = self.temp_project(name="Project retrieval2")
self.temp_task(project=project2)
projects = self.api.projects.get_all_ex().projects
self.assertTrue({project1, project2}.issubset({p.id for p in projects}))
projects = self.api.projects.get_all_ex(active_users=[user]).projects
ids = {p.id for p in projects}
self.assertFalse(project1 in ids)
self.assertTrue(project2 in ids)
def test_stats(self):
project = self.temp_project()
self.temp_task(project=project)
self.temp_task(project=project)
archived_task = self.temp_task(project=project)
self.api.tasks.archive(tasks=[archived_task])
p = self._get_project(project)
self.assertFalse("stats" in p)
p = self._get_project(project, include_stats=True)
self.assertFalse("archived" in p.stats)
self.assertTrue(p.stats.active.status_count.created, 2)
p = self._get_project(project, include_stats=True, stats_for_state=None)
self.assertTrue(p.stats.active.status_count.created, 2)
self.assertTrue(p.stats.archived.status_count.created, 1)
def _get_project(self, project, **kwargs):
projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects
p = first(p for p in projects if p.id == project)
self.assertIsNotNone(p)
return p
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test projects retrieval",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)
def temp_task(self, **kwargs) -> str:
self.update_missing(
kwargs,
type="testing",
name="test projects retrieval",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)

View File

@ -4,8 +4,10 @@ from typing import Sequence, Optional, Tuple
from boltons.iterutils import first from boltons.iterutils import first
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.utils import id as db_id from apiserver.database.utils import id as db_id
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService from apiserver.tests.automated import TestService
@ -14,7 +16,14 @@ class TestSubProjects(TestService):
super().setUp(version="2.13") super().setUp(version="2.13")
def test_project_aggregations(self): def test_project_aggregations(self):
child = self._temp_project(name="Aggregation/Pr1") """This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
api_key=config.get("apiclient.user_auth_only"),
secret_key=config.get("apiclient.user_auth_only_secret"),
base_url=f"http://localhost:8008/v2.13",
)
child = self._temp_project(name="Aggregation/Pr1", client=user2_client)
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
child_project = self.api.projects.get_all_ex(id=[child]).projects[0] child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
self.assertEqual(child_project.parent.id, project) self.assertEqual(child_project.parent.id, project)
@ -210,12 +219,13 @@ class TestSubProjects(TestService):
delete_params = dict(can_fail=True, force=True) delete_params = dict(can_fail=True, force=True)
def _temp_project(self, name, **kwargs): def _temp_project(self, name, client=None, **kwargs):
return self.create_temp( return self.create_temp(
"projects", "projects",
delete_params=self.delete_params, delete_params=self.delete_params,
name=name, name=name,
description="", description="",
client=client,
**kwargs, **kwargs,
) )

View File

@ -225,4 +225,4 @@ class TestTasksEdit(TestService):
self.api.tasks.enqueue(task=task_id, queue=queue_id) self.api.tasks.enqueue(task=task_id, queue=queue_id)
task = self.api.tasks.get_all_ex(id=task_id, projection=projection).tasks[0] task = self.api.tasks.get_all_ex(id=task_id, projection=projection).tasks[0]
self.assertEqual(task.status, "queued") self.assertEqual(task.status, "queued")
self.assertEqual(task.execution.queue, queue_id) self.assertEqual(task.execution.queue.id, queue_id)