Add update_tags api to tasks and models

This commit is contained in:
allegroai 2023-11-17 09:37:25 +02:00
parent cc0129a800
commit 274c487b37
14 changed files with 210 additions and 29 deletions

View File

@ -333,3 +333,8 @@ class DeleteModelsRequest(TaskRequest):
class GetAllReq(models.Base):
allow_public = BoolField(default=True)
search_hidden = BoolField(default=False)
class UpdateTagsRequest(BatchRequest):
add_tags = ListField([str])
remove_tags = ListField([str])

View File

@ -1,8 +1,11 @@
from collections import defaultdict
from enum import Enum
from typing import Sequence, Dict
from typing import Sequence, Dict, Type
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.model import AttributedDocument
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
@ -22,6 +25,51 @@ class OrgBLL:
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def edit_entity_tags(
self,
company_id,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
remove_tags: Sequence[str],
) -> int:
if entity_cls not in (Task, Model):
raise errors.bad_request.ValidationError(
"Tags editing can be called on tasks or models only"
)
if not entity_ids:
raise errors.bad_request.ValidationError(
"No entity ids provided for editing tags"
)
if not (add_tags or remove_tags):
raise errors.bad_request.ValidationError(
"Either add tags or remove tags should be provided"
)
updated = 0
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags
)
if not updated:
return 0
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
"project"
)
update_project_time(project_ids=projects)
self.update_tags(
company_id,
entity=Tags.Task if entity_cls is Task else Tags.Model,
projects=projects,
tags=add_tags or remove_tags
)
return updated
def get_tags(
self,
company_id: str,
@ -50,10 +98,10 @@ class OrgBLL:
return ret
def update_tags(
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company_id, project, tags, system_tags)
tags_cache.update_tags(company_id, projects, tags, system_tags)
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)

View File

@ -107,7 +107,7 @@ class _TagsCache:
return ret
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
"""
Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
@ -123,7 +123,7 @@ class _TagsCache:
if not fields:
return
self._delete_redis_keys(company_id, projects=[project], fields=fields)
self._delete_redis_keys(company_id, projects=projects, fields=fields)
def reset_tags(self, company_id: str, projects: Sequence[str]):
self._delete_redis_keys(

View File

@ -1,6 +1,5 @@
from .task_bll import TaskBLL
from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
)

View File

@ -1,7 +1,7 @@
from datetime import timedelta, datetime
from time import sleep
from apiserver.bll.task import update_project_time
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model.task.task import TaskStatus, Task
from apiserver.utilities.threads_manager import ThreadsManager

View File

@ -12,6 +12,7 @@ from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@ -31,7 +32,10 @@ from apiserver.database.model.task.task import (
)
from apiserver.database.model import EntityVisibility
from apiserver.database.model.queue import Queue
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.database.utils import (
get_company_or_none_constraint,
id as create_id,
)
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
@ -39,7 +43,6 @@ from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
ChangeStatusRequest,
update_project_time,
deleted_prefix,
get_last_metric_updates,
)
@ -78,7 +81,11 @@ class TaskBLL:
@staticmethod
def get_by_id(
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
company_id,
task_id,
required_status=None,
only_fields=None,
allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
@ -313,7 +320,7 @@ class TaskBLL:
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
projects=[new_task.project],
tags=updated_tags,
system_tags=updated_system_tags,
)

View File

@ -7,9 +7,9 @@ from apiserver.bll.task import (
TaskBLL,
validate_status_change,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model

View File

@ -1,14 +1,13 @@
from datetime import datetime
from typing import Sequence, Union
import attr
import six
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options
from apiserver.utilities.attrs import typed_attrs
@ -158,16 +157,6 @@ def get_possible_status_changes(current_status):
return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task:

View File

@ -1,6 +1,7 @@
import functools
import itertools
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from typing import (
Optional,
Callable,
@ -8,11 +9,13 @@ from typing import (
Tuple,
Sequence,
TypeVar,
Union,
)
from boltons import iterutils
from apiserver.apierrors import APIError
from apiserver.database.model.project import Project
from apiserver.database.model.settings import Settings
@ -77,3 +80,13 @@ def run_batch_operation(
}
)
return results, failures
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())

View File

@ -1095,3 +1095,36 @@ delete_metadata {
}
}
}
update_tags {
"999.0" {
request {
type: object
properties {
ids {
type: array
description: IDs of the models to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated models
}
}
}
}
}

View File

@ -2058,3 +2058,36 @@ move {
}
}
}
update_tags {
"999.0" {
request {
type: object
properties {
ids {
type: array
description: IDs of the tasks to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated tasks
}
}
}
}
}

View File

@ -22,6 +22,7 @@ from apiserver.apimodels.models import (
ModelsDeleteManyRequest,
ModelsGetRequest,
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
@ -219,7 +220,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Model,
project=project,
projects=[project],
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
@ -678,6 +679,19 @@ def move(call: APICall, company_id: str, request: MoveRequest):
}
@endpoint("models.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"update": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Model,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest

View File

@ -67,6 +67,7 @@ from apiserver.apimodels.tasks import (
GetAllReq,
DequeueRequest,
DequeueManyRequest,
UpdateTagsRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL
@ -76,7 +77,6 @@ from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import (
TaskBLL,
ChangeStatusRequest,
update_project_time,
)
from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save,
@ -101,7 +101,7 @@ from apiserver.bll.task.task_operations import (
move_tasks_to_trash,
)
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import run_batch_operation
from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output
@ -112,7 +112,11 @@ from apiserver.database.model.task.task import (
ModelItem,
TaskModelTypes,
)
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
from apiserver.database.utils import (
get_fields_attr,
parse_from_call,
get_options,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
@ -493,7 +497,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Task,
project=project,
projects=[project],
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
@ -1325,6 +1329,19 @@ def move(call: APICall, company_id: str, request: MoveRequest):
return {"project_id": project_id}
@endpoint("tasks.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"update": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Task,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True)

View File

@ -92,6 +92,29 @@ class TestProjectTags(TestService):
self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags)
def test_tags_api(self):
p = self.create_temp("projects", name="Test tags api", description="test")
# task
initial_tags = ["Task tag"]
task = self.new_task(project=p, tags=initial_tags)
data = self.api.projects.get_task_tags(projects=[p])
self.assertEqual(data.tags, initial_tags)
new_tags = ["New task tag"]
self.api.tasks.update_tags(ids=[task], add_tags=new_tags, remove_tags=initial_tags)
data = self.api.projects.get_task_tags(projects=[p])
self.assertEqual(data.tags, new_tags)
# model
initial_tags = ["Model tag"]
model = self.new_model(project=p, tags=initial_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(data.tags, initial_tags)
new_tags = ["New model tag"]
self.api.models.update_tags(ids=[model], add_tags=new_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
def new_task(self, **kwargs):
self.update_missing(
kwargs, type="testing", name="test project tags"