mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
More sub-projects support and fixes
This commit is contained in:
parent
0d5174c453
commit
3c5195028e
@ -1,3 +1,8 @@
|
||||
301 {
|
||||
_: "moved_permanently"
|
||||
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
|
||||
}
|
||||
|
||||
400 {
|
||||
_: "bad_request"
|
||||
1: ["not_supported", "endpoint is not supported"]
|
||||
|
@ -22,7 +22,12 @@ class DeleteRequest(ProjectRequest):
|
||||
delete_contents = fields.BoolField(default=False)
|
||||
|
||||
|
||||
class GetHyperParamRequest(ProjectRequest):
|
||||
class ProjectOrNoneRequest(models.Base):
|
||||
project = fields.StringField()
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetHyperParamRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
@ -33,6 +38,7 @@ class ProjectTagsRequest(TagsRequest):
|
||||
|
||||
class MultiProjectRequest(models.Base):
|
||||
projects = fields.ListField(str)
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(MultiProjectRequest):
|
||||
|
@ -147,7 +147,7 @@ class ProjectBLL:
|
||||
user: str,
|
||||
company: str,
|
||||
name: str,
|
||||
description: str,
|
||||
description: str = "",
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
@ -507,20 +507,24 @@ class ProjectBLL:
|
||||
company,
|
||||
project_ids: Sequence[str],
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> set:
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models/dataviews in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
with TimingContext("mongo", "active_users_in_projects"):
|
||||
res = set()
|
||||
query = Q(company=company)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
|
||||
projects_query = query
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
if user_ids:
|
||||
query &= Q(user__in=user_ids)
|
||||
projects_query &= Q(id__in=project_ids)
|
||||
|
||||
res = set(Project.objects(projects_query).distinct(field="user"))
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="user"))
|
||||
|
||||
@ -545,10 +549,17 @@ class ProjectBLL:
|
||||
else:
|
||||
query &= Q(company=company)
|
||||
|
||||
user_projects_query = query
|
||||
if project_ids:
|
||||
query &= Q(project__in=_ids_with_children(project_ids))
|
||||
ids_with_children = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=ids_with_children)
|
||||
user_projects_query &= Q(id__in=ids_with_children)
|
||||
|
||||
res = Task.objects(query).distinct(field="project")
|
||||
res = {p.id for p in Project.objects(user_projects_query).only("id")}
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="project"))
|
||||
|
||||
res = list(res)
|
||||
if not res:
|
||||
return res
|
||||
|
||||
@ -563,6 +574,7 @@ class ProjectBLL:
|
||||
cls,
|
||||
company_id: str,
|
||||
projects: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
state: Optional[EntityVisibility] = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
@ -571,7 +583,8 @@ class ProjectBLL:
|
||||
"""
|
||||
query = Q(company=company_id)
|
||||
if projects:
|
||||
projects = _ids_with_children(projects)
|
||||
if include_subprojects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
|
@ -13,7 +13,7 @@ import apiserver.database.utils as dbutils
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@ -37,7 +37,12 @@ from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import ChangeStatusRequest, validate_status_change, update_project_time, deleted_prefix
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
validate_status_change,
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
)
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@ -317,12 +322,19 @@ class TaskBLL:
|
||||
cls.validate_execution_model(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
def get_unique_metric_variants(
|
||||
company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company={"$in": [None, "", company_id]},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
company={"$in": [None, "", company_id]}, **project_constraint,
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
@ -601,11 +613,17 @@ class TaskBLL:
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str] = None,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
@ -613,7 +631,7 @@ class TaskBLL:
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
**project_constraint,
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
@ -687,6 +705,7 @@ class TaskBLL:
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
if allow_public:
|
||||
@ -694,6 +713,8 @@ class TaskBLL:
|
||||
else:
|
||||
company_constraint = {"company": company_id}
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
|
@ -193,8 +193,7 @@ class GetMixin(PropsMixin):
|
||||
"""
|
||||
Pop the parameters that match the specified patterns and return
|
||||
the dictionary of matching parameters
|
||||
For backwards compatibility with the previous version of the code
|
||||
the None values are discarded
|
||||
Pop None parameters since they are not the real queries
|
||||
"""
|
||||
if not patterns:
|
||||
return {}
|
||||
@ -351,11 +350,7 @@ class GetMixin(PropsMixin):
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
|
@ -36,7 +36,7 @@ class Project(AttributedDocument):
|
||||
min_length=3,
|
||||
sparse=True,
|
||||
)
|
||||
description = StringField(required=True)
|
||||
description = StringField()
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
|
@ -115,7 +115,7 @@ class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
framework = StringField()
|
||||
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
|
||||
docker_cmd = StringField()
|
||||
queue = StringField()
|
||||
queue = StringField(reference_field="Queue")
|
||||
""" Queue ID where task was queued """
|
||||
|
||||
|
||||
|
@ -151,6 +151,17 @@ get_by_id_ex {
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"}
|
||||
"2.13": ${get_all_ex."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
|
@ -283,17 +283,14 @@ create {
|
||||
description: "Create a new project"
|
||||
request {
|
||||
type: object
|
||||
required :[
|
||||
name
|
||||
description
|
||||
]
|
||||
required :[name]
|
||||
properties {
|
||||
name {
|
||||
description: "Project name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
description: "Project description."
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
@ -677,6 +674,17 @@ get_unique_metric_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_unique_metric_variants."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metrics/variants from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyperparam_values {
|
||||
"2.13" {
|
||||
@ -702,6 +710,11 @@ get_hyperparam_values {
|
||||
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes hyper parameters values from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
@ -761,6 +774,17 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_hyper_parameters."2.9"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes hyper parameters from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_task_tags {
|
||||
@ -830,7 +854,7 @@ make_private {
|
||||
}
|
||||
get_task_parents {
|
||||
"2.12" {
|
||||
description: "Get unique parent tasks for the tasks in the specified pprojects"
|
||||
description: "Get unique parent tasks for the tasks in the specified projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
@ -881,4 +905,15 @@ get_task_parents {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${get_task_parents."2.12"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the projects field is not empty then the result includes tasks parents from the subproject tasks"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -583,6 +583,17 @@ get_by_id_ex {
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"}
|
||||
"2.13": ${get_all_ex."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
|
@ -17,7 +17,7 @@ from apiserver.apimodels.models import (
|
||||
DeleteModelRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.config_repo import config
|
||||
@ -86,10 +86,22 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
include_subprojects = call_data.pop("include_subprojects", False)
|
||||
project_ids = call_data.get("project")
|
||||
if not project_ids or not include_subprojects:
|
||||
return
|
||||
|
||||
if not isinstance(project_ids, list):
|
||||
project_ids = [project_ids]
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
_process_include_subprojects(call.data)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
|
@ -8,12 +8,14 @@ from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
||||
from apiserver.apimodels.projects import (
|
||||
GetHyperParamRequest,
|
||||
ProjectRequest,
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
ProjectsGetRequest,
|
||||
DeleteRequest, MoveRequest, MergeRequest,
|
||||
DeleteRequest,
|
||||
MoveRequest,
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
@ -80,14 +82,13 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
|
||||
return
|
||||
|
||||
if "parent" not in data:
|
||||
data["parent"] = [None, ""]
|
||||
data["parent"] = [None]
|
||||
|
||||
|
||||
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
|
||||
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
conform_tag_fields(call, call.data)
|
||||
allow_public = not request.non_public
|
||||
shallow_search = request.shallow_search or request.include_stats
|
||||
with TimingContext("mongo", "projects_get_all"):
|
||||
data = call.data
|
||||
if request.active_users:
|
||||
@ -102,12 +103,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
return
|
||||
data["id"] = ids
|
||||
|
||||
_adjust_search_parameters(data, shallow_search=shallow_search)
|
||||
_adjust_search_parameters(data, shallow_search=request.shallow_search)
|
||||
|
||||
projects = Project.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=data,
|
||||
allow_public=allow_public,
|
||||
company=company_id, query_dict=data, allow_public=allow_public,
|
||||
)
|
||||
|
||||
conform_output_tags(call, projects)
|
||||
@ -147,9 +146,7 @@ def get_all(call: APICall):
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.create",
|
||||
required_fields=["name", "description"],
|
||||
response_data_model=IdResponse,
|
||||
"projects.create", required_fields=["name"], response_data_model=IdResponse,
|
||||
)
|
||||
def create(call: APICall):
|
||||
identity = call.identity
|
||||
@ -232,11 +229,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
call.result.data = {**attr.asdict(res)}
|
||||
|
||||
|
||||
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectRequest)
|
||||
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectRequest):
|
||||
@endpoint(
|
||||
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
|
||||
)
|
||||
def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
||||
):
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
company_id, [request.project] if request.project else None
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
)
|
||||
|
||||
call.result.data = {"metrics": metrics}
|
||||
@ -252,6 +255,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
@ -276,6 +280,7 @@ def get_hyperparam_values(
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
name=request.name,
|
||||
include_subprojects=request.include_subprojects,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
call.result.data = {
|
||||
@ -340,6 +345,9 @@ def get_task_parents(
|
||||
):
|
||||
call.result.data = {
|
||||
"parents": project_bll.get_task_parents(
|
||||
company_id, projects=request.projects, state=request.tasks_state
|
||||
company_id,
|
||||
projects=request.projects,
|
||||
include_subprojects=request.include_subprojects,
|
||||
state=request.tasks_state,
|
||||
)
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ from apiserver.apimodels.tasks import (
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
@ -152,27 +152,49 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
|
||||
call.result.data = {"task": task_dict}
|
||||
|
||||
|
||||
def escape_execution_parameters(call: APICall):
|
||||
projection = Task.get_projection(call.data)
|
||||
if projection:
|
||||
Task.set_projection(call.data, escape_paths(projection))
|
||||
def escape_execution_parameters(call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
|
||||
ordering = Task.get_ordering(call.data)
|
||||
keys = list(call.data)
|
||||
call_data = {
|
||||
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = Task.get_projection(call_data)
|
||||
if projection:
|
||||
Task.set_projection(call_data, escape_paths(projection))
|
||||
|
||||
ordering = Task.get_ordering(call_data)
|
||||
if ordering:
|
||||
Task.set_ordering(call.data, escape_paths(ordering))
|
||||
Task.set_ordering(call_data, escape_paths(ordering))
|
||||
|
||||
return call_data
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
include_subprojects = call_data.pop("include_subprojects", False)
|
||||
project_ids = call_data.get("project")
|
||||
if not project_ids or not include_subprojects:
|
||||
return
|
||||
|
||||
if not isinstance(project_ids, list):
|
||||
project_ids = [project_ids]
|
||||
call_data["project"] = project_ids_with_children(project_ids)
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query_dict=call_data,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
@ -183,12 +205,12 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True,
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
@ -199,14 +221,14 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
@ -216,7 +238,9 @@ def get_all(call: APICall, company_id, _):
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
def get_types(call: APICall, company_id, request: GetTypesRequest):
|
||||
call.result.data = {
|
||||
"types": list(project_bll.get_task_types(company_id, project_ids=request.projects))
|
||||
"types": list(
|
||||
project_bll.get_task_types(company_id, project_ids=request.projects)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
@ -10,16 +9,14 @@ import requests
|
||||
import six
|
||||
from boltons.iterutils import remap
|
||||
from boltons.typeutils import issubclass
|
||||
from pyhocon import ConfigFactory
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
|
||||
config = ConfigFactory.parse_file("api_client.conf")
|
||||
|
||||
log = logging.getLogger("api_client")
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class APICallResult:
|
||||
@ -111,7 +108,7 @@ class APIClient:
|
||||
self.api_key = (
|
||||
api_key
|
||||
or os.environ.get("SM_API_KEY")
|
||||
or config.get("api_key")
|
||||
or config.get("apiclient.api_key")
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError("APIClient requires api_key in constructor or config")
|
||||
@ -119,7 +116,7 @@ class APIClient:
|
||||
self.secret_key = (
|
||||
secret_key
|
||||
or os.environ.get("SM_API_SECRET")
|
||||
or config.get("secret_key")
|
||||
or config.get("apiclient.secret_key")
|
||||
)
|
||||
if not self.secret_key:
|
||||
raise ValueError(
|
||||
@ -127,7 +124,7 @@ class APIClient:
|
||||
)
|
||||
|
||||
self.base_url = (
|
||||
base_url or os.environ.get("SM_API_URL") or config.get("base_url")
|
||||
base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")
|
||||
)
|
||||
if not self.base_url:
|
||||
raise ValueError("APIClient requires base_url in constructor or config")
|
||||
@ -139,9 +136,9 @@ class APIClient:
|
||||
|
||||
# create http session
|
||||
self.http_session = requests.session()
|
||||
retries = config.get("retries", 7)
|
||||
backoff_factor = config.get("backoff_factor", 0.3)
|
||||
status_forcelist = config.get("status_forcelist", (500, 502, 504))
|
||||
retries = config.get("apiclient.retries", 7)
|
||||
backoff_factor = config.get("apiclient.backoff_factor", 0.3)
|
||||
status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504))
|
||||
retry = Retry(
|
||||
total=retries,
|
||||
read=retries,
|
||||
|
@ -1,65 +0,0 @@
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestProjectsRetrieval(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_active_user(self):
|
||||
user = self.api.users.get_current_user().user.id
|
||||
project1 = self.temp_project(name="Project retrieval1")
|
||||
project2 = self.temp_project(name="Project retrieval2")
|
||||
self.temp_task(project=project2)
|
||||
|
||||
projects = self.api.projects.get_all_ex().projects
|
||||
self.assertTrue({project1, project2}.issubset({p.id for p in projects}))
|
||||
|
||||
projects = self.api.projects.get_all_ex(active_users=[user]).projects
|
||||
ids = {p.id for p in projects}
|
||||
self.assertFalse(project1 in ids)
|
||||
self.assertTrue(project2 in ids)
|
||||
|
||||
def test_stats(self):
|
||||
project = self.temp_project()
|
||||
self.temp_task(project=project)
|
||||
self.temp_task(project=project)
|
||||
archived_task = self.temp_task(project=project)
|
||||
self.api.tasks.archive(tasks=[archived_task])
|
||||
|
||||
p = self._get_project(project)
|
||||
self.assertFalse("stats" in p)
|
||||
|
||||
p = self._get_project(project, include_stats=True)
|
||||
self.assertFalse("archived" in p.stats)
|
||||
self.assertTrue(p.stats.active.status_count.created, 2)
|
||||
|
||||
p = self._get_project(project, include_stats=True, stats_for_state=None)
|
||||
self.assertTrue(p.stats.active.status_count.created, 2)
|
||||
self.assertTrue(p.stats.archived.status_count.created, 1)
|
||||
|
||||
def _get_project(self, project, **kwargs):
|
||||
projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects
|
||||
p = first(p for p in projects if p.id == project)
|
||||
self.assertIsNotNone(p)
|
||||
return p
|
||||
|
||||
def temp_project(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
name="Test projects retrieval",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
||||
def temp_task(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test projects retrieval",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
@ -4,8 +4,10 @@ from typing import Sequence, Optional, Tuple
|
||||
from boltons.iterutils import first
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import id as db_id
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
@ -14,7 +16,14 @@ class TestSubProjects(TestService):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_project_aggregations(self):
|
||||
child = self._temp_project(name="Aggregation/Pr1")
|
||||
"""This test requires user with user_auth_only... credentials in db"""
|
||||
user2_client = APIClient(
|
||||
api_key=config.get("apiclient.user_auth_only"),
|
||||
secret_key=config.get("apiclient.user_auth_only_secret"),
|
||||
base_url=f"http://localhost:8008/v2.13",
|
||||
)
|
||||
|
||||
child = self._temp_project(name="Aggregation/Pr1", client=user2_client)
|
||||
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
|
||||
child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
|
||||
self.assertEqual(child_project.parent.id, project)
|
||||
@ -210,12 +219,13 @@ class TestSubProjects(TestService):
|
||||
|
||||
delete_params = dict(can_fail=True, force=True)
|
||||
|
||||
def _temp_project(self, name, **kwargs):
|
||||
def _temp_project(self, name, client=None, **kwargs):
|
||||
return self.create_temp(
|
||||
"projects",
|
||||
delete_params=self.delete_params,
|
||||
name=name,
|
||||
description="",
|
||||
client=client,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -225,4 +225,4 @@ class TestTasksEdit(TestService):
|
||||
self.api.tasks.enqueue(task=task_id, queue=queue_id)
|
||||
task = self.api.tasks.get_all_ex(id=task_id, projection=projection).tasks[0]
|
||||
self.assertEqual(task.status, "queued")
|
||||
self.assertEqual(task.execution.queue, queue_id)
|
||||
self.assertEqual(task.execution.queue.id, queue_id)
|
||||
|
Loading…
Reference in New Issue
Block a user