More sub-projects support and fixes

This commit is contained in:
allegroai 2021-05-03 17:44:54 +03:00
parent 0d5174c453
commit 3c5195028e
17 changed files with 225 additions and 142 deletions

View File

@ -1,3 +1,8 @@
301 {
_: "moved_permanently"
1: ["not_supported", "this endpoint is no longer supported for the requested API version"]
}
400 {
_: "bad_request"
1: ["not_supported", "endpoint is not supported"]

View File

@ -22,7 +22,12 @@ class DeleteRequest(ProjectRequest):
delete_contents = fields.BoolField(default=False)
class GetHyperParamRequest(ProjectRequest):
class ProjectOrNoneRequest(models.Base):
project = fields.StringField()
include_subprojects = fields.BoolField(default=True)
class GetHyperParamRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@ -33,6 +38,7 @@ class ProjectTagsRequest(TagsRequest):
class MultiProjectRequest(models.Base):
projects = fields.ListField(str)
include_subprojects = fields.BoolField(default=True)
class ProjectTaskParentsRequest(MultiProjectRequest):

View File

@ -147,7 +147,7 @@ class ProjectBLL:
user: str,
company: str,
name: str,
description: str,
description: str = "",
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
@ -507,20 +507,24 @@ class ProjectBLL:
company,
project_ids: Sequence[str],
user_ids: Optional[Sequence[str]] = None,
) -> set:
) -> Set[str]:
"""
Get the set of user ids that created tasks/models/dataviews in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
with TimingContext("mongo", "active_users_in_projects"):
res = set()
query = Q(company=company)
if user_ids:
query &= Q(user__in=user_ids)
projects_query = query
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
if user_ids:
query &= Q(user__in=user_ids)
projects_query &= Q(id__in=project_ids)
res = set(Project.objects(projects_query).distinct(field="user"))
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="user"))
@ -545,10 +549,17 @@ class ProjectBLL:
else:
query &= Q(company=company)
user_projects_query = query
if project_ids:
query &= Q(project__in=_ids_with_children(project_ids))
ids_with_children = _ids_with_children(project_ids)
query &= Q(project__in=ids_with_children)
user_projects_query &= Q(id__in=ids_with_children)
res = Task.objects(query).distinct(field="project")
res = {p.id for p in Project.objects(user_projects_query).only("id")}
for cls_ in (Task, Model):
res |= set(cls_.objects(query).distinct(field="project"))
res = list(res)
if not res:
return res
@ -563,6 +574,7 @@ class ProjectBLL:
cls,
company_id: str,
projects: Sequence[str],
include_subprojects: bool,
state: Optional[EntityVisibility] = None,
) -> Sequence[dict]:
"""
@ -571,7 +583,8 @@ class ProjectBLL:
"""
query = Q(company=company_id)
if projects:
projects = _ids_with_children(projects)
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])

View File

@ -13,7 +13,7 @@ import apiserver.database.utils as dbutils
from apiserver.apierrors import errors
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@ -37,7 +37,12 @@ from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change, update_project_time, deleted_prefix
from .utils import (
ChangeStatusRequest,
validate_status_change,
update_project_time,
deleted_prefix,
)
log = config.logger(__file__)
org_bll = OrgBLL()
@ -317,12 +322,19 @@ class TaskBLL:
cls.validate_execution_model(task)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
def get_unique_metric_variants(
company_id, project_ids: Sequence[str], include_subprojects: bool
):
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]},
**({"project": {"$in": project_ids}} if project_ids else {}),
company={"$in": [None, "", company_id]}, **project_constraint,
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
@ -601,11 +613,17 @@ class TaskBLL:
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
@ -613,7 +631,7 @@ class TaskBLL:
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
**project_constraint,
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
@ -687,6 +705,7 @@ class TaskBLL:
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
@ -694,6 +713,8 @@ class TaskBLL:
else:
company_constraint = {"company": company_id}
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}

View File

@ -193,8 +193,7 @@ class GetMixin(PropsMixin):
"""
Pop the parameters that match the specified patterns and return
the dictionary of matching parameters
For backwards compatibility with the previous version of the code
the None values are discarded
Pop None parameters since they are not the real queries
"""
if not patterns:
return {}
@ -351,11 +350,7 @@ class GetMixin(PropsMixin):
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
)
if not allow_empty:

View File

@ -36,7 +36,7 @@ class Project(AttributedDocument):
min_length=3,
sparse=True,
)
description = StringField(required=True)
description = StringField()
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))

View File

@ -115,7 +115,7 @@ class Execution(EmbeddedDocument, ProperDictMixin):
framework = StringField()
artifacts: Dict[str, Artifact] = SafeMapField(field=EmbeddedDocumentField(Artifact))
docker_cmd = StringField()
queue = StringField()
queue = StringField(reference_field="Queue")
""" Queue ID where task was queued """

View File

@ -151,6 +151,17 @@ get_by_id_ex {
get_all_ex {
internal: true
"2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then models from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
}
get_all {
"2.1" {

View File

@ -283,17 +283,14 @@ create {
description: "Create a new project"
request {
type: object
required :[
name
description
]
required :[name]
properties {
name {
description: "Project name Unique within the company."
type: string
}
description {
description: "Project description. "
description: "Project description."
type: string
}
tags {
@ -677,6 +674,17 @@ get_unique_metric_variants {
}
}
}
"2.13": ${get_unique_metric_variants."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metrics/variants from the subproject tasks"
type: boolean
default: true
}
}
}
}
}
get_hyperparam_values {
"2.13" {
@ -702,6 +710,11 @@ get_hyperparam_values {
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters values from the subproject tasks"
type: boolean
default: true
}
}
}
response {
@ -761,6 +774,17 @@ get_hyper_parameters {
}
}
}
"2.13": ${get_hyper_parameters."2.9"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes hyper parameters from the subproject tasks"
type: boolean
default: true
}
}
}
}
}
get_task_tags {
@ -830,7 +854,7 @@ make_private {
}
get_task_parents {
"2.12" {
description: "Get unique parent tasks for the tasks in the specified pprojects"
description: "Get unique parent tasks for the tasks in the specified projects"
request {
type: object
properties {
@ -881,4 +905,15 @@ get_task_parents {
}
}
}
"2.13": ${get_task_parents."2.12"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and the projects field is not empty then the result includes tasks parents from the subproject tasks"
type: boolean
default: true
}
}
}
}
}

View File

@ -583,6 +583,17 @@ get_by_id_ex {
get_all_ex {
internal: true
"2.1": ${get_all."2.1"}
"2.13": ${get_all_ex."2.1"} {
request {
properties {
include_subprojects {
description: "If set to 'true' and project field is set then tasks from the subprojects are searched too"
type: boolean
default: false
}
}
}
}
}
get_all {
"2.1" {

View File

@ -17,7 +17,7 @@ from apiserver.apimodels.models import (
DeleteModelRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.utils import deleted_prefix
from apiserver.config_repo import config
@ -86,10 +86,22 @@ def get_by_task_id(call: APICall, company_id, _):
call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True

View File

@ -8,12 +8,14 @@ from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
GetHyperParamRequest,
ProjectRequest,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
ProjectsGetRequest,
DeleteRequest, MoveRequest, MergeRequest,
DeleteRequest,
MoveRequest,
MergeRequest,
ProjectOrNoneRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
@ -80,14 +82,13 @@ def _adjust_search_parameters(data: dict, shallow_search: bool):
return
if "parent" not in data:
data["parent"] = [None, ""]
data["parent"] = [None]
@endpoint("projects.get_all_ex", request_data_model=ProjectsGetRequest)
def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_tag_fields(call, call.data)
allow_public = not request.non_public
shallow_search = request.shallow_search or request.include_stats
with TimingContext("mongo", "projects_get_all"):
data = call.data
if request.active_users:
@ -102,12 +103,10 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
return
data["id"] = ids
_adjust_search_parameters(data, shallow_search=shallow_search)
_adjust_search_parameters(data, shallow_search=request.shallow_search)
projects = Project.get_many_with_join(
company=company_id,
query_dict=data,
allow_public=allow_public,
company=company_id, query_dict=data, allow_public=allow_public,
)
conform_output_tags(call, projects)
@ -147,9 +146,7 @@ def get_all(call: APICall):
@endpoint(
"projects.create",
required_fields=["name", "description"],
response_data_model=IdResponse,
"projects.create", required_fields=["name"], response_data_model=IdResponse,
)
def create(call: APICall):
identity = call.identity
@ -232,11 +229,17 @@ def delete(call: APICall, company_id: str, request: DeleteRequest):
call.result.data = {**attr.asdict(res)}
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectRequest)
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectRequest):
@endpoint(
"projects.get_unique_metric_variants", request_data_model=ProjectOrNoneRequest
)
def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
):
metrics = task_bll.get_unique_metric_variants(
company_id, [request.project] if request.project else None
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
)
call.result.data = {"metrics": metrics}
@ -252,6 +255,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
@ -276,6 +280,7 @@ def get_hyperparam_values(
project_ids=request.projects,
section=request.section,
name=request.name,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
)
call.result.data = {
@ -340,6 +345,9 @@ def get_task_parents(
):
call.result.data = {
"parents": project_bll.get_task_parents(
company_id, projects=request.projects, state=request.tasks_state
company_id,
projects=request.projects,
include_subprojects=request.include_subprojects,
state=request.tasks_state,
)
}

View File

@ -47,7 +47,7 @@ from apiserver.apimodels.tasks import (
)
from apiserver.bll.event import EventBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
@ -152,27 +152,49 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
call.result.data = {"task": task_dict}
def escape_execution_parameters(call: APICall):
projection = Task.get_projection(call.data)
if projection:
Task.set_projection(call.data, escape_paths(projection))
def escape_execution_parameters(call: APICall) -> dict:
if not call.data:
return call.data
ordering = Task.get_ordering(call.data)
keys = list(call.data)
call_data = {
safe_key: call.data[key] for key, safe_key in zip(keys, escape_paths(keys))
}
projection = Task.get_projection(call_data)
if projection:
Task.set_projection(call_data, escape_paths(projection))
ordering = Task.get_ordering(call_data)
if ordering:
Task.set_ordering(call.data, escape_paths(ordering))
Task.set_ordering(call_data, escape_paths(ordering))
return call_data
def _process_include_subprojects(call_data: dict):
include_subprojects = call_data.pop("include_subprojects", False)
project_ids = call_data.get("project")
if not project_ids or not include_subprojects:
return
if not isinstance(project_ids, list):
project_ids = [project_ids]
call_data["project"] = project_ids_with_children(project_ids)
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call.data,
query_dict=call_data,
allow_public=True, # required in case projection is requested for public dataset/versions
)
unprepare_from_saved(call, tasks)
@ -183,12 +205,12 @@ def get_all_ex(call: APICall, company_id, _):
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True,
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
@ -199,14 +221,14 @@ def get_by_id_ex(call: APICall, company_id, _):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
parameters=call_data,
query_dict=call_data,
allow_public=True, # required in case projection is requested for public dataset/versions
)
unprepare_from_saved(call, tasks)
@ -216,7 +238,9 @@ def get_all(call: APICall, company_id, _):
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
def get_types(call: APICall, company_id, request: GetTypesRequest):
call.result.data = {
"types": list(project_bll.get_task_types(company_id, project_ids=request.projects))
"types": list(
project_bll.get_task_types(company_id, project_ids=request.projects)
)
}

View File

@ -1,5 +1,4 @@
import json
import logging
import os
import time
from contextlib import contextmanager
@ -10,16 +9,14 @@ import requests
import six
from boltons.iterutils import remap
from boltons.typeutils import issubclass
from pyhocon import ConfigFactory
from requests.adapters import HTTPAdapter
from requests.auth import HTTPBasicAuth
from requests.packages.urllib3.util.retry import Retry
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
config = ConfigFactory.parse_file("api_client.conf")
log = logging.getLogger("api_client")
log = config.logger(__file__)
class APICallResult:
@ -111,7 +108,7 @@ class APIClient:
self.api_key = (
api_key
or os.environ.get("SM_API_KEY")
or config.get("api_key")
or config.get("apiclient.api_key")
)
if not self.api_key:
raise ValueError("APIClient requires api_key in constructor or config")
@ -119,7 +116,7 @@ class APIClient:
self.secret_key = (
secret_key
or os.environ.get("SM_API_SECRET")
or config.get("secret_key")
or config.get("apiclient.secret_key")
)
if not self.secret_key:
raise ValueError(
@ -127,7 +124,7 @@ class APIClient:
)
self.base_url = (
base_url or os.environ.get("SM_API_URL") or config.get("base_url")
base_url or os.environ.get("SM_API_URL") or config.get("apiclient.base_url")
)
if not self.base_url:
raise ValueError("APIClient requires base_url in constructor or config")
@ -139,9 +136,9 @@ class APIClient:
# create http session
self.http_session = requests.session()
retries = config.get("retries", 7)
backoff_factor = config.get("backoff_factor", 0.3)
status_forcelist = config.get("status_forcelist", (500, 502, 504))
retries = config.get("apiclient.retries", 7)
backoff_factor = config.get("apiclient.backoff_factor", 0.3)
status_forcelist = config.get("apiclient.status_forcelist", (500, 502, 504))
retry = Retry(
total=retries,
read=retries,

View File

@ -1,65 +0,0 @@
from boltons.iterutils import first
from apiserver.tests.automated import TestService
class TestProjectsRetrieval(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_active_user(self):
user = self.api.users.get_current_user().user.id
project1 = self.temp_project(name="Project retrieval1")
project2 = self.temp_project(name="Project retrieval2")
self.temp_task(project=project2)
projects = self.api.projects.get_all_ex().projects
self.assertTrue({project1, project2}.issubset({p.id for p in projects}))
projects = self.api.projects.get_all_ex(active_users=[user]).projects
ids = {p.id for p in projects}
self.assertFalse(project1 in ids)
self.assertTrue(project2 in ids)
def test_stats(self):
project = self.temp_project()
self.temp_task(project=project)
self.temp_task(project=project)
archived_task = self.temp_task(project=project)
self.api.tasks.archive(tasks=[archived_task])
p = self._get_project(project)
self.assertFalse("stats" in p)
p = self._get_project(project, include_stats=True)
self.assertFalse("archived" in p.stats)
self.assertTrue(p.stats.active.status_count.created, 2)
p = self._get_project(project, include_stats=True, stats_for_state=None)
self.assertTrue(p.stats.active.status_count.created, 2)
self.assertTrue(p.stats.archived.status_count.created, 1)
def _get_project(self, project, **kwargs):
projects = self.api.projects.get_all_ex(id=[project], **kwargs).projects
p = first(p for p in projects if p.id == project)
self.assertIsNotNone(p)
return p
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test projects retrieval",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)
def temp_task(self, **kwargs) -> str:
self.update_missing(
kwargs,
type="testing",
name="test projects retrieval",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)

View File

@ -4,8 +4,10 @@ from typing import Sequence, Optional, Tuple
from boltons.iterutils import first
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import id as db_id
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService
@ -14,7 +16,14 @@ class TestSubProjects(TestService):
super().setUp(version="2.13")
def test_project_aggregations(self):
child = self._temp_project(name="Aggregation/Pr1")
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
api_key=config.get("apiclient.user_auth_only"),
secret_key=config.get("apiclient.user_auth_only_secret"),
base_url=f"http://localhost:8008/v2.13",
)
child = self._temp_project(name="Aggregation/Pr1", client=user2_client)
project = self.api.projects.get_all_ex(name="^Aggregation$").projects[0].id
child_project = self.api.projects.get_all_ex(id=[child]).projects[0]
self.assertEqual(child_project.parent.id, project)
@ -210,12 +219,13 @@ class TestSubProjects(TestService):
delete_params = dict(can_fail=True, force=True)
def _temp_project(self, name, **kwargs):
def _temp_project(self, name, client=None, **kwargs):
return self.create_temp(
"projects",
delete_params=self.delete_params,
name=name,
description="",
client=client,
**kwargs,
)

View File

@ -225,4 +225,4 @@ class TestTasksEdit(TestService):
self.api.tasks.enqueue(task=task_id, queue=queue_id)
task = self.api.tasks.get_all_ex(id=task_id, projection=projection).tasks[0]
self.assertEqual(task.status, "queued")
self.assertEqual(task.execution.queue, queue_id)
self.assertEqual(task.execution.queue.id, queue_id)