mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Support tags-per-project in tags related services
This commit is contained in:
parent
5e095af3aa
commit
1ea6408d41
@ -2,6 +2,7 @@ from jsonmodels import fields, models
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
tags = fields.ListField([str])
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
|
@ -1,5 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apimodels import ListField
|
||||
from apimodels.organization import TagsRequest
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
project = fields.StringField()
|
||||
@ -14,3 +17,7 @@ class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
|
||||
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
@ -230,9 +230,31 @@ class EventBLL(object):
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
|
||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
last_event = last_events[metric_hash][variant_hash]
|
||||
event_iter = event.get("iter", 0)
|
||||
event_timestamp = event["timestamp"]
|
||||
if (event_iter, event_timestamp) >= (
|
||||
last_event.get("iter", event_iter),
|
||||
last_event.get("timestamp", event_timestamp),
|
||||
):
|
||||
event_data = {
|
||||
k: event[k]
|
||||
for k in ("value", "metric", "variant", "iter", "timestamp")
|
||||
if k in event
|
||||
}
|
||||
value = event_data.get("value")
|
||||
if value is not None:
|
||||
event_data["min_value"] = min(value, last_event.get("min_value", value))
|
||||
event_data["max_value"] = max(value, last_event.get("max_value", value))
|
||||
else:
|
||||
event_data.update(
|
||||
**{
|
||||
k: last_event[k]
|
||||
for k in ("value", "min_value", "max_value")
|
||||
if k in last_event
|
||||
}
|
||||
)
|
||||
last_events[metric_hash][variant_hash] = event_data
|
||||
|
||||
def _update_last_metric_events_for_task(self, last_events, event):
|
||||
"""
|
||||
@ -275,7 +297,13 @@ class EventBLL(object):
|
||||
flatten_nested_items(
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=["value", "metric", "variant"],
|
||||
include_leaves=[
|
||||
"value",
|
||||
"min_value",
|
||||
"max_value",
|
||||
"metric",
|
||||
"variant",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1,6 +1,10 @@
|
||||
from typing import Sequence
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Sequence, Union, Type, Dict
|
||||
|
||||
from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
@ -10,40 +14,65 @@ from redis_manager import redman
|
||||
from utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
class _TagsCache:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
|
||||
self.db_cls = db_cls
|
||||
self.redis = redis
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(
|
||||
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
|
||||
)
|
||||
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
|
||||
filter_str = "_".join(filter_) if filter_ else ""
|
||||
return f"{field}_{company}_{filter_str}"
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
|
||||
def _get_tags_from_db(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> set:
|
||||
query = Q(company=company)
|
||||
if filter_:
|
||||
query &= GetMixin.get_list_field_query("system_tags", filter_)
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
|
||||
tags = set()
|
||||
for cls_ in (Task, Model):
|
||||
tags |= set(cls_.objects(query).distinct(field))
|
||||
return tags
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
def _get_tags_cache_key(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
):
|
||||
"""
|
||||
Project None means 'from all company projects'
|
||||
The key is built in the way that scanning company keys for 'all company projects'
|
||||
will not return the keys related to the particular company projects and vice versa.
|
||||
So that we can have a fine grain control on what redis keys to invalidate
|
||||
"""
|
||||
filter_str = None
|
||||
if filter_:
|
||||
filter_str = "_".join(
|
||||
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
|
||||
)
|
||||
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
|
||||
return "_".join(filter(None, key_parts))
|
||||
|
||||
def get_tags(
|
||||
self, company, include_system: bool = False, filter_: Sequence[str] = None
|
||||
self,
|
||||
company: str,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
project: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
@ -51,35 +80,114 @@ class OrgBLL:
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [
|
||||
self._tags_field,
|
||||
*([self._system_tags_field] if include_system else []),
|
||||
fields = [self._tags_field]
|
||||
if include_system:
|
||||
fields.append(self._system_tags_field)
|
||||
redis_keys = [
|
||||
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
|
||||
for f in fields
|
||||
]
|
||||
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, filter_))
|
||||
tags = list(self._get_tags_from_db(company, field, project, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
)
|
||||
ret[field] = tags
|
||||
ret[field] = set(tags)
|
||||
|
||||
return ret
|
||||
|
||||
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
|
||||
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
|
||||
"""
|
||||
Updates system tags. If reset is set then both tags and system_tags
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
if reset or tags is not None:
|
||||
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
|
||||
if reset or system_tags is not None:
|
||||
self.redis.delete(
|
||||
self._get_tags_cache_key(company, self._system_tags_field)
|
||||
fields = [
|
||||
field
|
||||
for field, update in (
|
||||
(self._tags_field, tags),
|
||||
(self._system_tags_field, system_tags),
|
||||
)
|
||||
if update is not None
|
||||
]
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company, projects=[project], fields=fields)
|
||||
|
||||
def reset_tags(self, company: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
company,
|
||||
projects=projects,
|
||||
fields=(self._tags_field, self._system_tags_field),
|
||||
)
|
||||
|
||||
def _delete_redis_keys(
|
||||
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
):
|
||||
redis_keys = list(
|
||||
chain.from_iterable(
|
||||
self.redis.keys(
|
||||
self._get_tags_cache_key(company, field=f, project=p) + "*"
|
||||
)
|
||||
for f in fields
|
||||
for p in set(projects) | {None}
|
||||
)
|
||||
)
|
||||
if redis_keys:
|
||||
self.redis.delete(*redis_keys)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_, project=project
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
|
@ -14,7 +14,7 @@ import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
@ -229,7 +229,21 @@ class TaskBLL(object):
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
@ -346,10 +360,12 @@ class TaskBLL(object):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
if path[-1] == "value":
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
|
@ -1,43 +1,48 @@
|
||||
_description: "This service provides organization level operations"
|
||||
|
||||
get_tags {
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -196,6 +196,52 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
projects {
|
||||
description: "The list of projects under which the tags are searched. If not passed or empty then all the projects are searched"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
@ -508,7 +554,7 @@ get_hyper_parameters {
|
||||
parameters {
|
||||
description: "A list of hyper parameter names"
|
||||
type: array
|
||||
items { type: string }
|
||||
items {type: string}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
@ -522,3 +568,17 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import Q, EmbeddedDocument
|
||||
|
||||
@ -12,7 +13,7 @@ from apimodels.models import (
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
)
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
@ -128,9 +129,19 @@ def parse_model_fields(call, valid_fields):
|
||||
return fields
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Model,
|
||||
project=project,
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(
|
||||
company, Tags.Model, projects=projects,
|
||||
)
|
||||
|
||||
|
||||
@ -203,7 +214,7 @@ def update_for_task(call: APICall, company_id, _):
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
@ -248,7 +259,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
@ -327,7 +338,15 @@ def edit(call: APICall, company_id, _):
|
||||
if fields:
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
_update_org_tags(company_id, fields)
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
@ -355,7 +374,13 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
@ -395,7 +420,7 @@ def update(call: APICall, company_id, _):
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).only("id", "task").first()
|
||||
model = Model.objects(**query).only("id", "task", "project").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
@ -428,5 +453,5 @@ def update(call: APICall, company_id, _):
|
||||
|
||||
del_count = Model.objects(**query).delete()
|
||||
if del_count:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
_reset_cached_tags(company_id, projects=[model.project])
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
|
@ -1,13 +1,22 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from apimodels.organization import TagsRequest
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from service_repo import endpoint, APICall
|
||||
from services.utils import get_tags_filter_dictionary, get_tags_response
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
@endpoint("organization.get_tags", request_data_model=TagsRequest)
|
||||
def get_tags(call: APICall, company, request: TagsRequest):
|
||||
filter_ = request.filter.system_tags if request.filter else None
|
||||
call.result.data = org_bll.get_tags(
|
||||
company, include_system=request.include_system, filter_=filter_
|
||||
)
|
||||
filter_dict = get_tags_filter_dictionary(request.filter)
|
||||
ret = defaultdict(set)
|
||||
for entity in Tags.Model, Tags.Task:
|
||||
tags = org_bll.get_tags(
|
||||
company, entity, include_system=request.include_system, filter_=filter_dict,
|
||||
)
|
||||
for field, vals in tags.items():
|
||||
ret[field] |= vals
|
||||
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
@ -9,7 +9,13 @@ from mongoengine import Q
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.projects import GetHyperParamReq, GetHyperParamResp, ProjectReq
|
||||
from apimodels.projects import (
|
||||
GetHyperParamReq,
|
||||
GetHyperParamResp,
|
||||
ProjectReq,
|
||||
ProjectTagsRequest,
|
||||
)
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from bll.task import TaskBLL
|
||||
from database.errors import translate_errors_context
|
||||
from database.model import EntityVisibility
|
||||
@ -18,9 +24,15 @@ from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
|
||||
from service_repo import APICall, endpoint
|
||||
from services.utils import conform_tag_fields, conform_output_tags
|
||||
from services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
get_tags_filter_dictionary,
|
||||
get_tags_response,
|
||||
)
|
||||
from timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
task_bll = TaskBLL()
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
|
||||
@ -381,3 +393,31 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
"remaining": remaining,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
ret = org_bll.get_tags(
|
||||
company,
|
||||
Tags.Task,
|
||||
include_system=request.include_system,
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
ret = org_bll.get_tags(
|
||||
company,
|
||||
Tags.Model,
|
||||
include_system=request.include_system,
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
@ -33,7 +33,7 @@ from apimodels.tasks import (
|
||||
ResetRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from bll.queue import QueueBLL
|
||||
from bll.task import (
|
||||
TaskBLL,
|
||||
@ -343,9 +343,19 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
|
||||
_validate_and_get_task_from_call(call)
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Task,
|
||||
project=project,
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(
|
||||
company, Tags.Task, projects=projects
|
||||
)
|
||||
|
||||
|
||||
@ -357,7 +367,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "save_task"):
|
||||
task.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
_update_cached_tags(company_id, project=task.project, fields=fields)
|
||||
update_project_time(task.project)
|
||||
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
@ -400,7 +410,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
task_id = req_model.task
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(id=task_id, company=company_id, _only=["id"])
|
||||
task = Task.get_for_writing(
|
||||
id=task_id, company=company_id, _only=["id", "project"]
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
@ -416,7 +428,13 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
injected_update=dict(last_update=datetime.utcnow()),
|
||||
)
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
new_project = updated_fields.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, task.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=task.project, fields=updated_fields
|
||||
)
|
||||
update_project_time(updated_fields.get("project"))
|
||||
unprepare_from_saved(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
@ -470,8 +488,10 @@ def update_batch(call: APICall, company_id, _):
|
||||
now = datetime.utcnow()
|
||||
|
||||
bulk_ops = []
|
||||
updated_projects = set()
|
||||
for id, data in items.items():
|
||||
fields, valid_fields = prepare_update_fields(call, tasks[id], data)
|
||||
task = tasks[id]
|
||||
fields, valid_fields = prepare_update_fields(call, task, data)
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
@ -481,12 +501,20 @@ def update_batch(call: APICall, company_id, _):
|
||||
)
|
||||
bulk_ops.append(update_op)
|
||||
|
||||
new_project = partial_update_dict.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
updated_projects.update({new_project, task.project})
|
||||
elif any(f in partial_update_dict for f in ("tags", "system_tags")):
|
||||
updated_projects.add(task.project)
|
||||
|
||||
updated = 0
|
||||
if bulk_ops:
|
||||
res = Task._get_collection().bulk_write(bulk_ops)
|
||||
updated = res.modified_count
|
||||
if updated:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
|
||||
if updated and updated_projects:
|
||||
_reset_cached_tags(company_id, projects=list(updated_projects))
|
||||
|
||||
call.result.data = {"updated": updated}
|
||||
|
||||
|
||||
@ -542,7 +570,15 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
fixed_fields.update(last_update=now)
|
||||
updated = task.update(upsert=False, **fixed_fields)
|
||||
if updated:
|
||||
_update_org_tags(company_id, fixed_fields)
|
||||
new_project = fixed_fields.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, task.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=task.project, fields=fixed_fields
|
||||
)
|
||||
update_project_time(fields.get("project"))
|
||||
unprepare_from_saved(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
@ -710,12 +746,11 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
|
||||
if request.clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(),
|
||||
unset__script=1,
|
||||
set__execution=Execution(), unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
updates.update(
|
||||
unset__execution__queue=1,
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
)
|
||||
|
||||
@ -909,7 +944,8 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
task.switch_collection(collection_name)
|
||||
|
||||
task.delete()
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
_reset_cached_tags(company_id, projects=[task.project])
|
||||
|
||||
call.result.data = dict(deleted=True, **attr.asdict(result))
|
||||
|
||||
|
||||
|
@ -1,12 +1,28 @@
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.organization import Filter
|
||||
from database.model.base import GetMixin
|
||||
from database.utils import partition_tags
|
||||
from service_repo import APICall
|
||||
from service_repo.base import PartialVersion
|
||||
|
||||
|
||||
def get_tags_filter_dictionary(input_: Filter) -> dict:
|
||||
if not input_:
|
||||
return {}
|
||||
|
||||
return {
|
||||
field: vals
|
||||
for field, vals in (("tags", input_.tags), ("system_tags", input_.system_tags))
|
||||
if vals
|
||||
}
|
||||
|
||||
|
||||
def get_tags_response(ret: dict) -> dict:
|
||||
return {field: sorted(vals) for field, vals in ret.items()}
|
||||
|
||||
|
||||
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
"""
|
||||
For old clients both tags and system tags are returned in 'tags' field
|
||||
|
@ -1,36 +0,0 @@
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestOrganization(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_tags(self):
|
||||
tag1 = "Orgtest tag1"
|
||||
tag2 = "Orgtest tag2"
|
||||
system_tag = "Orgtest system tag"
|
||||
|
||||
model = self.create_temp(
|
||||
"models", name="test_org", uri="file:///a", tags=[tag1]
|
||||
)
|
||||
task = self.create_temp(
|
||||
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
|
||||
)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
|
||||
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
|
||||
data = self.api.organization.get_tags(include_system=True)
|
||||
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
|
||||
self.assertTrue(system_tag in data.system_tags)
|
||||
|
||||
data = self.api.organization.get_tags(
|
||||
filter={"system_tags": ["__$not", system_tag]}
|
||||
)
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
self.assertFalse(tag2 in data.tags)
|
||||
|
||||
self.api.models.delete(model=model)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
82
server/tests/automated/test_project_tags.py
Normal file
82
server/tests/automated/test_project_tags.py
Normal file
@ -0,0 +1,82 @@
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestProjectTags(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_project_tags(self):
|
||||
tags_1 = ["Test tag 1", "Test tag 2"]
|
||||
tags_2 = ["Test tag 3", "Test tag 4"]
|
||||
|
||||
p1 = self.create_temp("projects", name="Test tags1", description="test")
|
||||
task1_1 = self.new_task(project=p1, tags=tags_1[:1])
|
||||
task1_2 = self.new_task(project=p1, tags=tags_1[1:])
|
||||
|
||||
p2 = self.create_temp("projects", name="Test tasks2", description="test")
|
||||
task2 = self.new_task(project=p2, tags=tags_2)
|
||||
|
||||
# test tags per project
|
||||
data = self.api.projects.get_task_tags(projects=[p1])
|
||||
self.assertEqual(set(tags_1), set(data.tags))
|
||||
data = self.api.projects.get_model_tags(projects=[p1])
|
||||
self.assertEqual(set(), set(data.tags))
|
||||
data = self.api.projects.get_task_tags(projects=[p2])
|
||||
self.assertEqual(set(tags_2), set(data.tags))
|
||||
|
||||
# test tags for projects list
|
||||
data = self.api.projects.get_task_tags(projects=[p1, p2])
|
||||
self.assertEqual(set(tags_1) | set(tags_2), set(data.tags))
|
||||
|
||||
# test tags for all projects
|
||||
data = self.api.projects.get_task_tags(projects=[p1, p2])
|
||||
self.assertTrue((set(tags_1) | set(tags_2)).issubset(data.tags))
|
||||
|
||||
# test move to another project
|
||||
self.api.tasks.edit(task=task1_2, project=p2)
|
||||
data = self.api.projects.get_task_tags(projects=[p1])
|
||||
self.assertEqual(set(tags_1[:1]), set(data.tags))
|
||||
data = self.api.projects.get_task_tags(projects=[p2])
|
||||
self.assertEqual(set(tags_1[1:]) | set(tags_2), set(data.tags))
|
||||
|
||||
# test tags update
|
||||
self.api.tasks.delete(task=task1_1, force=True)
|
||||
self.api.tasks.delete(task=task2, force=True)
|
||||
data = self.api.projects.get_task_tags(projects=[p1, p2])
|
||||
self.assertEqual(set(tags_1[1:]), set(data.tags))
|
||||
|
||||
def test_organization_tags(self):
|
||||
tag1 = "Orgtest tag1"
|
||||
tag2 = "Orgtest tag2"
|
||||
system_tag = "Orgtest system tag"
|
||||
|
||||
model = self.new_model(tags=[tag1])
|
||||
task = self.new_task(tags=[tag1])
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
|
||||
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
|
||||
data = self.api.organization.get_tags(include_system=True)
|
||||
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
|
||||
self.assertTrue(system_tag in data.system_tags)
|
||||
|
||||
data = self.api.organization.get_tags(
|
||||
filter={"system_tags": ["__$not", system_tag]}
|
||||
)
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
self.assertFalse(tag2 in data.tags)
|
||||
|
||||
self.api.models.delete(model=model)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs, type="testing", name="test project tags", input=dict(view=dict())
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def new_model(self, **kwargs):
|
||||
self.update_missing(kwargs, name="test project tags", uri="file:///a")
|
||||
return self.create_temp("models", **kwargs)
|
@ -8,6 +8,8 @@ from functools import partial
|
||||
from statistics import mean
|
||||
from typing import Sequence
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
import es_factory
|
||||
from apierrors.errors.bad_request import EventsNotAdded
|
||||
from tests.automated import TestService
|
||||
@ -72,6 +74,31 @@ class TestTaskEvents(TestService):
|
||||
),
|
||||
)
|
||||
|
||||
def test_last_scalar_metrics(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
# send 2 batches to check the interaction with already stored db value
|
||||
# each batch contains multiple iterations
|
||||
self.send_batch(events[:50])
|
||||
self.send_batch(events[50:])
|
||||
|
||||
task_data = self.api.tasks.get_by_id(task=task).task
|
||||
metric_data = first(first(task_data.last_metrics.values()).values())
|
||||
self.assertEqual(iter_count - 1, metric_data.value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
|
||||
def test_task_debug_images(self):
|
||||
task = self._temp_task()
|
||||
metric = "Metric1"
|
||||
|
Loading…
Reference in New Issue
Block a user