clearml-server/apiserver/bll/organization/__init__.py
2023-11-17 09:37:25 +02:00

112 lines
3.7 KiB
Python

from collections import defaultdict
from enum import Enum
from typing import Sequence, Dict, Type
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.model import AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from .tags_cache import _TagsCache
log = config.logger(__file__)
class Tags(Enum):
Task = "task"
Model = "model"
class OrgBLL:
def __init__(self, redis=None):
self.redis = redis or redman.connection("apiserver")
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def edit_entity_tags(
self,
company_id,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
remove_tags: Sequence[str],
) -> int:
if entity_cls not in (Task, Model):
raise errors.bad_request.ValidationError(
"Tags editing can be called on tasks or models only"
)
if not entity_ids:
raise errors.bad_request.ValidationError(
"No entity ids provided for editing tags"
)
if not (add_tags or remove_tags):
raise errors.bad_request.ValidationError(
"Either add tags or remove tags should be provided"
)
updated = 0
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags
)
if not updated:
return 0
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
"project"
)
update_project_time(project_ids=projects)
self.update_tags(
company_id,
entity=Tags.Task if entity_cls is Task else Tags.Model,
projects=projects,
tags=add_tags or remove_tags
)
return updated
def get_tags(
self,
company_id: str,
entity: Tags,
include_system: bool = False,
filter_: Dict[str, Sequence[str]] = None,
projects: Sequence[str] = None,
) -> dict:
tags_cache = self._get_tags_cache_for_entity(entity)
if not projects:
return tags_cache.get_tags(
company_id, include_system=include_system, filter_=filter_
)
ret = defaultdict(set)
for project in projects:
project_tags = tags_cache.get_tags(
company_id,
include_system=include_system,
filter_=filter_,
project=project,
)
for field, tags in project_tags.items():
ret[field] |= tags
return ret
def update_tags(
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company_id, projects, tags, system_tags)
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.reset_tags(company_id, projects=projects)
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
return self._task_tags if entity == Tags.Task else self._model_tags