mirror of
https://github.com/clearml/clearml-server
synced 2025-03-09 21:51:54 +00:00
Add filtering on child projects in projects.get_all_ex
This commit is contained in:
parent
2fb9288a6c
commit
74200a24bd
@ -1,4 +1,5 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.fields import EmbeddedField
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
@ -56,6 +57,10 @@ class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ChildrenCondition(models.Base):
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_dataset_stats = fields.BoolField(default=False)
|
||||
include_stats = fields.BoolField(default=False)
|
||||
@ -68,3 +73,4 @@ class ProjectsGetRequest(models.Base):
|
||||
shallow_search = fields.BoolField(default=False)
|
||||
search_hidden = fields.BoolField(default=False)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
children_condition = EmbeddedField(ChildrenCondition)
|
||||
|
@ -571,7 +571,7 @@ class ProjectBLL:
|
||||
search_hidden: bool = False,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
users: Sequence[str] = None,
|
||||
user_active_project_ids: Sequence[str] = None,
|
||||
selected_project_ids: Sequence[str] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
@ -581,7 +581,7 @@ class ProjectBLL:
|
||||
project_ids,
|
||||
_only=("id", "name"),
|
||||
search_hidden=search_hidden,
|
||||
allowed_ids=user_active_project_ids,
|
||||
allowed_ids=selected_project_ids,
|
||||
)
|
||||
if include_children
|
||||
else {}
|
||||
@ -753,46 +753,65 @@ class ProjectBLL:
|
||||
return tags, system_tags
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
def get_projects_with_selected_children(
|
||||
cls,
|
||||
company: str,
|
||||
users: Sequence[str],
|
||||
users: Sequence[str] = None,
|
||||
project_ids: Optional[Sequence[str]] = None,
|
||||
allow_public: bool = True,
|
||||
children_condition: Mapping[str, Any] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
"""
|
||||
Get the projects ids where user created any tasks including all the parents of these projects
|
||||
Get the projects ids matching children_condition (if passed) or where the passed user created any tasks
|
||||
including all the parents of these projects
|
||||
If project ids are specified then filter the results by these project ids
|
||||
"""
|
||||
query = Q(user__in=users)
|
||||
if not (users or children_condition):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either active users or children_condition should be specified"
|
||||
)
|
||||
|
||||
if allow_public:
|
||||
query &= get_company_or_none_constraint(company)
|
||||
projects_query = Project.prepare_query(
|
||||
company, parameters=children_condition, allow_public=allow_public
|
||||
)
|
||||
if children_condition:
|
||||
contained_entities_query = None
|
||||
else:
|
||||
query &= Q(company=company)
|
||||
contained_entities_query = (
|
||||
get_company_or_none_constraint(company)
|
||||
if allow_public
|
||||
else Q(company=company)
|
||||
)
|
||||
|
||||
if users:
|
||||
user_query = Q(user__in=users)
|
||||
projects_query &= user_query
|
||||
if contained_entities_query:
|
||||
contained_entities_query &= user_query
|
||||
|
||||
user_projects_query = query
|
||||
if project_ids:
|
||||
ids_with_children = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=ids_with_children)
|
||||
user_projects_query &= Q(id__in=ids_with_children)
|
||||
projects_query &= Q(id__in=ids_with_children)
|
||||
if contained_entities_query:
|
||||
contained_entities_query &= Q(project__in=ids_with_children)
|
||||
|
||||
res = {p.id for p in Project.objects(user_projects_query).only("id")}
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(query).distinct(field="project"))
|
||||
res = {p.id for p in Project.objects(projects_query).only("id")}
|
||||
if contained_entities_query:
|
||||
for cls_ in (Task, Model):
|
||||
res |= set(cls_.objects(contained_entities_query).distinct(field="project"))
|
||||
|
||||
res = list(res)
|
||||
if not res:
|
||||
return res, res
|
||||
|
||||
user_active_project_ids = _ids_with_parents(res)
|
||||
selected_project_ids = _ids_with_parents(res)
|
||||
filtered_ids = (
|
||||
list(set(user_active_project_ids) & set(project_ids))
|
||||
list(set(selected_project_ids) & set(project_ids))
|
||||
if project_ids
|
||||
else list(user_active_project_ids)
|
||||
else list(selected_project_ids)
|
||||
)
|
||||
|
||||
return filtered_ids, user_active_project_ids
|
||||
return filtered_ids, selected_project_ids
|
||||
|
||||
@classmethod
|
||||
def get_task_parents(
|
||||
|
@ -41,10 +41,6 @@
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
strict: false
|
||||
|
||||
aggregate {
|
||||
allow_disk_use: true
|
||||
}
|
||||
}
|
||||
|
||||
elastic {
|
||||
|
@ -2,3 +2,8 @@ max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
||||
|
||||
allow_disk_use {
|
||||
# sort: true
|
||||
aggregate: true
|
||||
}
|
@ -17,7 +17,7 @@ from typing import (
|
||||
|
||||
from boltons.iterutils import first, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField, QuerySet
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors, APIError
|
||||
@ -39,7 +39,7 @@ from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import project_dict, exclude_fields_from_dict
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
mongo_conf = config.get("services._mongo")
|
||||
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
||||
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
||||
|
||||
@ -158,7 +158,9 @@ class GetMixin(PropsMixin):
|
||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||
try:
|
||||
op = (
|
||||
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
||||
v[len(self.op_prefix) :]
|
||||
if v and v.startswith(self.op_prefix)
|
||||
else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
@ -166,7 +168,9 @@ class GetMixin(PropsMixin):
|
||||
return op
|
||||
except AttributeError:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"invalid value type, string expected", field=self._field, value=str(v)
|
||||
"invalid value type, string expected",
|
||||
field=self._field,
|
||||
value=str(v),
|
||||
)
|
||||
|
||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||
@ -233,8 +237,8 @@ class GetMixin(PropsMixin):
|
||||
cls._cache_manager = RedisCacheManager(
|
||||
state_class=cls.GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
expiration_interval=mongo_conf.get(
|
||||
"scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
@ -451,7 +455,9 @@ class GetMixin(PropsMixin):
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"failed parsing query field", error=str(ex), **({"field": field} if field else {})
|
||||
"failed parsing query field",
|
||||
error=str(ex),
|
||||
**({"field": field} if field else {}),
|
||||
)
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
@ -570,7 +576,7 @@ class GetMixin(PropsMixin):
|
||||
if start is not None:
|
||||
return start, cls.validate_scroll_size(parameters)
|
||||
|
||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
||||
max_page_size = mongo_conf.get("max_page_size", 500)
|
||||
page = parameters.get("page", default_page)
|
||||
if page is not None and page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
@ -880,6 +886,13 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@staticmethod
|
||||
def _get_qs_with_ordering(qs: QuerySet, order_by: Sequence):
|
||||
disk_use_setting = mongo_conf.get("allow_disk_use.sort", None)
|
||||
if disk_use_setting is not None:
|
||||
qs = qs.allow_disk_use(disk_use_setting)
|
||||
return qs.order_by(*order_by)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
@ -1173,7 +1186,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
kwargs.update(
|
||||
allowDiskUse=allow_disk_use
|
||||
if allow_disk_use is not None
|
||||
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
|
||||
else mongo_conf.get("allow_disk_use.aggregate", True)
|
||||
)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
|
@ -19,6 +19,7 @@ from apiserver.database.fields import (
|
||||
SafeSortedListField,
|
||||
EmbeddedDocumentListField,
|
||||
NullableStringField,
|
||||
NoneType,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
@ -89,7 +90,9 @@ class Artifact(EmbeddedDocument):
|
||||
content_size = LongField()
|
||||
timestamp = LongField()
|
||||
type_data = EmbeddedDocumentField(ArtifactTypeData)
|
||||
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
|
||||
display_data = SafeSortedListField(
|
||||
ListField(UnionField((int, float, str, NoneType)))
|
||||
)
|
||||
|
||||
|
||||
class ParamsItem(EmbeddedDocument, ProperDictMixin):
|
||||
@ -231,6 +234,7 @@ class Task(AttributedDocument):
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
pattern_fields=("name", "comment", "report"),
|
||||
fields=("execution.queue", "runtime.*", "models.input.model"),
|
||||
)
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
|
@ -620,6 +620,19 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.24": ${get_all_ex."2.23"} {
|
||||
request.properties.children_condition {
|
||||
description: The filter that any of the child projects should match in order that the parent will be included
|
||||
type: object
|
||||
properties {
|
||||
system_tags {
|
||||
description: The list of system tags to match from
|
||||
type: string
|
||||
}
|
||||
}
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
|
@ -76,7 +76,7 @@ def get_entities_count(call: APICall, company, request: EntitiesCountRequest):
|
||||
requested_ids = data.get("id")
|
||||
if isinstance(requested_ids, str):
|
||||
requested_ids = [requested_ids]
|
||||
ids, _ = project_bll.get_projects_with_active_user(
|
||||
ids, _ = project_bll.get_projects_with_selected_children(
|
||||
company=company,
|
||||
users=request.active_users,
|
||||
project_ids=requested_ids,
|
||||
|
@ -114,13 +114,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
_adjust_search_parameters(
|
||||
data, shallow_search=request.shallow_search,
|
||||
)
|
||||
user_active_project_ids = None
|
||||
if request.active_users:
|
||||
ids, user_active_project_ids = project_bll.get_projects_with_active_user(
|
||||
selected_project_ids = None
|
||||
if request.active_users or request.children_condition:
|
||||
ids, selected_project_ids = project_bll.get_projects_with_selected_children(
|
||||
company=company_id,
|
||||
users=request.active_users,
|
||||
project_ids=requested_ids,
|
||||
allow_public=allow_public,
|
||||
children_condition=request.children_condition.to_struct()
|
||||
if request.children_condition
|
||||
else None,
|
||||
)
|
||||
if not ids:
|
||||
return {"projects": []}
|
||||
@ -158,7 +161,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
search_hidden=request.search_hidden,
|
||||
filter_=request.include_stats_filter,
|
||||
users=request.active_users,
|
||||
user_active_project_ids=user_active_project_ids,
|
||||
selected_project_ids=selected_project_ids,
|
||||
)
|
||||
|
||||
for project in projects:
|
||||
|
@ -33,6 +33,31 @@ class TestSubProjects(TestService):
|
||||
).projects[0]
|
||||
self.assertEqual(data.dataset_stats, {"file_count": 2, "total_size": 1000})
|
||||
|
||||
def test_query_children(self):
|
||||
test_root_name = "TestQueryChildren"
|
||||
test_root = self._temp_project(name=test_root_name)
|
||||
child_with_tag = self._temp_project(
|
||||
name=f"{test_root_name}/Project1/WithTag", system_tags=["test"]
|
||||
)
|
||||
child_without_tag = self._temp_project(name=f"{test_root_name}/Project2/WithoutTag")
|
||||
|
||||
projects = self.api.projects.get_all_ex(parent=[test_root], shallow_search=True).projects
|
||||
self.assertEqual({p.basename for p in projects}, {"Project1", "Project2"})
|
||||
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root], children_condition={"system_tags": ["test"]}, shallow_search=True
|
||||
).projects
|
||||
self.assertEqual({p.basename for p in projects}, {"Project1"})
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[projects[0].id], children_condition={"system_tags": ["test"]}, shallow_search=True
|
||||
).projects
|
||||
self.assertEqual(projects[0].id, child_with_tag)
|
||||
|
||||
projects = self.api.projects.get_all_ex(
|
||||
parent=[test_root], children_condition={"system_tags": ["not existent"]}, shallow_search=True
|
||||
).projects
|
||||
self.assertEqual(len(projects), 0)
|
||||
|
||||
def test_project_aggregations(self):
|
||||
"""This test requires user with user_auth_only... credentials in db"""
|
||||
user2_client = APIClient(
|
||||
|
Loading…
Reference in New Issue
Block a user