diff --git a/apiserver/bll/organization/__init__.py b/apiserver/bll/organization/__init__.py index 329752c..9fd64e2 100644 --- a/apiserver/bll/organization/__init__.py +++ b/apiserver/bll/organization/__init__.py @@ -1,147 +1,14 @@ from collections import defaultdict from enum import Enum -from itertools import chain -from typing import Sequence, Union, Type, Dict - -from mongoengine import Q -from redis import Redis +from typing import Sequence, Dict from apiserver.config_repo import config -from apiserver.database.model.base import GetMixin from apiserver.database.model.model import Model from apiserver.database.model.task.task import Task from apiserver.redis_manager import redman -from apiserver.utilities import json +from .tags_cache import _TagsCache log = config.logger(__file__) -_settings_prefix = "services.organization" - - -class _TagsCache: - _tags_field = "tags" - _system_tags_field = "system_tags" - - def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis): - self.db_cls = db_cls - self.redis = redis - - @property - def _tags_cache_expiration_seconds(self): - return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600) - - def _get_tags_from_db( - self, - company: str, - field: str, - project: str = None, - filter_: Dict[str, Sequence[str]] = None, - ) -> set: - query = Q(company=company) - if filter_: - for name, vals in filter_.items(): - if vals: - query &= GetMixin.get_list_field_query(name, vals) - if project: - query &= Q(project=project) - - return self.db_cls.objects(query).distinct(field) - - def _get_tags_cache_key( - self, - company: str, - field: str, - project: str = None, - filter_: Dict[str, Sequence[str]] = None, - ): - """ - Project None means 'from all company projects' - The key is built in the way that scanning company keys for 'all company projects' - will not return the keys related to the particular company projects and vice versa. - So that we can have a fine grain control on what redis keys to invalidate - """ - filter_str = None - if filter_: - filter_str = "_".join( - ["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())] - ) - key_parts = [company, project, self.db_cls.__name__, field, filter_str] - return "_".join(filter(None, key_parts)) - - def get_tags( - self, - company: str, - include_system: bool = False, - filter_: Dict[str, Sequence[str]] = None, - project: str = None, - ) -> dict: - """ - Get tags and optionally system tags for the company - Return the dictionary of tags per tags field name - The function retrieves both cached values from Redis in one call - and re calculates any of them if missing in Redis - """ - fields = [self._tags_field] - if include_system: - fields.append(self._system_tags_field) - redis_keys = [ - self._get_tags_cache_key(company, field=f, project=project, filter_=filter_) - for f in fields - ] - cached = self.redis.mget(redis_keys) - ret = {} - for field, tag_data, key in zip(fields, cached, redis_keys): - if tag_data is not None: - tags = json.loads(tag_data) - else: - tags = list(self._get_tags_from_db(company, field, project, filter_)) - self.redis.setex( - key, - time=self._tags_cache_expiration_seconds, - value=json.dumps(tags), - ) - ret[field] = set(tags) - - return ret - - def update_tags(self, company: str, project: str, tags=None, system_tags=None): - """ - Updates tags. If reset is set then both tags and system_tags - are recalculated. Otherwise only those that are not 'None' - """ - fields = [ - field - for field, update in ( - (self._tags_field, tags), - (self._system_tags_field, system_tags), - ) - if update is not None - ] - if not fields: - return - - self._delete_redis_keys(company, projects=[project], fields=fields) - - def reset_tags(self, company: str, projects: Sequence[str]): - self._delete_redis_keys( - company, - projects=projects, - fields=(self._tags_field, self._system_tags_field), - ) - - def _delete_redis_keys( - self, company: str, projects: [Sequence[str]], fields: Sequence[str] - ): - redis_keys = list( - chain.from_iterable( - self.redis.keys( - self._get_tags_cache_key(company, field=f, project=p) + "*" - ) - for f in fields - for p in set(projects) | {None} - ) - ) - if redis_keys: - self.redis.delete(*redis_keys) class Tags(Enum): @@ -157,7 +24,7 @@ class OrgBLL: def get_tags( self, - company: str, + company_id: str, entity: Tags, include_system: bool = False, filter_: Dict[str, Sequence[str]] = None, @@ -166,13 +33,16 @@ class OrgBLL: tags_cache = self._get_tags_cache_for_entity(entity) if not projects: return tags_cache.get_tags( - company, include_system=include_system, filter_=filter_ + company_id, include_system=include_system, filter_=filter_ ) ret = defaultdict(set) for project in projects: project_tags = tags_cache.get_tags( - company, include_system=include_system, filter_=filter_, project=project + company_id, + include_system=include_system, + filter_=filter_, + project=project, ) for field, tags in project_tags.items(): ret[field] |= tags @@ -180,14 +50,14 @@ class OrgBLL: return ret def update_tags( - self, company: str, entity: Tags, project: str, tags=None, system_tags=None, + self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None, ): tags_cache = self._get_tags_cache_for_entity(entity) - tags_cache.update_tags(company, project, tags, system_tags) + tags_cache.update_tags(company_id, project, tags, system_tags) - def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]): + def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]): tags_cache = self._get_tags_cache_for_entity(entity) - tags_cache.reset_tags(company, projects=projects) + tags_cache.reset_tags(company_id, projects=projects) def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache: return self._task_tags if entity == Tags.Task else self._model_tags diff --git a/apiserver/bll/organization/tags_cache.py b/apiserver/bll/organization/tags_cache.py new file mode 100644 index 0000000..33e7e2c --- /dev/null +++ b/apiserver/bll/organization/tags_cache.py @@ -0,0 +1,144 @@ +from itertools import chain +from typing import Sequence, Union, Type, Dict + +from mongoengine import Q +from redis import Redis + +from apiserver.config_repo import config +from apiserver.database.model.base import GetMixin +from apiserver.database.model.model import Model +from apiserver.database.model.task.task import Task + +log = config.logger(__file__) +_settings_prefix = "services.organization" + + +class _TagsCache: + _tags_field = "tags" + _system_tags_field = "system_tags" + _dummy_tag = "__dummy__" + # prepend our list in redis with this tag since empty lists are auto deleted + + def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis): + self.db_cls = db_cls + self.redis = redis + + @property + def _tags_cache_expiration_seconds(self): + return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600) + + def _get_tags_from_db( + self, + company_id: str, + field: str, + project: str = None, + filter_: Dict[str, Sequence[str]] = None, + ) -> set: + query = Q(company=company_id) + if filter_: + for name, vals in filter_.items(): + if vals: + query &= GetMixin.get_list_field_query(name, vals) + if project: + query &= Q(project=project) + + return self.db_cls.objects(query).distinct(field) + + def _get_tags_cache_key( + self, + company_id: str, + field: str, + project: str = None, + filter_: Dict[str, Sequence[str]] = None, + ): + """ + Project None means 'from all company projects' + The key is built in the way that scanning company keys for 'all company projects' + will not return the keys related to the particular company projects and vice versa. + So that we can have a fine grain control on what redis keys to invalidate + """ + filter_str = None + if filter_: + filter_str = "_".join( + ["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())] + ) + key_parts = [field, company_id, project, self.db_cls.__name__, filter_str] + return "_".join(filter(None, key_parts)) + + def get_tags( + self, + company_id: str, + include_system: bool = False, + filter_: Dict[str, Sequence[str]] = None, + project: str = None, + ) -> dict: + """ + Get tags and optionally system tags for the company + Return the dictionary of tags per tags field name + The function retrieves both cached values from Redis in one call + and re calculates any of them if missing in Redis + """ + fields = [self._tags_field] + if include_system: + fields.append(self._system_tags_field) + + ret = {} + for field in fields: + redis_key = self._get_tags_cache_key( + company_id, field=field, project=project, filter_=filter_ + ) + cached_tags = self.redis.lrange(redis_key, 0, -1) + if cached_tags: + tags = [c.decode() for c in cached_tags[1:]] + else: + tags = list( + self._get_tags_from_db( + company_id, field=field, project=project, filter_=filter_ + ) + ) + self.redis.rpush(redis_key, self._dummy_tag, *tags) + self.redis.expire(redis_key, self._tags_cache_expiration_seconds) + + ret[field] = set(tags) + + return ret + + def update_tags(self, company_id: str, project: str, tags=None, system_tags=None): + """ + Updates tags. If reset is set then both tags and system_tags + are recalculated. Otherwise only those that are not 'None' + """ + fields = [ + field + for field, update in ( + (self._tags_field, tags), + (self._system_tags_field, system_tags), + ) + if update is not None + ] + if not fields: + return + + self._delete_redis_keys(company_id, projects=[project], fields=fields) + + def reset_tags(self, company_id: str, projects: Sequence[str]): + self._delete_redis_keys( + company_id, + projects=projects, + fields=(self._tags_field, self._system_tags_field), + ) + + def _delete_redis_keys( + self, company_id: str, projects: [Sequence[str]], fields: Sequence[str] + ): + redis_keys = list( + chain.from_iterable( + self.redis.keys( + self._get_tags_cache_key(company_id, field=f, project=p) + "*" + ) + for f in fields + for p in set(projects) | {None} + ) + ) + if redis_keys: + self.redis.delete(*redis_keys) diff --git a/apiserver/tests/automated/test_project_tags.py b/apiserver/tests/automated/test_project_tags.py index d53c893..614ea7c 100644 --- a/apiserver/tests/automated/test_project_tags.py +++ b/apiserver/tests/automated/test_project_tags.py @@ -1,3 +1,5 @@ +from datetime import datetime + from apiserver.tests.automated import TestService @@ -46,7 +48,7 @@ class TestProjectTags(TestService): self.assertEqual(set(tags_1[1:]), set(data.tags)) def test_organization_tags(self): - tag1 = "Orgtest tag1" + tag1 = datetime.utcnow().isoformat() tag2 = "Orgtest tag2" system_tag = "Orgtest system tag"