Support tags-per-project in tags related services

This commit is contained in:
allegroai 2020-06-21 23:54:05 +03:00
parent 5e095af3aa
commit 1ea6408d41
15 changed files with 571 additions and 147 deletions

View File

@ -2,6 +2,7 @@ from jsonmodels import fields, models
class Filter(models.Base):
tags = fields.ListField([str])
system_tags = fields.ListField([str])

View File

@ -1,5 +1,8 @@
from jsonmodels import models, fields
from apimodels import ListField
from apimodels.organization import TagsRequest
class ProjectReq(models.Base):
project = fields.StringField()
@ -14,3 +17,7 @@ class GetHyperParamResp(models.Base):
parameters = fields.ListField(str)
remaining = fields.IntField()
total = fields.IntField()
class ProjectTagsRequest(TagsRequest):
projects = ListField(str)

View File

@ -230,9 +230,31 @@ class EventBLL(object):
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
last_event = last_events[metric_hash][variant_hash]
event_iter = event.get("iter", 0)
event_timestamp = event["timestamp"]
if (event_iter, event_timestamp) >= (
last_event.get("iter", event_iter),
last_event.get("timestamp", event_timestamp),
):
event_data = {
k: event[k]
for k in ("value", "metric", "variant", "iter", "timestamp")
if k in event
}
value = event_data.get("value")
if value is not None:
event_data["min_value"] = min(value, last_event.get("min_value", value))
event_data["max_value"] = max(value, last_event.get("max_value", value))
else:
event_data.update(
**{
k: last_event[k]
for k in ("value", "min_value", "max_value")
if k in last_event
}
)
last_events[metric_hash][variant_hash] = event_data
def _update_last_metric_events_for_task(self, last_events, event):
"""
@ -275,7 +297,13 @@ class EventBLL(object):
flatten_nested_items(
last_scalar_events,
nesting=2,
include_leaves=["value", "metric", "variant"],
include_leaves=[
"value",
"min_value",
"max_value",
"metric",
"variant",
],
)
)

View File

@ -1,6 +1,10 @@
from typing import Sequence
from collections import defaultdict
from enum import Enum
from itertools import chain
from typing import Sequence, Union, Type, Dict
from mongoengine import Q
from redis import Redis
from config import config
from database.model.base import GetMixin
@ -10,40 +14,65 @@ from redis_manager import redman
from utilities import json
log = config.logger(__file__)
_settings_prefix = "services.organization"
class OrgBLL:
class _TagsCache:
_tags_field = "tags"
_system_tags_field = "system_tags"
_settings_prefix = "services.organization"
def __init__(self, redis=None):
self.redis = redis or redman.connection("apiserver")
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
self.db_cls = db_cls
self.redis = redis
@property
def _tags_cache_expiration_seconds(self):
return config.get(
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
)
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
@staticmethod
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
filter_str = "_".join(filter_) if filter_ else ""
return f"{field}_{company}_{filter_str}"
@staticmethod
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
def _get_tags_from_db(
self,
company: str,
field: str,
project: str = None,
filter_: Dict[str, Sequence[str]] = None,
) -> set:
query = Q(company=company)
if filter_:
query &= GetMixin.get_list_field_query("system_tags", filter_)
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if project:
query &= Q(project=project)
tags = set()
for cls_ in (Task, Model):
tags |= set(cls_.objects(query).distinct(field))
return tags
return self.db_cls.objects(query).distinct(field)
def _get_tags_cache_key(
self,
company: str,
field: str,
project: str = None,
filter_: Dict[str, Sequence[str]] = None,
):
"""
Project None means 'from all company projects'
The key is built in the way that scanning company keys for 'all company projects'
will not return the keys related to the particular company projects and vice versa.
So that we can have a fine grain control on what redis keys to invalidate
"""
filter_str = None
if filter_:
filter_str = "_".join(
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
)
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
return "_".join(filter(None, key_parts))
def get_tags(
self, company, include_system: bool = False, filter_: Sequence[str] = None
self,
company: str,
include_system: bool = False,
filter_: Dict[str, Sequence[str]] = None,
project: str = None,
) -> dict:
"""
Get tags and optionally system tags for the company
@ -51,35 +80,114 @@ class OrgBLL:
The function retrieves both cached values from Redis in one call
and re calculates any of them if missing in Redis
"""
fields = [
self._tags_field,
*([self._system_tags_field] if include_system else []),
fields = [self._tags_field]
if include_system:
fields.append(self._system_tags_field)
redis_keys = [
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
for f in fields
]
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
cached = self.redis.mget(redis_keys)
ret = {}
for field, tag_data, key in zip(fields, cached, redis_keys):
if tag_data is not None:
tags = json.loads(tag_data)
else:
tags = list(self._get_tags_from_db(company, field, filter_))
tags = list(self._get_tags_from_db(company, field, project, filter_))
self.redis.setex(
key,
time=self._tags_cache_expiration_seconds,
value=json.dumps(tags),
)
ret[field] = tags
ret[field] = set(tags)
return ret
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
"""
Updates system tags. If reset is set then both tags and system_tags
Updates tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
"""
if reset or tags is not None:
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
if reset or system_tags is not None:
self.redis.delete(
self._get_tags_cache_key(company, self._system_tags_field)
fields = [
field
for field, update in (
(self._tags_field, tags),
(self._system_tags_field, system_tags),
)
if update is not None
]
if not fields:
return
self._delete_redis_keys(company, projects=[project], fields=fields)
def reset_tags(self, company: str, projects: Sequence[str]):
self._delete_redis_keys(
company,
projects=projects,
fields=(self._tags_field, self._system_tags_field),
)
def _delete_redis_keys(
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
):
redis_keys = list(
chain.from_iterable(
self.redis.keys(
self._get_tags_cache_key(company, field=f, project=p) + "*"
)
for f in fields
for p in set(projects) | {None}
)
)
if redis_keys:
self.redis.delete(*redis_keys)
class Tags(Enum):
Task = "task"
Model = "model"
class OrgBLL:
def __init__(self, redis=None):
self.redis = redis or redman.connection("apiserver")
self._task_tags = _TagsCache(Task, self.redis)
self._model_tags = _TagsCache(Model, self.redis)
def get_tags(
self,
company: str,
entity: Tags,
include_system: bool = False,
filter_: Dict[str, Sequence[str]] = None,
projects: Sequence[str] = None,
) -> dict:
tags_cache = self._get_tags_cache_for_entity(entity)
if not projects:
return tags_cache.get_tags(
company, include_system=include_system, filter_=filter_
)
ret = defaultdict(set)
for project in projects:
project_tags = tags_cache.get_tags(
company, include_system=include_system, filter_=filter_, project=project
)
for field, tags in project_tags.items():
ret[field] |= tags
return ret
def update_tags(
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.update_tags(company, project, tags, system_tags)
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
tags_cache = self._get_tags_cache_for_entity(entity)
tags_cache.reset_tags(company, projects=projects)
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
return self._task_tags if entity == Tags.Task else self._model_tags

View File

@ -14,7 +14,7 @@ import database.utils as dbutils
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
from bll.organization import OrgBLL
from bll.organization import OrgBLL, Tags
from config import config
from database.errors import translate_errors_context
from database.model.model import Model
@ -229,7 +229,21 @@ class TaskBLL(object):
validate_project=validate_references or project,
)
new_task.save()
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
if task.project == new_task.project:
updated_tags = tags
updated_system_tags = system_tags
else:
updated_tags = new_task.tags
updated_system_tags = new_task.system_tags
org_bll.update_tags(
company_id,
Tags.Task,
project=new_task.project,
tags=updated_tags,
system_tags=updated_system_tags,
)
return new_task
@classmethod
@ -346,10 +360,12 @@ class TaskBLL(object):
return "__".join((op, "last_metrics") + path)
for path, value in last_scalar_values:
extra_updates[op_path("set", *path)] = value
if path[-1] == "value":
if path[-1] == "min_value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
elif path[-1] == "max_value":
extra_updates[op_path("max", *path[:-1], "max_value")] = value
else:
extra_updates[op_path("set", *path)] = value
if last_events is not None:

View File

@ -1,43 +1,48 @@
_description: "This service provides organization level operations"
get_tags {
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
tags {
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
}
}
response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
}
}
response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
}

View File

@ -196,6 +196,52 @@ _definitions {
}
}
}
tags_request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
projects {
description: "The list of projects under which the tags are searched. If not passed or empty then all the projects are searched"
type: array
items { type: string }
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
tags {
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
}
}
tags_response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
create {
@ -508,7 +554,7 @@ get_hyper_parameters {
parameters {
description: "A list of hyper parameter names"
type: array
items { type: string }
items {type: string}
}
remaining {
description: "Remaining results"
@ -522,3 +568,17 @@ get_hyper_parameters {
}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
get_model_tags {
"2.8" {
description: "Get user and system tags used for the models under the specified projects"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Sequence
from mongoengine import Q, EmbeddedDocument
@ -12,7 +13,7 @@ from apimodels.models import (
PublishModelResponse,
ModelTaskPublishResponse,
)
from bll.organization import OrgBLL
from bll.organization import OrgBLL, Tags
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
@ -128,9 +129,19 @@ def parse_model_fields(call, valid_fields):
return fields
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Model,
project=project,
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(
company, Tags.Model, projects=projects,
)
@ -203,7 +214,7 @@ def update_for_task(call: APICall, company_id, _):
**fields,
)
model.save()
_update_org_tags(company_id, fields)
_update_cached_tags(company_id, project=model.project, fields=fields)
TaskBLL.update_statistics(
task_id=task_id,
@ -248,7 +259,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
**fields,
)
model.save()
_update_org_tags(company_id, fields)
_update_cached_tags(company_id, project=model.project, fields=fields)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
@ -327,7 +338,15 @@ def edit(call: APICall, company_id, _):
if fields:
updated = model.update(upsert=False, **fields)
if updated:
_update_org_tags(company_id, fields)
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
@ -355,7 +374,13 @@ def _update_model(call: APICall, company_id, model_id=None):
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
_update_org_tags(company_id, updated_fields)
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@ -395,7 +420,7 @@ def update(call: APICall, company_id, _):
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).only("id", "task").first()
model = Model.objects(**query).only("id", "task", "project").first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
@ -428,5 +453,5 @@ def update(call: APICall, company_id, _):
del_count = Model.objects(**query).delete()
if del_count:
org_bll.update_org_tags(company_id, reset=True)
_reset_cached_tags(company_id, projects=[model.project])
call.result.data = dict(deleted=del_count > 0)

View File

@ -1,13 +1,22 @@
from collections import defaultdict
from apimodels.organization import TagsRequest
from bll.organization import OrgBLL
from bll.organization import OrgBLL, Tags
from service_repo import endpoint, APICall
from services.utils import get_tags_filter_dictionary, get_tags_response
org_bll = OrgBLL()
@endpoint("organization.get_tags", request_data_model=TagsRequest)
def get_tags(call: APICall, company, request: TagsRequest):
filter_ = request.filter.system_tags if request.filter else None
call.result.data = org_bll.get_tags(
company, include_system=request.include_system, filter_=filter_
)
filter_dict = get_tags_filter_dictionary(request.filter)
ret = defaultdict(set)
for entity in Tags.Model, Tags.Task:
tags = org_bll.get_tags(
company, entity, include_system=request.include_system, filter_=filter_dict,
)
for field, vals in tags.items():
ret[field] |= vals
call.result.data = get_tags_response(ret)

View File

@ -9,7 +9,13 @@ from mongoengine import Q
import database
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.projects import GetHyperParamReq, GetHyperParamResp, ProjectReq
from apimodels.projects import (
GetHyperParamReq,
GetHyperParamResp,
ProjectReq,
ProjectTagsRequest,
)
from bll.organization import OrgBLL, Tags
from bll.task import TaskBLL
from database.errors import translate_errors_context
from database.model import EntityVisibility
@ -18,9 +24,15 @@ from database.model.project import Project
from database.model.task.task import Task, TaskStatus
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from services.utils import (
conform_tag_fields,
conform_output_tags,
get_tags_filter_dictionary,
get_tags_response,
)
from timing_context import TimingContext
org_bll = OrgBLL()
task_bll = TaskBLL()
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
@ -381,3 +393,31 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
"remaining": remaining,
"parameters": parameters,
}
@endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
def get_tags(call: APICall, company, request: ProjectTagsRequest):
ret = org_bll.get_tags(
company,
Tags.Task,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
@endpoint(
"projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
def get_tags(call: APICall, company, request: ProjectTagsRequest):
ret = org_bll.get_tags(
company,
Tags.Model,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)

View File

@ -33,7 +33,7 @@ from apimodels.tasks import (
ResetRequest,
)
from bll.event import EventBLL
from bll.organization import OrgBLL
from bll.organization import OrgBLL, Tags
from bll.queue import QueueBLL
from bll.task import (
TaskBLL,
@ -343,9 +343,19 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
_validate_and_get_task_from_call(call)
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Task,
project=project,
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(
company, Tags.Task, projects=projects
)
@ -357,7 +367,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
with translate_errors_context(), TimingContext("mongo", "save_task"):
task.save()
_update_org_tags(company_id, fields)
_update_cached_tags(company_id, project=task.project, fields=fields)
update_project_time(task.project)
call.result.data_model = IdResponse(id=task.id)
@ -400,7 +410,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
task_id = req_model.task
with translate_errors_context():
task = Task.get_for_writing(id=task_id, company=company_id, _only=["id"])
task = Task.get_for_writing(
id=task_id, company=company_id, _only=["id", "project"]
)
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
@ -416,7 +428,13 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
injected_update=dict(last_update=datetime.utcnow()),
)
if updated_count:
_update_org_tags(company_id, updated_fields)
new_project = updated_fields.get("project", task.project)
if new_project != task.project:
_reset_cached_tags(company_id, projects=[new_project, task.project])
else:
_update_cached_tags(
company_id, project=task.project, fields=updated_fields
)
update_project_time(updated_fields.get("project"))
unprepare_from_saved(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@ -470,8 +488,10 @@ def update_batch(call: APICall, company_id, _):
now = datetime.utcnow()
bulk_ops = []
updated_projects = set()
for id, data in items.items():
fields, valid_fields = prepare_update_fields(call, tasks[id], data)
task = tasks[id]
fields, valid_fields = prepare_update_fields(call, task, data)
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
@ -481,12 +501,20 @@ def update_batch(call: APICall, company_id, _):
)
bulk_ops.append(update_op)
new_project = partial_update_dict.get("project", task.project)
if new_project != task.project:
updated_projects.update({new_project, task.project})
elif any(f in partial_update_dict for f in ("tags", "system_tags")):
updated_projects.add(task.project)
updated = 0
if bulk_ops:
res = Task._get_collection().bulk_write(bulk_ops)
updated = res.modified_count
if updated:
org_bll.update_org_tags(company_id, reset=True)
if updated and updated_projects:
_reset_cached_tags(company_id, projects=list(updated_projects))
call.result.data = {"updated": updated}
@ -542,7 +570,15 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
fixed_fields.update(last_update=now)
updated = task.update(upsert=False, **fixed_fields)
if updated:
_update_org_tags(company_id, fixed_fields)
new_project = fixed_fields.get("project", task.project)
if new_project != task.project:
_reset_cached_tags(
company_id, projects=[new_project, task.project]
)
else:
_update_cached_tags(
company_id, project=task.project, fields=fixed_fields
)
update_project_time(fields.get("project"))
unprepare_from_saved(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@ -710,12 +746,11 @@ def reset(call: APICall, company_id, request: ResetRequest):
if request.clear_all:
updates.update(
set__execution=Execution(),
unset__script=1,
set__execution=Execution(), unset__script=1,
)
else:
updates.update(unset__execution__queue=1)
updates.update(
unset__execution__queue=1,
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
)
@ -909,7 +944,8 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
task.switch_collection(collection_name)
task.delete()
org_bll.update_org_tags(company_id, reset=True)
_reset_cached_tags(company_id, projects=[task.project])
call.result.data = dict(deleted=True, **attr.asdict(result))

View File

@ -1,12 +1,28 @@
from typing import Union, Sequence, Tuple
from apierrors import errors
from apimodels.organization import Filter
from database.model.base import GetMixin
from database.utils import partition_tags
from service_repo import APICall
from service_repo.base import PartialVersion
def get_tags_filter_dictionary(input_: Filter) -> dict:
if not input_:
return {}
return {
field: vals
for field, vals in (("tags", input_.tags), ("system_tags", input_.system_tags))
if vals
}
def get_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()}
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
For old clients both tags and system tags are returned in 'tags' field

View File

@ -1,36 +0,0 @@
from tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.8"):
super().setUp(version=version)
def test_tags(self):
tag1 = "Orgtest tag1"
tag2 = "Orgtest tag2"
system_tag = "Orgtest system tag"
model = self.create_temp(
"models", name="test_org", uri="file:///a", tags=[tag1]
)
task = self.create_temp(
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
)
data = self.api.organization.get_tags()
self.assertTrue(tag1 in data.tags)
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
data = self.api.organization.get_tags(include_system=True)
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
self.assertTrue(system_tag in data.system_tags)
data = self.api.organization.get_tags(
filter={"system_tags": ["__$not", system_tag]}
)
self.assertTrue(tag1 in data.tags)
self.assertFalse(tag2 in data.tags)
self.api.models.delete(model=model)
data = self.api.organization.get_tags()
self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags)

View File

@ -0,0 +1,82 @@
from tests.automated import TestService
class TestProjectTags(TestService):
def setUp(self, version="2.8"):
super().setUp(version=version)
def test_project_tags(self):
tags_1 = ["Test tag 1", "Test tag 2"]
tags_2 = ["Test tag 3", "Test tag 4"]
p1 = self.create_temp("projects", name="Test tags1", description="test")
task1_1 = self.new_task(project=p1, tags=tags_1[:1])
task1_2 = self.new_task(project=p1, tags=tags_1[1:])
p2 = self.create_temp("projects", name="Test tasks2", description="test")
task2 = self.new_task(project=p2, tags=tags_2)
# test tags per project
data = self.api.projects.get_task_tags(projects=[p1])
self.assertEqual(set(tags_1), set(data.tags))
data = self.api.projects.get_model_tags(projects=[p1])
self.assertEqual(set(), set(data.tags))
data = self.api.projects.get_task_tags(projects=[p2])
self.assertEqual(set(tags_2), set(data.tags))
# test tags for projects list
data = self.api.projects.get_task_tags(projects=[p1, p2])
self.assertEqual(set(tags_1) | set(tags_2), set(data.tags))
# test tags for all projects
data = self.api.projects.get_task_tags(projects=[p1, p2])
self.assertTrue((set(tags_1) | set(tags_2)).issubset(data.tags))
# test move to another project
self.api.tasks.edit(task=task1_2, project=p2)
data = self.api.projects.get_task_tags(projects=[p1])
self.assertEqual(set(tags_1[:1]), set(data.tags))
data = self.api.projects.get_task_tags(projects=[p2])
self.assertEqual(set(tags_1[1:]) | set(tags_2), set(data.tags))
# test tags update
self.api.tasks.delete(task=task1_1, force=True)
self.api.tasks.delete(task=task2, force=True)
data = self.api.projects.get_task_tags(projects=[p1, p2])
self.assertEqual(set(tags_1[1:]), set(data.tags))
def test_organization_tags(self):
tag1 = "Orgtest tag1"
tag2 = "Orgtest tag2"
system_tag = "Orgtest system tag"
model = self.new_model(tags=[tag1])
task = self.new_task(tags=[tag1])
data = self.api.organization.get_tags()
self.assertTrue(tag1 in data.tags)
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
data = self.api.organization.get_tags(include_system=True)
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
self.assertTrue(system_tag in data.system_tags)
data = self.api.organization.get_tags(
filter={"system_tags": ["__$not", system_tag]}
)
self.assertTrue(tag1 in data.tags)
self.assertFalse(tag2 in data.tags)
self.api.models.delete(model=model)
data = self.api.organization.get_tags()
self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags)
def new_task(self, **kwargs):
self.update_missing(
kwargs, type="testing", name="test project tags", input=dict(view=dict())
)
return self.create_temp("tasks", **kwargs)
def new_model(self, **kwargs):
self.update_missing(kwargs, name="test project tags", uri="file:///a")
return self.create_temp("models", **kwargs)

View File

@ -8,6 +8,8 @@ from functools import partial
from statistics import mean
from typing import Sequence
from boltons.iterutils import first
import es_factory
from apierrors.errors.bad_request import EventsNotAdded
from tests.automated import TestService
@ -72,6 +74,31 @@ class TestTaskEvents(TestService):
),
)
def test_last_scalar_metrics(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
# send 2 batches to check the interaction with already stored db value
# each batch contains multiple iterations
self.send_batch(events[:50])
self.send_batch(events[50:])
task_data = self.api.tasks.get_by_id(task=task).task
metric_data = first(first(task_data.last_metrics.values()).values())
self.assertEqual(iter_count - 1, metric_data.value)
self.assertEqual(iter_count - 1, metric_data.max_value)
self.assertEqual(0, metric_data.min_value)
def test_task_debug_images(self):
task = self._temp_task()
metric = "Metric1"