mirror of
https://github.com/clearml/clearml-server
synced 2025-04-20 06:04:37 +00:00
Add update_tags api to tasks and models
This commit is contained in:
parent
cc0129a800
commit
274c487b37
@ -333,3 +333,8 @@ class DeleteModelsRequest(TaskRequest):
|
||||
class GetAllReq(models.Base):
|
||||
allow_public = BoolField(default=True)
|
||||
search_hidden = BoolField(default=False)
|
||||
|
||||
|
||||
class UpdateTagsRequest(BatchRequest):
|
||||
add_tags = ListField([str])
|
||||
remove_tags = ListField([str])
|
||||
|
@ -1,8 +1,11 @@
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Sequence, Dict
|
||||
from typing import Sequence, Dict, Type
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import AttributedDocument
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
@ -22,6 +25,51 @@ class OrgBLL:
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def edit_entity_tags(
|
||||
self,
|
||||
company_id,
|
||||
entity_cls: Type[AttributedDocument],
|
||||
entity_ids: Sequence[str],
|
||||
add_tags: Sequence[str],
|
||||
remove_tags: Sequence[str],
|
||||
) -> int:
|
||||
if entity_cls not in (Task, Model):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Tags editing can be called on tasks or models only"
|
||||
)
|
||||
if not entity_ids:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"No entity ids provided for editing tags"
|
||||
)
|
||||
if not (add_tags or remove_tags):
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Either add tags or remove tags should be provided"
|
||||
)
|
||||
|
||||
updated = 0
|
||||
if add_tags:
|
||||
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
|
||||
add_to_set__tags=add_tags
|
||||
)
|
||||
if remove_tags:
|
||||
updated += entity_cls.objects(company=company_id, id__in=entity_ids).update(
|
||||
pull_all__tags=remove_tags
|
||||
)
|
||||
if not updated:
|
||||
return 0
|
||||
|
||||
projects = entity_cls.objects(company=company_id, id__in=entity_ids).distinct(
|
||||
"project"
|
||||
)
|
||||
update_project_time(project_ids=projects)
|
||||
self.update_tags(
|
||||
company_id,
|
||||
entity=Tags.Task if entity_cls is Task else Tags.Model,
|
||||
projects=projects,
|
||||
tags=add_tags or remove_tags
|
||||
)
|
||||
return updated
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company_id: str,
|
||||
@ -50,10 +98,10 @@ class OrgBLL:
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company_id: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
self, company_id: str, entity: Tags, projects: Sequence[str], tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company_id, project, tags, system_tags)
|
||||
tags_cache.update_tags(company_id, projects, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company_id: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
|
@ -107,7 +107,7 @@ class _TagsCache:
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(self, company_id: str, project: str, tags=None, system_tags=None):
|
||||
def update_tags(self, company_id: str, projects: Sequence[str], tags=None, system_tags=None):
|
||||
"""
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
@ -123,7 +123,7 @@ class _TagsCache:
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company_id, projects=[project], fields=fields)
|
||||
self._delete_redis_keys(company_id, projects=projects, fields=fields)
|
||||
|
||||
def reset_tags(self, company_id: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
|
@ -1,6 +1,5 @@
|
||||
from .task_bll import TaskBLL
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from datetime import timedelta, datetime
|
||||
from time import sleep
|
||||
|
||||
from apiserver.bll.task import update_project_time
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import TaskStatus, Task
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
@ -12,6 +12,7 @@ from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@ -31,7 +32,10 @@ from apiserver.database.model.task.task import (
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
id as create_id,
|
||||
)
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
@ -39,7 +43,6 @@ from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
get_last_metric_updates,
|
||||
)
|
||||
@ -78,7 +81,11 @@ class TaskBLL:
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
only_fields=None,
|
||||
allow_public=False,
|
||||
):
|
||||
if only_fields:
|
||||
if isinstance(only_fields, string_types):
|
||||
@ -313,7 +320,7 @@ class TaskBLL:
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
projects=[new_task.project],
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
|
@ -7,9 +7,9 @@ from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
validate_status_change,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
)
|
||||
from apiserver.bll.task.task_cleanup import cleanup_task, CleanupResult
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
|
@ -1,14 +1,13 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence, Union
|
||||
|
||||
import attr
|
||||
import six
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.util import update_project_time
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, TaskSystemTags
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.utilities.attrs import typed_attrs
|
||||
@ -158,16 +157,6 @@ def get_possible_status_changes(current_status):
|
||||
return possible
|
||||
|
||||
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
||||
|
||||
def get_task_for_update(
|
||||
company_id: str, task_id: str, allow_all_statuses: bool = False, force: bool = False
|
||||
) -> Task:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import functools
|
||||
import itertools
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Optional,
|
||||
Callable,
|
||||
@ -8,11 +9,13 @@ from typing import (
|
||||
Tuple,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from boltons import iterutils
|
||||
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.settings import Settings
|
||||
|
||||
|
||||
@ -77,3 +80,13 @@ def run_batch_operation(
|
||||
}
|
||||
)
|
||||
return results, failures
|
||||
|
||||
|
||||
def update_project_time(project_ids: Union[str, Sequence[str]]):
|
||||
if not project_ids:
|
||||
return
|
||||
|
||||
if isinstance(project_ids, str):
|
||||
project_ids = [project_ids]
|
||||
|
||||
return Project.objects(id__in=project_ids).update(last_update=datetime.utcnow())
|
||||
|
@ -1095,3 +1095,36 @@ delete_metadata {
|
||||
}
|
||||
}
|
||||
}
|
||||
update_tags {
|
||||
"999.0" {
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
type: array
|
||||
description: IDs of the models to update
|
||||
items {type: string}
|
||||
}
|
||||
add_tags {
|
||||
type: array
|
||||
description: User tags to add
|
||||
items {type: string}
|
||||
}
|
||||
remove_tags {
|
||||
type: array
|
||||
description: User tags to remove
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
type: integer
|
||||
description: The number of updated models
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -2058,3 +2058,36 @@ move {
|
||||
}
|
||||
}
|
||||
}
|
||||
update_tags {
|
||||
"999.0" {
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
type: array
|
||||
description: IDs of the tasks to update
|
||||
items {type: string}
|
||||
}
|
||||
add_tags {
|
||||
type: array
|
||||
description: User tags to add
|
||||
items {type: string}
|
||||
}
|
||||
remove_tags {
|
||||
type: array
|
||||
description: User tags to remove
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
type: integer
|
||||
description: The number of updated tasks
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -22,6 +22,7 @@ from apiserver.apimodels.models import (
|
||||
ModelsDeleteManyRequest,
|
||||
ModelsGetRequest,
|
||||
)
|
||||
from apiserver.apimodels.tasks import UpdateTagsRequest
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
@ -219,7 +220,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Model,
|
||||
project=project,
|
||||
projects=[project],
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
@ -678,6 +679,19 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.update_tags")
|
||||
def update_tags(_, company_id: str, request: UpdateTagsRequest):
|
||||
return {
|
||||
"update": org_bll.edit_entity_tags(
|
||||
company_id=company_id,
|
||||
entity_cls=Model,
|
||||
entity_ids=request.ids,
|
||||
add_tags=request.add_tags,
|
||||
remove_tags=request.remove_tags,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
|
@ -67,6 +67,7 @@ from apiserver.apimodels.tasks import (
|
||||
GetAllReq,
|
||||
DequeueRequest,
|
||||
DequeueManyRequest,
|
||||
UpdateTagsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.model import ModelBLL
|
||||
@ -76,7 +77,6 @@ from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.task import (
|
||||
TaskBLL,
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
)
|
||||
from apiserver.bll.task.artifacts import (
|
||||
artifacts_prepare_for_save,
|
||||
@ -101,7 +101,7 @@ from apiserver.bll.task.task_operations import (
|
||||
move_tasks_to_trash,
|
||||
)
|
||||
from apiserver.bll.task.utils import update_task, get_task_for_update, deleted_prefix
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.bll.util import run_batch_operation, update_project_time
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.task.output import Output
|
||||
@ -112,7 +112,11 @@ from apiserver.database.model.task.task import (
|
||||
ModelItem,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
||||
from apiserver.database.utils import (
|
||||
get_fields_attr,
|
||||
parse_from_call,
|
||||
get_options,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
@ -493,7 +497,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Task,
|
||||
project=project,
|
||||
projects=[project],
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
@ -1325,6 +1329,19 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
return {"project_id": project_id}
|
||||
|
||||
|
||||
@endpoint("tasks.update_tags")
|
||||
def update_tags(_, company_id: str, request: UpdateTagsRequest):
|
||||
return {
|
||||
"update": org_bll.edit_entity_tags(
|
||||
company_id=company_id,
|
||||
entity_cls=Task,
|
||||
entity_ids=request.ids,
|
||||
add_tags=request.add_tags,
|
||||
remove_tags=request.remove_tags,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||
def add_or_update_model(call: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
get_task_for_update(company_id=company_id, task_id=request.task, force=True)
|
||||
|
@ -92,6 +92,29 @@ class TestProjectTags(TestService):
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
||||
|
||||
def test_tags_api(self):
|
||||
p = self.create_temp("projects", name="Test tags api", description="test")
|
||||
|
||||
# task
|
||||
initial_tags = ["Task tag"]
|
||||
task = self.new_task(project=p, tags=initial_tags)
|
||||
data = self.api.projects.get_task_tags(projects=[p])
|
||||
self.assertEqual(data.tags, initial_tags)
|
||||
new_tags = ["New task tag"]
|
||||
self.api.tasks.update_tags(ids=[task], add_tags=new_tags, remove_tags=initial_tags)
|
||||
data = self.api.projects.get_task_tags(projects=[p])
|
||||
self.assertEqual(data.tags, new_tags)
|
||||
|
||||
# model
|
||||
initial_tags = ["Model tag"]
|
||||
model = self.new_model(project=p, tags=initial_tags)
|
||||
data = self.api.projects.get_model_tags(projects=[p])
|
||||
self.assertEqual(data.tags, initial_tags)
|
||||
new_tags = ["New model tag"]
|
||||
self.api.models.update_tags(ids=[model], add_tags=new_tags)
|
||||
data = self.api.projects.get_model_tags(projects=[p])
|
||||
self.assertEqual(set(data.tags), set([*new_tags, *initial_tags]))
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs, type="testing", name="test project tags"
|
||||
|
Loading…
Reference in New Issue
Block a user