mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Add filtering on child projects in projects.get_all_ex
This commit is contained in:
parent
2fb9288a6c
commit
74200a24bd
@ -1,4 +1,5 @@
|
|||||||
from jsonmodels import models, fields
|
from jsonmodels import models, fields
|
||||||
|
from jsonmodels.fields import EmbeddedField
|
||||||
|
|
||||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||||
from apiserver.apimodels.organization import TagsRequest
|
from apiserver.apimodels.organization import TagsRequest
|
||||||
@ -56,6 +57,10 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
|||||||
allow_public = fields.BoolField(default=True)
|
allow_public = fields.BoolField(default=True)
|
||||||
|
|
||||||
|
|
||||||
|
class ChildrenCondition(models.Base):
|
||||||
|
system_tags = fields.ListField([str])
|
||||||
|
|
||||||
|
|
||||||
class ProjectsGetRequest(models.Base):
|
class ProjectsGetRequest(models.Base):
|
||||||
include_dataset_stats = fields.BoolField(default=False)
|
include_dataset_stats = fields.BoolField(default=False)
|
||||||
include_stats = fields.BoolField(default=False)
|
include_stats = fields.BoolField(default=False)
|
||||||
@ -68,3 +73,4 @@ class ProjectsGetRequest(models.Base):
|
|||||||
shallow_search = fields.BoolField(default=False)
|
shallow_search = fields.BoolField(default=False)
|
||||||
search_hidden = fields.BoolField(default=False)
|
search_hidden = fields.BoolField(default=False)
|
||||||
allow_public = fields.BoolField(default=True)
|
allow_public = fields.BoolField(default=True)
|
||||||
|
children_condition = EmbeddedField(ChildrenCondition)
|
||||||
|
@ -571,7 +571,7 @@ class ProjectBLL:
|
|||||||
search_hidden: bool = False,
|
search_hidden: bool = False,
|
||||||
filter_: Mapping[str, Any] = None,
|
filter_: Mapping[str, Any] = None,
|
||||||
users: Sequence[str] = None,
|
users: Sequence[str] = None,
|
||||||
user_active_project_ids: Sequence[str] = None,
|
selected_project_ids: Sequence[str] = None,
|
||||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||||
if not project_ids:
|
if not project_ids:
|
||||||
return {}, {}
|
return {}, {}
|
||||||
@ -581,7 +581,7 @@ class ProjectBLL:
|
|||||||
project_ids,
|
project_ids,
|
||||||
_only=("id", "name"),
|
_only=("id", "name"),
|
||||||
search_hidden=search_hidden,
|
search_hidden=search_hidden,
|
||||||
allowed_ids=user_active_project_ids,
|
allowed_ids=selected_project_ids,
|
||||||
)
|
)
|
||||||
if include_children
|
if include_children
|
||||||
else {}
|
else {}
|
||||||
@ -753,46 +753,65 @@ class ProjectBLL:
|
|||||||
return tags, system_tags
|
return tags, system_tags
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_projects_with_active_user(
|
def get_projects_with_selected_children(
|
||||||
cls,
|
cls,
|
||||||
company: str,
|
company: str,
|
||||||
users: Sequence[str],
|
users: Sequence[str] = None,
|
||||||
project_ids: Optional[Sequence[str]] = None,
|
project_ids: Optional[Sequence[str]] = None,
|
||||||
allow_public: bool = True,
|
allow_public: bool = True,
|
||||||
|
children_condition: Mapping[str, Any] = None,
|
||||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||||
"""
|
"""
|
||||||
Get the projects ids where user created any tasks including all the parents of these projects
|
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
|
||||||
|
including all the parents of these projects
|
||||||
If project ids are specified then filter the results by these project ids
|
If project ids are specified then filter the results by these project ids
|
||||||
"""
|
"""
|
||||||
query = Q(user__in=users)
|
if not (users or children_condition):
|
||||||
|
raise errors.bad_request.ValidationError(
|
||||||
|
"Either active users or children_condition should be specified"
|
||||||
|
)
|
||||||
|
|
||||||
if allow_public:
|
projects_query = Project.prepare_query(
|
||||||
query &= get_company_or_none_constraint(company)
|
company, parameters=children_condition, allow_public=allow_public
|
||||||
|
)
|
||||||
|
if children_condition:
|
||||||
|
contained_entities_query = None
|
||||||
else:
|
else:
|
||||||
query &= Q(company=company)
|
contained_entities_query = (
|
||||||
|
get_company_or_none_constraint(company)
|
||||||
|
if allow_public
|
||||||
|
else Q(company=company)
|
||||||
|
)
|
||||||
|
|
||||||
|
if users:
|
||||||
|
user_query = Q(user__in=users)
|
||||||
|
projects_query &= user_query
|
||||||
|
if contained_entities_query:
|
||||||
|
contained_entities_query &= user_query
|
||||||
|
|
||||||
user_projects_query = query
|
|
||||||
if project_ids:
|
if project_ids:
|
||||||
ids_with_children = _ids_with_children(project_ids)
|
ids_with_children = _ids_with_children(project_ids)
|
||||||
query &= Q(project__in=ids_with_children)
|
projects_query &= Q(id__in=ids_with_children)
|
||||||
user_projects_query &= Q(id__in=ids_with_children)
|
if contained_entities_query:
|
||||||
|
contained_entities_query &= Q(project__in=ids_with_children)
|
||||||
|
|
||||||
res = {p.id for p in Project.objects(user_projects_query).only("id")}
|
res = {p.id for p in Project.objects(projects_query).only("id")}
|
||||||
|
if contained_entities_query:
|
||||||
for cls_ in (Task, Model):
|
for cls_ in (Task, Model):
|
||||||
res |= set(cls_.objects(query).distinct(field="project"))
|
res |= set(cls_.objects(contained_entities_query).distinct(field="project"))
|
||||||
|
|
||||||
res = list(res)
|
res = list(res)
|
||||||
if not res:
|
if not res:
|
||||||
return res, res
|
return res, res
|
||||||
|
|
||||||
user_active_project_ids = _ids_with_parents(res)
|
selected_project_ids = _ids_with_parents(res)
|
||||||
filtered_ids = (
|
filtered_ids = (
|
||||||
list(set(user_active_project_ids) & set(project_ids))
|
list(set(selected_project_ids) & set(project_ids))
|
||||||
if project_ids
|
if project_ids
|
||||||
else list(user_active_project_ids)
|
else list(selected_project_ids)
|
||||||
)
|
)
|
||||||
|
|
||||||
return filtered_ids, user_active_project_ids
|
return filtered_ids, selected_project_ids
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_task_parents(
|
def get_task_parents(
|
||||||
|
@ -41,10 +41,6 @@
|
|||||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||||
# but not declared in a data model
|
# but not declared in a data model
|
||||||
strict: false
|
strict: false
|
||||||
|
|
||||||
aggregate {
|
|
||||||
allow_disk_use: true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
elastic {
|
elastic {
|
||||||
|
@ -2,3 +2,8 @@ max_page_size: 500
|
|||||||
|
|
||||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||||
scroll_state_expiration_seconds: 600
|
scroll_state_expiration_seconds: 600
|
||||||
|
|
||||||
|
allow_disk_use {
|
||||||
|
# sort: true
|
||||||
|
aggregate: true
|
||||||
|
}
|
@ -17,7 +17,7 @@ from typing import (
|
|||||||
|
|
||||||
from boltons.iterutils import first, partition
|
from boltons.iterutils import first, partition
|
||||||
from dateutil.parser import parse as parse_datetime
|
from dateutil.parser import parse as parse_datetime
|
||||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
|
||||||
from pymongo.command_cursor import CommandCursor
|
from pymongo.command_cursor import CommandCursor
|
||||||
|
|
||||||
from apiserver.apierrors import errors, APIError
|
from apiserver.apierrors import errors, APIError
|
||||||
@ -39,7 +39,7 @@ from apiserver.redis_manager import redman
|
|||||||
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
|
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
|
||||||
|
|
||||||
log = config.logger("dbmodel")
|
log = config.logger("dbmodel")
|
||||||
|
mongo_conf = config.get("services._mongo")
|
||||||
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
||||||
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
||||||
|
|
||||||
@ -158,7 +158,9 @@ class GetMixin(PropsMixin):
|
|||||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
op = (
|
op = (
|
||||||
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
v[len(self.op_prefix) :]
|
||||||
|
if v and v.startswith(self.op_prefix)
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if translate:
|
if translate:
|
||||||
tup = self._ops.get(op, None)
|
tup = self._ops.get(op, None)
|
||||||
@ -166,7 +168,9 @@ class GetMixin(PropsMixin):
|
|||||||
return op
|
return op
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise errors.bad_request.FieldsValueError(
|
raise errors.bad_request.FieldsValueError(
|
||||||
"invalid value type, string expected", field=self._field, value=str(v)
|
"invalid value type, string expected",
|
||||||
|
field=self._field,
|
||||||
|
value=str(v),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||||
@ -233,8 +237,8 @@ class GetMixin(PropsMixin):
|
|||||||
cls._cache_manager = RedisCacheManager(
|
cls._cache_manager = RedisCacheManager(
|
||||||
state_class=cls.GetManyScrollState,
|
state_class=cls.GetManyScrollState,
|
||||||
redis=redman.connection("apiserver"),
|
redis=redman.connection("apiserver"),
|
||||||
expiration_interval=config.get(
|
expiration_interval=mongo_conf.get(
|
||||||
"services._mongo.scroll_state_expiration_seconds", 600
|
"scroll_state_expiration_seconds", 600
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -451,7 +455,9 @@ class GetMixin(PropsMixin):
|
|||||||
raise
|
raise
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise errors.bad_request.FieldsValueError(
|
raise errors.bad_request.FieldsValueError(
|
||||||
"failed parsing query field", error=str(ex), **({"field": field} if field else {})
|
"failed parsing query field",
|
||||||
|
error=str(ex),
|
||||||
|
**({"field": field} if field else {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return query & RegexQ(**dict_query)
|
return query & RegexQ(**dict_query)
|
||||||
@ -570,7 +576,7 @@ class GetMixin(PropsMixin):
|
|||||||
if start is not None:
|
if start is not None:
|
||||||
return start, cls.validate_scroll_size(parameters)
|
return start, cls.validate_scroll_size(parameters)
|
||||||
|
|
||||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
max_page_size = mongo_conf.get("max_page_size", 500)
|
||||||
page = parameters.get("page", default_page)
|
page = parameters.get("page", default_page)
|
||||||
if page is not None and page < 0:
|
if page is not None and page < 0:
|
||||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||||
@ -880,6 +886,13 @@ class GetMixin(PropsMixin):
|
|||||||
|
|
||||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_qs_with_ordering(qs: QuerySet, order_by: Sequence):
|
||||||
|
disk_use_setting = mongo_conf.get("allow_disk_use.sort", None)
|
||||||
|
if disk_use_setting is not None:
|
||||||
|
qs = qs.allow_disk_use(disk_use_setting)
|
||||||
|
return qs.order_by(*order_by)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_many_no_company(
|
def _get_many_no_company(
|
||||||
cls: Union["GetMixin", Document],
|
cls: Union["GetMixin", Document],
|
||||||
@ -1173,7 +1186,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
|||||||
kwargs.update(
|
kwargs.update(
|
||||||
allowDiskUse=allow_disk_use
|
allowDiskUse=allow_disk_use
|
||||||
if allow_disk_use is not None
|
if allow_disk_use is not None
|
||||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
else mongo_conf.get("allow_disk_use.aggregate", True)
|
||||||
)
|
)
|
||||||
return cls.objects.aggregate(pipeline, **kwargs)
|
return cls.objects.aggregate(pipeline, **kwargs)
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from apiserver.database.fields import (
|
|||||||
SafeSortedListField,
|
SafeSortedListField,
|
||||||
EmbeddedDocumentListField,
|
EmbeddedDocumentListField,
|
||||||
NullableStringField,
|
NullableStringField,
|
||||||
|
NoneType,
|
||||||
)
|
)
|
||||||
from apiserver.database.model import AttributedDocument
|
from apiserver.database.model import AttributedDocument
|
||||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||||
@ -89,7 +90,9 @@ class Artifact(EmbeddedDocument):
|
|||||||
content_size = LongField()
|
content_size = LongField()
|
||||||
timestamp = LongField()
|
timestamp = LongField()
|
||||||
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
||||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
display_data = SafeSortedListField(
|
||||||
|
ListField(UnionField((int, float, str, NoneType)))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||||
@ -231,6 +234,7 @@ class Task(AttributedDocument):
|
|||||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||||
datetime_fields=("status_changed", "last_update"),
|
datetime_fields=("status_changed", "last_update"),
|
||||||
pattern_fields=("name", "comment", "report"),
|
pattern_fields=("name", "comment", "report"),
|
||||||
|
fields=("execution.queue", "runtime.*", "models.input.model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = StringField(primary_key=True)
|
id = StringField(primary_key=True)
|
||||||
|
@ -620,6 +620,19 @@ get_all_ex {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.24": ${get_all_ex."2.23"} {
|
||||||
|
request.properties.children_condition {
|
||||||
|
description: The filter that any of the child projects should match in order that the parent will be included
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
system_tags {
|
||||||
|
description: The list of system tags to match from
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
additionalProperties: true
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
update {
|
update {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
|
@ -76,7 +76,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
|
|||||||
requested_ids = data.get("id")
|
requested_ids = data.get("id")
|
||||||
if isinstance(requested_ids, str):
|
if isinstance(requested_ids, str):
|
||||||
requested_ids = [requested_ids]
|
requested_ids = [requested_ids]
|
||||||
ids, _ = project_bll.get_projects_with_active_user(
|
ids, _ = project_bll.get_projects_with_selected_children(
|
||||||
company=company,
|
company=company,
|
||||||
users=request.active_users,
|
users=request.active_users,
|
||||||
project_ids=requested_ids,
|
project_ids=requested_ids,
|
||||||
|
@ -114,13 +114,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
|||||||
_adjust_search_parameters(
|
_adjust_search_parameters(
|
||||||
data, shallow_search=request.shallow_search,
|
data, shallow_search=request.shallow_search,
|
||||||
)
|
)
|
||||||
user_active_project_ids = None
|
selected_project_ids = None
|
||||||
if request.active_users:
|
if request.active_users or request.children_condition:
|
||||||
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
|
ids, selected_project_ids = project_bll.get_projects_with_selected_children(
|
||||||
company=company_id,
|
company=company_id,
|
||||||
users=request.active_users,
|
users=request.active_users,
|
||||||
project_ids=requested_ids,
|
project_ids=requested_ids,
|
||||||
allow_public=allow_public,
|
allow_public=allow_public,
|
||||||
|
children_condition=request.children_condition.to_struct()
|
||||||
|
if request.children_condition
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
if not ids:
|
if not ids:
|
||||||
return {"projects": []}
|
return {"projects": []}
|
||||||
@ -158,7 +161,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
|||||||
search_hidden=request.search_hidden,
|
search_hidden=request.search_hidden,
|
||||||
filter_=request.include_stats_filter,
|
filter_=request.include_stats_filter,
|
||||||
users=request.active_users,
|
users=request.active_users,
|
||||||
user_active_project_ids=user_active_project_ids,
|
selected_project_ids=selected_project_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
for project in projects:
|
for project in projects:
|
||||||
|
@ -33,6 +33,31 @@ class TestSubProjects(TestService):
|
|||||||
).projects[0]
|
).projects[0]
|
||||||
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
|
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
|
||||||
|
|
||||||
|
def test_query_children(self):
|
||||||
|
test_root_name = "TestQueryChildren"
|
||||||
|
test_root = self._temp_project(name=test_root_name)
|
||||||
|
child_with_tag = self._temp_project(
|
||||||
|
name=f"{test_root_name}/Project1/WithTag", system_tags=["test"]
|
||||||
|
)
|
||||||
|
child_without_tag = self._temp_project(name=f"{test_root_name}/Project2/WithoutTag")
|
||||||
|
|
||||||
|
projects = self.api.projects.get_all_ex(parent=[test_root], shallow_search=True).projects
|
||||||
|
self.assertEqual({p.basename for p in projects}, {"Project1", "Project2"})
|
||||||
|
|
||||||
|
projects = self.api.projects.get_all_ex(
|
||||||
|
parent=[test_root], children_condition={"system_tags": ["test"]}, shallow_search=True
|
||||||
|
).projects
|
||||||
|
self.assertEqual({p.basename for p in projects}, {"Project1"})
|
||||||
|
projects = self.api.projects.get_all_ex(
|
||||||
|
parent=[projects[0].id], children_condition={"system_tags": ["test"]}, shallow_search=True
|
||||||
|
).projects
|
||||||
|
self.assertEqual(projects[0].id, child_with_tag)
|
||||||
|
|
||||||
|
projects = self.api.projects.get_all_ex(
|
||||||
|
parent=[test_root], children_condition={"system_tags": ["not existent"]}, shallow_search=True
|
||||||
|
).projects
|
||||||
|
self.assertEqual(len(projects), 0)
|
||||||
|
|
||||||
def test_project_aggregations(self):
|
def test_project_aggregations(self):
|
||||||
"""This test requires user with user_auth_only... credentials in db"""
|
"""This test requires user with user_auth_only... credentials in db"""
|
||||||
user2_client = APIClient(
|
user2_client = APIClient(
|
||||||
|
Loading…
Reference in New Issue
Block a user