From 1ea6408d419a5ae6f1082da3ecce5397355f9c14 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sun, 21 Jun 2020 23:54:05 +0300 Subject: [PATCH] Support tags-per-project in tags related services --- server/apimodels/organization.py | 1 + server/apimodels/projects.py | 7 + server/bll/event/event_bll.py | 36 +++- server/bll/organization/__init__.py | 176 ++++++++++++++++---- server/bll/task/task_bll.py | 24 ++- server/schema/services/organization.conf | 77 +++++---- server/schema/services/projects.conf | 62 ++++++- server/services/models.py | 45 +++-- server/services/organization.py | 19 ++- server/services/projects.py | 44 ++++- server/services/tasks.py | 66 ++++++-- server/services/utils.py | 16 ++ server/tests/automated/test_organization.py | 36 ---- server/tests/automated/test_project_tags.py | 82 +++++++++ server/tests/automated/test_task_events.py | 27 +++ 15 files changed, 571 insertions(+), 147 deletions(-) delete mode 100644 server/tests/automated/test_organization.py create mode 100644 server/tests/automated/test_project_tags.py diff --git a/server/apimodels/organization.py b/server/apimodels/organization.py index 9c4c12d..a7b3f6f 100644 --- a/server/apimodels/organization.py +++ b/server/apimodels/organization.py @@ -2,6 +2,7 @@ from jsonmodels import fields, models class Filter(models.Base): + tags = fields.ListField([str]) system_tags = fields.ListField([str]) diff --git a/server/apimodels/projects.py b/server/apimodels/projects.py index f4e4f7e..d1780e5 100644 --- a/server/apimodels/projects.py +++ b/server/apimodels/projects.py @@ -1,5 +1,8 @@ from jsonmodels import models, fields +from apimodels import ListField +from apimodels.organization import TagsRequest + class ProjectReq(models.Base): project = fields.StringField() @@ -14,3 +17,7 @@ class GetHyperParamResp(models.Base): parameters = fields.ListField(str) remaining = fields.IntField() total = fields.IntField() + + +class ProjectTagsRequest(TagsRequest): + projects = ListField(str) diff --git a/server/bll/event/event_bll.py b/server/bll/event/event_bll.py index e846d9b..b50fbf8 100644 --- a/server/bll/event/event_bll.py +++ b/server/bll/event/event_bll.py @@ -230,9 +230,31 @@ class EventBLL(object): metric_hash = dbutils.hash_field_name(metric) variant_hash = dbutils.hash_field_name(variant) - timestamp = last_events[metric_hash][variant_hash].get("timestamp", None) - if timestamp is None or timestamp < event["timestamp"]: - last_events[metric_hash][variant_hash] = event + last_event = last_events[metric_hash][variant_hash] + event_iter = event.get("iter", 0) + event_timestamp = event["timestamp"] + if (event_iter, event_timestamp) >= ( + last_event.get("iter", event_iter), + last_event.get("timestamp", event_timestamp), + ): + event_data = { + k: event[k] + for k in ("value", "metric", "variant", "iter", "timestamp") + if k in event + } + value = event_data.get("value") + if value is not None: + event_data["min_value"] = min(value, last_event.get("min_value", value)) + event_data["max_value"] = max(value, last_event.get("max_value", value)) + else: + event_data.update( + **{ + k: last_event[k] + for k in ("value", "min_value", "max_value") + if k in last_event + } + ) + last_events[metric_hash][variant_hash] = event_data def _update_last_metric_events_for_task(self, last_events, event): """ @@ -275,7 +297,13 @@ class EventBLL(object): flatten_nested_items( last_scalar_events, nesting=2, - include_leaves=["value", "metric", "variant"], + include_leaves=[ + "value", + "min_value", + "max_value", + "metric", + "variant", + ], ) ) diff --git a/server/bll/organization/__init__.py b/server/bll/organization/__init__.py index e8ab1fe..b5a85bf 100644 --- a/server/bll/organization/__init__.py +++ b/server/bll/organization/__init__.py @@ -1,6 +1,10 @@ -from typing import Sequence +from collections import defaultdict +from enum import Enum +from itertools import chain +from typing import Sequence, Union, Type, Dict from mongoengine import Q +from redis import Redis from config import config from database.model.base import GetMixin @@ -10,40 +14,65 @@ from redis_manager import redman from utilities import json log = config.logger(__file__) +_settings_prefix = "services.organization" -class OrgBLL: +class _TagsCache: _tags_field = "tags" _system_tags_field = "system_tags" - _settings_prefix = "services.organization" - def __init__(self, redis=None): - self.redis = redis or redman.connection("apiserver") + def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis): + self.db_cls = db_cls + self.redis = redis @property def _tags_cache_expiration_seconds(self): - return config.get( - f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600 - ) + return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600) - @staticmethod - def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None): - filter_str = "_".join(filter_) if filter_ else "" - return f"{field}_{company}_{filter_str}" - - @staticmethod - def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set: + def _get_tags_from_db( + self, + company: str, + field: str, + project: str = None, + filter_: Dict[str, Sequence[str]] = None, + ) -> set: query = Q(company=company) if filter_: - query &= GetMixin.get_list_field_query("system_tags", filter_) + for name, vals in filter_.items(): + if vals: + query &= GetMixin.get_list_field_query(name, vals) + if project: + query &= Q(project=project) - tags = set() - for cls_ in (Task, Model): - tags |= set(cls_.objects(query).distinct(field)) - return tags + return self.db_cls.objects(query).distinct(field) + + def _get_tags_cache_key( + self, + company: str, + field: str, + project: str = None, + filter_: Dict[str, Sequence[str]] = None, + ): + """ + Project None means 'from all company projects' + The key is built in the way that scanning company keys for 'all company projects' + will not return the keys related to the particular company projects and vice versa. + So that we can have a fine grain control on what redis keys to invalidate + """ + filter_str = None + if filter_: + filter_str = "_".join( + ["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())] + ) + key_parts = [company, project, self.db_cls.__name__, field, filter_str] + return "_".join(filter(None, key_parts)) def get_tags( - self, company, include_system: bool = False, filter_: Sequence[str] = None + self, + company: str, + include_system: bool = False, + filter_: Dict[str, Sequence[str]] = None, + project: str = None, ) -> dict: """ Get tags and optionally system tags for the company @@ -51,35 +80,114 @@ class OrgBLL: The function retrieves both cached values from Redis in one call and re calculates any of them if missing in Redis """ - fields = [ - self._tags_field, - *([self._system_tags_field] if include_system else []), + fields = [self._tags_field] + if include_system: + fields.append(self._system_tags_field) + redis_keys = [ + self._get_tags_cache_key(company, field=f, project=project, filter_=filter_) + for f in fields ] - redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields] cached = self.redis.mget(redis_keys) ret = {} for field, tag_data, key in zip(fields, cached, redis_keys): if tag_data is not None: tags = json.loads(tag_data) else: - tags = list(self._get_tags_from_db(company, field, filter_)) + tags = list(self._get_tags_from_db(company, field, project, filter_)) self.redis.setex( key, time=self._tags_cache_expiration_seconds, value=json.dumps(tags), ) - ret[field] = tags + ret[field] = set(tags) return ret - def update_org_tags(self, company, tags=None, system_tags=None, reset=False): + def update_tags(self, company: str, project: str, tags=None, system_tags=None): """ - Updates system tags. If reset is set then both tags and system_tags + Updates tags. If reset is set then both tags and system_tags are recalculated. Otherwise only those that are not 'None' """ - if reset or tags is not None: - self.redis.delete(self._get_tags_cache_key(company, self._tags_field)) - if reset or system_tags is not None: - self.redis.delete( - self._get_tags_cache_key(company, self._system_tags_field) + fields = [ + field + for field, update in ( + (self._tags_field, tags), + (self._system_tags_field, system_tags), ) + if update is not None + ] + if not fields: + return + + self._delete_redis_keys(company, projects=[project], fields=fields) + + def reset_tags(self, company: str, projects: Sequence[str]): + self._delete_redis_keys( + company, + projects=projects, + fields=(self._tags_field, self._system_tags_field), + ) + + def _delete_redis_keys( + self, company: str, projects: [Sequence[str]], fields: Sequence[str] + ): + redis_keys = list( + chain.from_iterable( + self.redis.keys( + self._get_tags_cache_key(company, field=f, project=p) + "*" + ) + for f in fields + for p in set(projects) | {None} + ) + ) + if redis_keys: + self.redis.delete(*redis_keys) + + +class Tags(Enum): + Task = "task" + Model = "model" + + +class OrgBLL: + def __init__(self, redis=None): + self.redis = redis or redman.connection("apiserver") + self._task_tags = _TagsCache(Task, self.redis) + self._model_tags = _TagsCache(Model, self.redis) + + def get_tags( + self, + company: str, + entity: Tags, + include_system: bool = False, + filter_: Dict[str, Sequence[str]] = None, + projects: Sequence[str] = None, + ) -> dict: + tags_cache = self._get_tags_cache_for_entity(entity) + if not projects: + return tags_cache.get_tags( + company, include_system=include_system, filter_=filter_ + ) + + ret = defaultdict(set) + for project in projects: + project_tags = tags_cache.get_tags( + company, include_system=include_system, filter_=filter_, project=project + ) + for field, tags in project_tags.items(): + ret[field] |= tags + + return ret + + def update_tags( + self, company: str, entity: Tags, project: str, tags=None, system_tags=None, + ): + tags_cache = self._get_tags_cache_for_entity(entity) + tags_cache.update_tags(company, project, tags, system_tags) + + def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]): + tags_cache = self._get_tags_cache_for_entity(entity) + tags_cache.reset_tags(company, projects=projects) + + def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache: + return self._task_tags if entity == Tags.Task else self._model_tags diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index d682aa3..6f69ccf 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -14,7 +14,7 @@ import database.utils as dbutils import es_factory from apierrors import errors from apimodels.tasks import Artifact as ApiArtifact -from bll.organization import OrgBLL +from bll.organization import OrgBLL, Tags from config import config from database.errors import translate_errors_context from database.model.model import Model @@ -229,7 +229,21 @@ class TaskBLL(object): validate_project=validate_references or project, ) new_task.save() - org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags) + + if task.project == new_task.project: + updated_tags = tags + updated_system_tags = system_tags + else: + updated_tags = new_task.tags + updated_system_tags = new_task.system_tags + org_bll.update_tags( + company_id, + Tags.Task, + project=new_task.project, + tags=updated_tags, + system_tags=updated_system_tags, + ) + return new_task @classmethod @@ -346,10 +360,12 @@ class TaskBLL(object): return "__".join((op, "last_metrics") + path) for path, value in last_scalar_values: - extra_updates[op_path("set", *path)] = value - if path[-1] == "value": + if path[-1] == "min_value": extra_updates[op_path("min", *path[:-1], "min_value")] = value + elif path[-1] == "max_value": extra_updates[op_path("max", *path[:-1], "max_value")] = value + else: + extra_updates[op_path("set", *path)] = value if last_events is not None: diff --git a/server/schema/services/organization.conf b/server/schema/services/organization.conf index 978ab2e..db236a0 100644 --- a/server/schema/services/organization.conf +++ b/server/schema/services/organization.conf @@ -1,43 +1,48 @@ _description: "This service provides organization level operations" get_tags { - "2.8" { - description: "Get all the user and system tags used for the company tasks and models" - request { - type: object - properties { - include_system { - description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'" - type: boolean - default: false - } - filter { - description: "Filter on entities to collect tags from" - type: object - properties { - system_tags { - description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded" - type: array - items {type: string} + "2.8" { + description: "Get all the user and system tags used for the company tasks and models" + request { + type: object + properties { + include_system { + description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'" + type: boolean + default: false + } + filter { + description: "Filter on entities to collect tags from" + type: object + properties { + tags { + description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded" + type: array + items {type: string} + } + system_tags { + description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded" + type: array + items {type: string} + } + } + } + } + } + response { + type: object + properties { + tags { + description: "The list of unique tag values" + type: array + items {type: string} + } + system_tags { + description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request" + type: array + items {type: string} + } } - } } - } } - response { - type: object - properties { - tags { - description: "The list of unique tag values" - type: array - items {type: string} - } - system_tags { - description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request" - type: array - items {type: string} - } - } - } - } } \ No newline at end of file diff --git a/server/schema/services/projects.conf b/server/schema/services/projects.conf index d5f18c4..d4fc43e 100644 --- a/server/schema/services/projects.conf +++ b/server/schema/services/projects.conf @@ -196,6 +196,52 @@ _definitions { } } } + tags_request { + type: object + properties { + include_system { + description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'" + type: boolean + default: false + } + projects { + description: "The list of projects under which the tags are searched. If not passed or empty then all the projects are searched" + type: array + items { type: string } + } + filter { + description: "Filter on entities to collect tags from" + type: object + properties { + tags { + description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded" + type: array + items {type: string} + } + system_tags { + description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded" + type: array + items {type: string} + } + } + } + } + } + tags_response { + type: object + properties { + tags { + description: "The list of unique tag values" + type: array + items {type: string} + } + system_tags { + description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request" + type: array + items {type: string} + } + } + } } create { @@ -508,7 +554,7 @@ get_hyper_parameters { parameters { description: "A list of hyper parameter names" type: array - items { type: string } + items {type: string} } remaining { description: "Remaining results" @@ -522,3 +568,17 @@ get_hyper_parameters { } } } +get_task_tags { + "2.8" { + description: "Get user and system tags used for the tasks under the specified projects" + request = ${_definitions.tags_request} + response = ${_definitions.tags_response} + } +} +get_model_tags { + "2.8" { + description: "Get user and system tags used for the models under the specified projects" + request = ${_definitions.tags_request} + response = ${_definitions.tags_response} + } +} \ No newline at end of file diff --git a/server/services/models.py b/server/services/models.py index 6d4348b..19bc3a1 100644 --- a/server/services/models.py +++ b/server/services/models.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Sequence from mongoengine import Q, EmbeddedDocument @@ -12,7 +13,7 @@ from apimodels.models import ( PublishModelResponse, ModelTaskPublishResponse, ) -from bll.organization import OrgBLL +from bll.organization import OrgBLL, Tags from bll.task import TaskBLL from config import config from database.errors import translate_errors_context @@ -128,9 +129,19 @@ def parse_model_fields(call, valid_fields): return fields -def _update_org_tags(company, fields: dict): - org_bll.update_org_tags( - company, tags=fields.get("tags"), system_tags=fields.get("system_tags") +def _update_cached_tags(company: str, project: str, fields: dict): + org_bll.update_tags( + company, + Tags.Model, + project=project, + tags=fields.get("tags"), + system_tags=fields.get("system_tags"), + ) + + +def _reset_cached_tags(company: str, projects: Sequence[str]): + org_bll.reset_tags( + company, Tags.Model, projects=projects, ) @@ -203,7 +214,7 @@ def update_for_task(call: APICall, company_id, _): **fields, ) model.save() - _update_org_tags(company_id, fields) + _update_cached_tags(company_id, project=model.project, fields=fields) TaskBLL.update_statistics( task_id=task_id, @@ -248,7 +259,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest): **fields, ) model.save() - _update_org_tags(company_id, fields) + _update_cached_tags(company_id, project=model.project, fields=fields) call.result.data_model = CreateModelResponse(id=model.id, created=True) @@ -327,7 +338,15 @@ def edit(call: APICall, company_id, _): if fields: updated = model.update(upsert=False, **fields) if updated: - _update_org_tags(company_id, fields) + new_project = fields.get("project", model.project) + if new_project != model.project: + _reset_cached_tags( + company_id, projects=[new_project, model.project] + ) + else: + _update_cached_tags( + company_id, project=model.project, fields=fields + ) conform_output_tags(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) else: @@ -355,7 +374,13 @@ def _update_model(call: APICall, company_id, model_id=None): updated_count, updated_fields = Model.safe_update(company_id, model.id, data) if updated_count: - _update_org_tags(company_id, updated_fields) + new_project = updated_fields.get("project", model.project) + if new_project != model.project: + _reset_cached_tags(company_id, projects=[new_project, model.project]) + else: + _update_cached_tags( + company_id, project=model.project, fields=updated_fields + ) conform_output_tags(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields) @@ -395,7 +420,7 @@ def update(call: APICall, company_id, _): with translate_errors_context(): query = dict(id=model_id, company=company_id) - model = Model.objects(**query).only("id", "task").first() + model = Model.objects(**query).only("id", "task", "project").first() if not model: raise errors.bad_request.InvalidModelId(**query) @@ -428,5 +453,5 @@ def update(call: APICall, company_id, _): del_count = Model.objects(**query).delete() if del_count: - org_bll.update_org_tags(company_id, reset=True) + _reset_cached_tags(company_id, projects=[model.project]) call.result.data = dict(deleted=del_count > 0) diff --git a/server/services/organization.py b/server/services/organization.py index 49c7d56..ebbb1ee 100644 --- a/server/services/organization.py +++ b/server/services/organization.py @@ -1,13 +1,22 @@ +from collections import defaultdict + from apimodels.organization import TagsRequest -from bll.organization import OrgBLL +from bll.organization import OrgBLL, Tags from service_repo import endpoint, APICall +from services.utils import get_tags_filter_dictionary, get_tags_response org_bll = OrgBLL() @endpoint("organization.get_tags", request_data_model=TagsRequest) def get_tags(call: APICall, company, request: TagsRequest): - filter_ = request.filter.system_tags if request.filter else None - call.result.data = org_bll.get_tags( - company, include_system=request.include_system, filter_=filter_ - ) + filter_dict = get_tags_filter_dictionary(request.filter) + ret = defaultdict(set) + for entity in Tags.Model, Tags.Task: + tags = org_bll.get_tags( + company, entity, include_system=request.include_system, filter_=filter_dict, + ) + for field, vals in tags.items(): + ret[field] |= vals + + call.result.data = get_tags_response(ret) diff --git a/server/services/projects.py b/server/services/projects.py index b952290..0becbc8 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -9,7 +9,13 @@ from mongoengine import Q import database from apierrors import errors from apimodels.base import UpdateResponse -from apimodels.projects import GetHyperParamReq, GetHyperParamResp, ProjectReq +from apimodels.projects import ( + GetHyperParamReq, + GetHyperParamResp, + ProjectReq, + ProjectTagsRequest, +) +from bll.organization import OrgBLL, Tags from bll.task import TaskBLL from database.errors import translate_errors_context from database.model import EntityVisibility @@ -18,9 +24,15 @@ from database.model.project import Project from database.model.task.task import Task, TaskStatus from database.utils import parse_from_call, get_options, get_company_or_none_constraint from service_repo import APICall, endpoint -from services.utils import conform_tag_fields, conform_output_tags +from services.utils import ( + conform_tag_fields, + conform_output_tags, + get_tags_filter_dictionary, + get_tags_response, +) from timing_context import TimingContext +org_bll = OrgBLL() task_bll = TaskBLL() archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]} @@ -381,3 +393,31 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR "remaining": remaining, "parameters": parameters, } + + +@endpoint( + "projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest +) +def get_tags(call: APICall, company, request: ProjectTagsRequest): + ret = org_bll.get_tags( + company, + Tags.Task, + include_system=request.include_system, + filter_=get_tags_filter_dictionary(request.filter), + projects=request.projects, + ) + call.result.data = get_tags_response(ret) + + +@endpoint( + "projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest +) +def get_tags(call: APICall, company, request: ProjectTagsRequest): + ret = org_bll.get_tags( + company, + Tags.Model, + include_system=request.include_system, + filter_=get_tags_filter_dictionary(request.filter), + projects=request.projects, + ) + call.result.data = get_tags_response(ret) diff --git a/server/services/tasks.py b/server/services/tasks.py index 9cf0f8b..f4d15dc 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -33,7 +33,7 @@ from apimodels.tasks import ( ResetRequest, ) from bll.event import EventBLL -from bll.organization import OrgBLL +from bll.organization import OrgBLL, Tags from bll.queue import QueueBLL from bll.task import ( TaskBLL, @@ -343,9 +343,19 @@ def validate(call: APICall, company_id, req_model: CreateRequest): _validate_and_get_task_from_call(call) -def _update_org_tags(company, fields: dict): - org_bll.update_org_tags( - company, tags=fields.get("tags"), system_tags=fields.get("system_tags") +def _update_cached_tags(company: str, project: str, fields: dict): + org_bll.update_tags( + company, + Tags.Task, + project=project, + tags=fields.get("tags"), + system_tags=fields.get("system_tags"), + ) + + +def _reset_cached_tags(company: str, projects: Sequence[str]): + org_bll.reset_tags( + company, Tags.Task, projects=projects ) @@ -357,7 +367,7 @@ def create(call: APICall, company_id, req_model: CreateRequest): with translate_errors_context(), TimingContext("mongo", "save_task"): task.save() - _update_org_tags(company_id, fields) + _update_cached_tags(company_id, project=task.project, fields=fields) update_project_time(task.project) call.result.data_model = IdResponse(id=task.id) @@ -400,7 +410,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest): task_id = req_model.task with translate_errors_context(): - task = Task.get_for_writing(id=task_id, company=company_id, _only=["id"]) + task = Task.get_for_writing( + id=task_id, company=company_id, _only=["id", "project"] + ) if not task: raise errors.bad_request.InvalidTaskId(id=task_id) @@ -416,7 +428,13 @@ def update(call: APICall, company_id, req_model: UpdateRequest): injected_update=dict(last_update=datetime.utcnow()), ) if updated_count: - _update_org_tags(company_id, updated_fields) + new_project = updated_fields.get("project", task.project) + if new_project != task.project: + _reset_cached_tags(company_id, projects=[new_project, task.project]) + else: + _update_cached_tags( + company_id, project=task.project, fields=updated_fields + ) update_project_time(updated_fields.get("project")) unprepare_from_saved(call, updated_fields) return UpdateResponse(updated=updated_count, fields=updated_fields) @@ -470,8 +488,10 @@ def update_batch(call: APICall, company_id, _): now = datetime.utcnow() bulk_ops = [] + updated_projects = set() for id, data in items.items(): - fields, valid_fields = prepare_update_fields(call, tasks[id], data) + task = tasks[id] + fields, valid_fields = prepare_update_fields(call, task, data) partial_update_dict = Task.get_safe_update_dict(fields) if not partial_update_dict: continue @@ -481,12 +501,20 @@ def update_batch(call: APICall, company_id, _): ) bulk_ops.append(update_op) + new_project = partial_update_dict.get("project", task.project) + if new_project != task.project: + updated_projects.update({new_project, task.project}) + elif any(f in partial_update_dict for f in ("tags", "system_tags")): + updated_projects.add(task.project) + updated = 0 if bulk_ops: res = Task._get_collection().bulk_write(bulk_ops) updated = res.modified_count - if updated: - org_bll.update_org_tags(company_id, reset=True) + + if updated and updated_projects: + _reset_cached_tags(company_id, projects=list(updated_projects)) + call.result.data = {"updated": updated} @@ -542,7 +570,15 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): fixed_fields.update(last_update=now) updated = task.update(upsert=False, **fixed_fields) if updated: - _update_org_tags(company_id, fixed_fields) + new_project = fixed_fields.get("project", task.project) + if new_project != task.project: + _reset_cached_tags( + company_id, projects=[new_project, task.project] + ) + else: + _update_cached_tags( + company_id, project=task.project, fields=fixed_fields + ) update_project_time(fields.get("project")) unprepare_from_saved(call, fields) call.result.data_model = UpdateResponse(updated=updated, fields=fields) @@ -710,12 +746,11 @@ def reset(call: APICall, company_id, request: ResetRequest): if request.clear_all: updates.update( - set__execution=Execution(), - unset__script=1, + set__execution=Execution(), unset__script=1, ) else: - updates.update(unset__execution__queue=1) updates.update( + unset__execution__queue=1, __raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}}, ) @@ -909,7 +944,8 @@ def delete(call: APICall, company_id, req_model: DeleteRequest): task.switch_collection(collection_name) task.delete() - org_bll.update_org_tags(company_id, reset=True) + _reset_cached_tags(company_id, projects=[task.project]) + call.result.data = dict(deleted=True, **attr.asdict(result)) diff --git a/server/services/utils.py b/server/services/utils.py index b9d8f6e..f02d425 100644 --- a/server/services/utils.py +++ b/server/services/utils.py @@ -1,12 +1,28 @@ from typing import Union, Sequence, Tuple from apierrors import errors +from apimodels.organization import Filter from database.model.base import GetMixin from database.utils import partition_tags from service_repo import APICall from service_repo.base import PartialVersion +def get_tags_filter_dictionary(input_: Filter) -> dict: + if not input_: + return {} + + return { + field: vals + for field, vals in (("tags", input_.tags), ("system_tags", input_.system_tags)) + if vals + } + + +def get_tags_response(ret: dict) -> dict: + return {field: sorted(vals) for field, vals in ret.items()} + + def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]): """ For old clients both tags and system tags are returned in 'tags' field diff --git a/server/tests/automated/test_organization.py b/server/tests/automated/test_organization.py deleted file mode 100644 index 566b016..0000000 --- a/server/tests/automated/test_organization.py +++ /dev/null @@ -1,36 +0,0 @@ -from tests.automated import TestService - - -class TestOrganization(TestService): - def setUp(self, version="2.8"): - super().setUp(version=version) - - def test_tags(self): - tag1 = "Orgtest tag1" - tag2 = "Orgtest tag2" - system_tag = "Orgtest system tag" - - model = self.create_temp( - "models", name="test_org", uri="file:///a", tags=[tag1] - ) - task = self.create_temp( - "tasks", name="test org", type="training", input=dict(view={}), tags=[tag1] - ) - data = self.api.organization.get_tags() - self.assertTrue(tag1 in data.tags) - - self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag]) - data = self.api.organization.get_tags(include_system=True) - self.assertTrue({tag1, tag2}.issubset(set(data.tags))) - self.assertTrue(system_tag in data.system_tags) - - data = self.api.organization.get_tags( - filter={"system_tags": ["__$not", system_tag]} - ) - self.assertTrue(tag1 in data.tags) - self.assertFalse(tag2 in data.tags) - - self.api.models.delete(model=model) - data = self.api.organization.get_tags() - self.assertFalse(tag1 in data.tags) - self.assertTrue(tag2 in data.tags) diff --git a/server/tests/automated/test_project_tags.py b/server/tests/automated/test_project_tags.py new file mode 100644 index 0000000..7e85b38 --- /dev/null +++ b/server/tests/automated/test_project_tags.py @@ -0,0 +1,82 @@ +from tests.automated import TestService + + +class TestProjectTags(TestService): + def setUp(self, version="2.8"): + super().setUp(version=version) + + def test_project_tags(self): + tags_1 = ["Test tag 1", "Test tag 2"] + tags_2 = ["Test tag 3", "Test tag 4"] + + p1 = self.create_temp("projects", name="Test tags1", description="test") + task1_1 = self.new_task(project=p1, tags=tags_1[:1]) + task1_2 = self.new_task(project=p1, tags=tags_1[1:]) + + p2 = self.create_temp("projects", name="Test tasks2", description="test") + task2 = self.new_task(project=p2, tags=tags_2) + + # test tags per project + data = self.api.projects.get_task_tags(projects=[p1]) + self.assertEqual(set(tags_1), set(data.tags)) + data = self.api.projects.get_model_tags(projects=[p1]) + self.assertEqual(set(), set(data.tags)) + data = self.api.projects.get_task_tags(projects=[p2]) + self.assertEqual(set(tags_2), set(data.tags)) + + # test tags for projects list + data = self.api.projects.get_task_tags(projects=[p1, p2]) + self.assertEqual(set(tags_1) | set(tags_2), set(data.tags)) + + # test tags for all projects + data = self.api.projects.get_task_tags(projects=[p1, p2]) + self.assertTrue((set(tags_1) | set(tags_2)).issubset(data.tags)) + + # test move to another project + self.api.tasks.edit(task=task1_2, project=p2) + data = self.api.projects.get_task_tags(projects=[p1]) + self.assertEqual(set(tags_1[:1]), set(data.tags)) + data = self.api.projects.get_task_tags(projects=[p2]) + self.assertEqual(set(tags_1[1:]) | set(tags_2), set(data.tags)) + + # test tags update + self.api.tasks.delete(task=task1_1, force=True) + self.api.tasks.delete(task=task2, force=True) + data = self.api.projects.get_task_tags(projects=[p1, p2]) + self.assertEqual(set(tags_1[1:]), set(data.tags)) + + def test_organization_tags(self): + tag1 = "Orgtest tag1" + tag2 = "Orgtest tag2" + system_tag = "Orgtest system tag" + + model = self.new_model(tags=[tag1]) + task = self.new_task(tags=[tag1]) + data = self.api.organization.get_tags() + self.assertTrue(tag1 in data.tags) + + self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag]) + data = self.api.organization.get_tags(include_system=True) + self.assertTrue({tag1, tag2}.issubset(set(data.tags))) + self.assertTrue(system_tag in data.system_tags) + + data = self.api.organization.get_tags( + filter={"system_tags": ["__$not", system_tag]} + ) + self.assertTrue(tag1 in data.tags) + self.assertFalse(tag2 in data.tags) + + self.api.models.delete(model=model) + data = self.api.organization.get_tags() + self.assertFalse(tag1 in data.tags) + self.assertTrue(tag2 in data.tags) + + def new_task(self, **kwargs): + self.update_missing( + kwargs, type="testing", name="test project tags", input=dict(view=dict()) + ) + return self.create_temp("tasks", **kwargs) + + def new_model(self, **kwargs): + self.update_missing(kwargs, name="test project tags", uri="file:///a") + return self.create_temp("models", **kwargs) diff --git a/server/tests/automated/test_task_events.py b/server/tests/automated/test_task_events.py index d88eb7f..90f1ea0 100644 --- a/server/tests/automated/test_task_events.py +++ b/server/tests/automated/test_task_events.py @@ -8,6 +8,8 @@ from functools import partial from statistics import mean from typing import Sequence +from boltons.iterutils import first + import es_factory from apierrors.errors.bad_request import EventsNotAdded from tests.automated import TestService @@ -72,6 +74,31 @@ class TestTaskEvents(TestService): ), ) + def test_last_scalar_metrics(self): + metric = "Metric1" + variant = "Variant1" + iter_count = 100 + task = self._temp_task() + events = [ + { + **self._create_task_event("training_stats_scalar", task, iteration), + "metric": metric, + "variant": variant, + "value": iteration, + } + for iteration in range(iter_count) + ] + # send 2 batches to check the interaction with already stored db value + # each batch contains multiple iterations + self.send_batch(events[:50]) + self.send_batch(events[50:]) + + task_data = self.api.tasks.get_by_id(task=task).task + metric_data = first(first(task_data.last_metrics.values()).values()) + self.assertEqual(iter_count - 1, metric_data.value) + self.assertEqual(iter_count - 1, metric_data.max_value) + self.assertEqual(0, metric_data.min_value) + def test_task_debug_images(self): task = self._temp_task() metric = "Metric1"