Add update_tags api to tasks and models

This commit is contained in:
allegroai 2023-11-17 09:37:25 +02:00
parent cc0129a800
commit 274c487b37
14 changed files with 210 additions and 29 deletions

View File

@ -333,3 +333,8 @@ class DeleteModelsRequest(TaskRequest):
class GetAllReq(models.Base): class GetAllReq(models.Base):
allow_public = BoolField(default=True) allow_public = BoolField(default=True)
search_hidden = BoolField(default=False) search_hidden = BoolField(default=False)
class UpdateTagsRequest(BatchRequest):
add_tags = ListField([str])
remove_tags = ListField([str])

View File

@ -1,8 +1,11 @@
from collections import defaultdict from collections import defaultdict
from enum import Enum from enum import Enum
from typing import Sequence, Dict from typing import Sequence, Dict, Type
from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model.model import AttributedDocument
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman from apiserver.redis_manager import redman
@ -22,6 +25,51 @@ class OrgBLL:
self._task_tags = _TagsCache(Task, self.redis) self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis) self._model_tags = _TagsCache(Model, self.redis)
def edit_entity_tags(
self,
company_id,
entity_cls: Type[AttributedDocument],
entity_ids: Sequence[str],
add_tags: Sequence[str],
remove_tags: Sequence[str],
) -> int:
if entity_cls not in (Task, Model):
raise errors.bad_request.ValidationError(
"Tags editing can be called on tasks or models only"
)
if not entity_ids:
raise errors.bad_request.ValidationError(
"No entity ids provided for editing tags"
)
if not (add_tags or remove_tags):
raise errors.bad_request.ValidationError(
"Either add tags or remove tags should be provided"
)
updated = 0
if add_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
add_to_set__tags=add_tags
)
if remove_tags:
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
pull_all__tags=remove_tags
)
if not updated:
return 0
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
"project"
)
update_project_time(project_ids=projects)
self.update_tags(
company_id,
entity=Tags.Task if entity_cls is Task else Tags.Model,
projects=projects,
tags=add_tags or remove_tags
)
return updated
def get_tags( def get_tags(
self, self,
company_id: str, company_id: str,
@ -50,10 +98,10 @@ class OrgBLL:
return ret return ret
def update_tags( def update_tags(
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None, self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
): ):
tags_cache = self._get_tags_cache_for_entity(entity) tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company_id, project, tags, system_tags) tags_cache.update_tags(company_id, projects, tags, system_tags)
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]): def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity) tags_cache = self._get_tags_cache_for_entity(entity)

View File

@ -107,7 +107,7 @@ class _TagsCache:
return ret return ret
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None): def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
""" """
Updates tags. If reset is set then both tags and system_tags Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None' are recalculated. Otherwise only those that are not 'None'
@ -123,7 +123,7 @@ class _TagsCache:
if not fields: if not fields:
return return
self._delete_redis_keys(company_id, projects=[project], fields=fields) self._delete_redis_keys(company_id, projects=projects, fields=fields)
def reset_tags(self, company_id: str, projects: Sequence[str]): def reset_tags(self, company_id: str, projects: Sequence[str]):
self._delete_redis_keys( self._delete_redis_keys(

View File

@ -1,6 +1,5 @@
from .task_bll import TaskBLL from .task_bll import TaskBLL
from .utils import ( from .utils import (
ChangeStatusRequest, ChangeStatusRequest,
update_project_time,
validate_status_change, validate_status_change,
) )

View File

@ -1,7 +1,7 @@
from datetime import timedelta, datetime from datetime import timedelta, datetime
from time import sleep from time import sleep
from apiserver.bll.task import update_project_time from apiserver.bll.util import update_project_time
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model.task.task import TaskStatus, Task from apiserver.database.model.task.task import TaskStatus, Task
from apiserver.utilities.threads_manager import ThreadsManager from apiserver.utilities.threads_manager import ThreadsManager

View File

@ -12,6 +12,7 @@ from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
@ -31,7 +32,10 @@ from apiserver.database.model.task.task import (
) )
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.queue import Queue from apiserver.database.model.queue import Queue
from apiserver.database.utils import get_company_or_none_constraint, id as create_id from apiserver.database.utils import (
get_company_or_none_constraint,
id as create_id,
)
from apiserver.es_factory import es_factory from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman from apiserver.redis_manager import redman
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
@ -39,7 +43,6 @@ from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save from .param_utils import params_prepare_for_save
from .utils import ( from .utils import (
ChangeStatusRequest, ChangeStatusRequest,
update_project_time,
deleted_prefix, deleted_prefix,
get_last_metric_updates, get_last_metric_updates,
) )
@ -78,7 +81,11 @@ class TaskBLL:
@staticmethod @staticmethod
def get_by_id( def get_by_id(
company_id, task_id, required_status=None, only_fields=None, allow_public=False, company_id,
task_id,
required_status=None,
only_fields=None,
allow_public=False,
): ):
if only_fields: if only_fields:
if isinstance(only_fields, string_types): if isinstance(only_fields, string_types):
@ -313,7 +320,7 @@ class TaskBLL:
org_bll.update_tags( org_bll.update_tags(
company_id, company_id,
Tags.Task, Tags.Task,
project=new_task.project, projects=[new_task.project],
tags=updated_tags, tags=updated_tags,
system_tags=updated_system_tags, system_tags=updated_system_tags,
) )

View File

@ -7,9 +7,9 @@ from apiserver.bll.task import (
TaskBLL, TaskBLL,
validate_status_change, validate_status_change,
ChangeStatusRequest, ChangeStatusRequest,
update_project_time,
) )
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model from apiserver.database.model.model import Model

View File

@ -1,14 +1,13 @@
from datetime import datetime from datetime import datetime
from typing import Sequence, Union
import attr import attr
import six import six
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.bll.util import update_project_time
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
from apiserver.database.utils import get_options from apiserver.database.utils import get_options
from apiserver.utilities.attrs import typed_attrs from apiserver.utilities.attrs import typed_attrs
@ -158,16 +157,6 @@ def get_possible_status_changes(current_status):
return possible return possible
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
def get_task_for_update( def get_task_for_update(
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
) -> Task: ) -> Task:

View File

@ -1,6 +1,7 @@
import functools import functools
import itertools import itertools
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from typing import ( from typing import (
Optional, Optional,
Callable, Callable,
@ -8,11 +9,13 @@ from typing import (
Tuple, Tuple,
Sequence, Sequence,
TypeVar, TypeVar,
Union,
) )
from boltons import iterutils from boltons import iterutils
from apiserver.apierrors import APIError from apiserver.apierrors import APIError
from apiserver.database.model.project import Project
from apiserver.database.model.settings import Settings from apiserver.database.model.settings import Settings
@ -77,3 +80,13 @@ def run_batch_operation(
} }
) )
return results, failures return results, failures
def update_project_time(project_ids: Union[str, Sequence[str]]):
if not project_ids:
return
if isinstance(project_ids, str):
project_ids = [project_ids]
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())

View File

@ -1095,3 +1095,36 @@ delete_metadata {
} }
} }
} }
update_tags {
"999.0" {
request {
type: object
properties {
ids {
type: array
description: IDs of the models to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated models
}
}
}
}
}

View File

@ -2058,3 +2058,36 @@ move {
} }
} }
} }
update_tags {
"999.0" {
request {
type: object
properties {
ids {
type: array
description: IDs of the tasks to update
items {type: string}
}
add_tags {
type: array
description: User tags to add
items {type: string}
}
remove_tags {
type: array
description: User tags to remove
items {type: string}
}
}
}
response {
type: object
properties {
updated {
type: integer
description: The number of updated tasks
}
}
}
}
}

View File

@ -22,6 +22,7 @@ from apiserver.apimodels.models import (
ModelsDeleteManyRequest, ModelsDeleteManyRequest,
ModelsGetRequest, ModelsGetRequest,
) )
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL
@ -219,7 +220,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags( org_bll.update_tags(
company, company,
Tags.Model, Tags.Model,
project=project, projects=[project],
tags=fields.get("tags"), tags=fields.get("tags"),
system_tags=fields.get("system_tags"), system_tags=fields.get("system_tags"),
) )
@ -678,6 +679,19 @@ def move(call: APICall, company_id: str, request: MoveRequest):
} }
@endpoint("models.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"update": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Model,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("models.add_or_update_metadata", min_version="2.13") @endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata( def add_or_update_metadata(
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest call: APICall, company_id: str, request: AddOrUpdateMetadataRequest

View File

@ -67,6 +67,7 @@ from apiserver.apimodels.tasks import (
GetAllReq, GetAllReq,
DequeueRequest, DequeueRequest,
DequeueManyRequest, DequeueManyRequest,
UpdateTagsRequest,
) )
from apiserver.bll.event import EventBLL from apiserver.bll.event import EventBLL
from apiserver.bll.model import ModelBLL from apiserver.bll.model import ModelBLL
@ -76,7 +77,6 @@ from apiserver.bll.queue import QueueBLL
from apiserver.bll.task import ( from apiserver.bll.task import (
TaskBLL, TaskBLL,
ChangeStatusRequest, ChangeStatusRequest,
update_project_time,
) )
from apiserver.bll.task.artifacts import ( from apiserver.bll.task.artifacts import (
artifacts_prepare_for_save, artifacts_prepare_for_save,
@ -101,7 +101,7 @@ from apiserver.bll.task.task_operations import (
move_tasks_to_trash, move_tasks_to_trash,
) )
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
from apiserver.bll.util import run_batch_operation from apiserver.bll.util import run_batch_operation, update_project_time
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.model.task.output import Output from apiserver.database.model.task.output import Output
@ -112,7 +112,11 @@ from apiserver.database.model.task.task import (
ModelItem, ModelItem,
TaskModelTypes, TaskModelTypes,
) )
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options from apiserver.database.utils import (
get_fields_attr,
parse_from_call,
get_options,
)
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import ( from apiserver.services.utils import (
conform_tag_fields, conform_tag_fields,
@ -493,7 +497,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags( org_bll.update_tags(
company, company,
Tags.Task, Tags.Task,
project=project, projects=[project],
tags=fields.get("tags"), tags=fields.get("tags"),
system_tags=fields.get("system_tags"), system_tags=fields.get("system_tags"),
) )
@ -1325,6 +1329,19 @@ def move(call: APICall, company_id: str, request: MoveRequest):
return {"project_id": project_id} return {"project_id": project_id}
@endpoint("tasks.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"update": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Task,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("tasks.add_or_update_model", min_version="2.13") @endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest): def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
get_task_for_update(company_id=company_id, task_id=request.task, force=True) get_task_for_update(company_id=company_id, task_id=request.task, force=True)

View File

@ -92,6 +92,29 @@ class TestProjectTags(TestService):
self.assertFalse(tag1 in data.tags) self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags) self.assertTrue(tag2 in data.tags)
def test_tags_api(self):
p = self.create_temp("projects", name="Test tags api", description="test")
# task
initial_tags = ["Task tag"]
task = self.new_task(project=p, tags=initial_tags)
data = self.api.projects.get_task_tags(projects=[p])
self.assertEqual(data.tags, initial_tags)
new_tags = ["New task tag"]
self.api.tasks.update_tags(ids=[task], add_tags=new_tags, remove_tags=initial_tags)
data = self.api.projects.get_task_tags(projects=[p])
self.assertEqual(data.tags, new_tags)
# model
initial_tags = ["Model tag"]
model = self.new_model(project=p, tags=initial_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(data.tags, initial_tags)
new_tags = ["New model tag"]
self.api.models.update_tags(ids=[model], add_tags=new_tags)
data = self.api.projects.get_model_tags(projects=[p])
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
def new_task(self, **kwargs): def new_task(self, **kwargs):
self.update_missing( self.update_missing(
kwargs, type="testing", name="test project tags" kwargs, type="testing", name="test project tags"