mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add organization.get_tags to obtain the set of all used task, model, queue and project tags
This commit is contained in:
10
server/apimodels/organization.py
Normal file
10
server/apimodels/organization.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from jsonmodels import fields, models
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
class TagsRequest(models.Base):
|
||||
include_system = fields.BoolField(default=False)
|
||||
filter = fields.EmbeddedField(Filter)
|
||||
85
server/bll/organization/__init__.py
Normal file
85
server/bll/organization/__init__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import Q
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model import Model
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
from utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(
|
||||
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
|
||||
filter_str = "_".join(filter_) if filter_ else ""
|
||||
return f"{field}_{company}_{filter_str}"
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
|
||||
query = Q(company=company)
|
||||
if filter_:
|
||||
query &= GetMixin.get_list_field_query("system_tags", filter_)
|
||||
|
||||
tags = set()
|
||||
for cls_ in (Task, Model):
|
||||
tags |= set(cls_.objects(query).distinct(field))
|
||||
return tags
|
||||
|
||||
def get_tags(
|
||||
self, company, include_system: bool = False, filter_: Sequence[str] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
Return the dictionary of tags per tags field name
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [
|
||||
self._tags_field,
|
||||
*([self._system_tags_field] if include_system else []),
|
||||
]
|
||||
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
)
|
||||
ret[field] = tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
|
||||
"""
|
||||
Updates system tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
if reset or tags is not None:
|
||||
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
|
||||
if reset or system_tags is not None:
|
||||
self.redis.delete(
|
||||
self._get_tags_cache_key(company, self._system_tags_field)
|
||||
)
|
||||
@@ -14,6 +14,7 @@ import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
@@ -30,11 +31,13 @@ from database.model.task.task import (
|
||||
)
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from services.utils import validate_tags
|
||||
from timing_context import TimingContext
|
||||
from utilities.dicts import deep_merge
|
||||
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
class TaskBLL(object):
|
||||
@@ -166,6 +169,7 @@ class TaskBLL(object):
|
||||
execution_overrides: Optional[dict] = None,
|
||||
validate_references: bool = False,
|
||||
) -> Task:
|
||||
validate_tags(tags, system_tags)
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id)
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
@@ -212,7 +216,7 @@ class TaskBLL(object):
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
|
||||
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
@@ -344,7 +348,7 @@ class TaskBLL(object):
|
||||
|
||||
metric_stats = {
|
||||
dbutils.hash_field_name(metric_key): MetricEventStats(
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
|
||||
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
|
||||
)
|
||||
for metric_key, metric_data in last_events.items()
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ log = config.logger(__file__)
|
||||
|
||||
class WorkerBLL:
|
||||
def __init__(self, es=None, redis=None):
|
||||
self.es_client = es if es is not None else es_factory.connect("workers")
|
||||
self.redis = redis if redis is not None else redman.connection("workers")
|
||||
self.es_client = es or es_factory.connect("workers")
|
||||
self.redis = redis or redman.connection("workers")
|
||||
self._stats = WorkerStats(self.es_client)
|
||||
|
||||
@property
|
||||
|
||||
3
server/config/default/services/organization.conf
Normal file
3
server/config/default/services/organization.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
tags_cache {
|
||||
expiration_seconds: 3600
|
||||
}
|
||||
@@ -3,7 +3,7 @@ from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional
|
||||
|
||||
from boltons.iterutils import first
|
||||
from boltons.iterutils import first, bucketize
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
@@ -98,6 +98,37 @@ class GetMixin(PropsMixin):
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
|
||||
_default = "in"
|
||||
_ops = {"not": "nin"}
|
||||
_next = _default
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
|
||||
def key(self, v):
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
self._next = self._default
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
@@ -175,17 +206,7 @@ class GetMixin(PropsMixin):
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
exclude = [t for t in data if t.startswith("-")]
|
||||
include = list(set(data).difference(exclude))
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
if include:
|
||||
dict_query[f"{mongoengine_field}__in"] = include
|
||||
if exclude:
|
||||
dict_query[f"{mongoengine_field}__nin"] = [
|
||||
t[1:] for t in exclude
|
||||
]
|
||||
query &= cls.get_list_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
@@ -229,6 +250,47 @@ class GetMixin(PropsMixin):
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
is included.
|
||||
|
||||
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
|
||||
or by a preceding "__$not" value (operator)
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{
|
||||
f"{mongoengine_field}__{action}": list(
|
||||
set(filter(None, actions[action]))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
if not allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
from mongoengine import Document, StringField, DateTimeField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeDictField
|
||||
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import GetMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
@@ -61,8 +61,8 @@ class Model(DbModelMixin, Document):
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = StrippedStringField(default="", user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
from mongoengine import StringField, DateTimeField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
@@ -36,7 +36,7 @@ class Project(AttributedDocument):
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True))
|
||||
system_tags = ListField(StringField(required=True))
|
||||
tags = SafeSortedListField(StringField(required=True))
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
|
||||
@@ -4,11 +4,10 @@ from mongoengine import (
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import ProperDictMixin, GetMixin
|
||||
from database.model.company import Company
|
||||
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
|
||||
@@ -172,8 +172,8 @@ class Task(AttributedDocument):
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
last_worker = StringField()
|
||||
last_worker_report = DateTimeField()
|
||||
|
||||
@@ -23,10 +23,25 @@ def migrate_auth(db: Database):
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
"""
|
||||
Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
1. Sort tags and system tags
|
||||
2. Remove the old indices from the collections since
|
||||
they may come out of sync with the latest changes
|
||||
in the code and mongo libraries update
|
||||
"""
|
||||
|
||||
fields = ("tags", "system_tags")
|
||||
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
|
||||
for collection_name in ("task", "model", "project", "queue"):
|
||||
collection = db[collection_name]
|
||||
for doc in collection.find(filter=query, projection=fields):
|
||||
update = {
|
||||
field: sorted(doc[field])
|
||||
for field in fields
|
||||
if doc.get(field)
|
||||
}
|
||||
if update:
|
||||
collection.update_one({"_id": doc["_id"]}, {"$set": update})
|
||||
|
||||
_drop_all_indices_from_collections(
|
||||
db,
|
||||
[
|
||||
|
||||
43
server/schema/services/organization.conf
Normal file
43
server/schema/services/organization.conf
Normal file
@@ -0,0 +1,43 @@
|
||||
_description: "This service provides organization level operations"
|
||||
|
||||
get_tags {
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ from apimodels.models import (
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
)
|
||||
from bll.organization import OrgBLL
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
@@ -29,37 +30,34 @@ from services.utils import conform_tag_fields, conform_output_tags
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
def get_by_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
models = Model.get_many(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
conform_output_tags(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
def get_by_task_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
def get_by_task_id(call: APICall, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=call.identity.company)
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["output"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
@@ -70,13 +68,11 @@ def get_by_task_id(call):
|
||||
|
||||
model_id = task.output.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
@@ -84,24 +80,24 @@ def get_by_task_id(call):
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall):
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=call.identity.company, query_dict=call.data, allow_public=True
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call: APICall):
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
models = Model.get_many(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
@@ -128,13 +124,18 @@ create_fields = {
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
return fields
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call, company_id, _):
|
||||
assert isinstance(call, APICall)
|
||||
def update_for_task(call: APICall, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
uri = call.data.get("uri")
|
||||
iteration = call.data.get("iteration")
|
||||
@@ -177,7 +178,9 @@ def update_for_task(call, company_id, _):
|
||||
|
||||
if task.output and task.output.model:
|
||||
# model exists, update
|
||||
res = _update_model(call, model_id=task.output.model).to_struct()
|
||||
res = _update_model(
|
||||
call, company_id, model_id=task.output.model
|
||||
).to_struct()
|
||||
res.update({"id": task.output.model, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
@@ -200,6 +203,7 @@ def update_for_task(call, company_id, _):
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
@@ -216,48 +220,46 @@ def update_for_task(call, company_id, _):
|
||||
request_data_model=CreateModelRequest,
|
||||
response_data_model=CreateModelResponse,
|
||||
)
|
||||
def create(call, company, req_model):
|
||||
assert isinstance(call, APICall)
|
||||
assert isinstance(req_model, CreateModelRequest)
|
||||
identity = call.identity
|
||||
def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
|
||||
if req_model.public:
|
||||
company = ""
|
||||
company_id = ""
|
||||
|
||||
with translate_errors_context():
|
||||
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company, project=project)
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(call, req_data)
|
||||
validate_task(company_id, req_data)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
|
||||
# create and save model
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=identity.user,
|
||||
company=company,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=datetime.utcnow(),
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
|
||||
def prepare_update_fields(call, fields):
|
||||
def prepare_update_fields(call, company_id, fields: dict):
|
||||
fields = fields.copy()
|
||||
if "uri" in fields:
|
||||
# clear UI cache if URI is provided (model updated)
|
||||
fields["ui_cache"] = fields.pop("ui_cache", {})
|
||||
if "task" in fields:
|
||||
validate_task(call, fields)
|
||||
validate_task(company_id, fields)
|
||||
|
||||
if "labels" in fields:
|
||||
labels = fields["labels"]
|
||||
@@ -282,27 +284,26 @@ def prepare_update_fields(call, fields):
|
||||
"labels values must be integers", values=invalid_values
|
||||
)
|
||||
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
return fields
|
||||
|
||||
|
||||
def validate_task(call, fields):
|
||||
Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])
|
||||
def validate_task(company_id, fields: dict):
|
||||
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
def edit(call: APICall):
|
||||
identity = call.identity
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
@@ -320,44 +321,41 @@ def edit(call: APICall):
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=identity.company,
|
||||
last_iteration_max=iteration,
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
if fields:
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
_update_org_tags(company_id, fields)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call: APICall, model_id=None):
|
||||
identity = call.identity
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
# get model by id
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
data = prepare_update_fields(call, call.data)
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=identity.company,
|
||||
last_iteration_max=iteration,
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(
|
||||
call.identity.company, model.id, data
|
||||
)
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
@@ -365,8 +363,8 @@ def _update_model(call: APICall, model_id=None):
|
||||
@endpoint(
|
||||
"models.update", required_fields=["model"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call):
|
||||
call.result.data_model = _update_model(call)
|
||||
def update(call, company_id, _):
|
||||
call.result.data_model = _update_model(call, company_id)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -374,10 +372,10 @@ def update(call):
|
||||
request_data_model=PublishModelRequest,
|
||||
response_data_model=PublishModelResponse,
|
||||
)
|
||||
def set_ready(call: APICall, company, req_model: PublishModelRequest):
|
||||
def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
|
||||
updated, published_task_data = TaskBLL.model_set_ready(
|
||||
model_id=req_model.model,
|
||||
company_id=company,
|
||||
company_id=company_id,
|
||||
publish_task=req_model.publish_task,
|
||||
force_publish_task=req_model.force_publish_task,
|
||||
)
|
||||
@@ -391,14 +389,12 @@ def set_ready(call: APICall, company, req_model: PublishModelRequest):
|
||||
|
||||
|
||||
@endpoint("models.delete", required_fields=["model"])
|
||||
def update(call):
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
def update(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
force = call.data.get("force", False)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).only("id", "task").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
@@ -431,4 +427,6 @@ def update(call):
|
||||
)
|
||||
|
||||
del_count = Model.objects(**query).delete()
|
||||
if del_count:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
|
||||
13
server/services/organization.py
Normal file
13
server/services/organization.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from apimodels.organization import TagsRequest
|
||||
from bll.organization import OrgBLL
|
||||
from service_repo import endpoint, APICall
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
@endpoint("organization.get_tags", request_data_model=TagsRequest)
|
||||
def get_tags(call: APICall, company, request: TagsRequest):
|
||||
filter_ = request.filter.system_tags if request.filter else None
|
||||
call.result.data = org_bll.get_tags(
|
||||
company, include_system=request.include_system, filter_=filter_
|
||||
)
|
||||
@@ -155,10 +155,8 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
|
||||
{
|
||||
"$match": {
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"$in": project_ids,
|
||||
},
|
||||
"project": {"$in": project_ids},
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
@@ -276,7 +274,7 @@ def create(call):
|
||||
|
||||
with translate_errors_context():
|
||||
fields = parse_from_call(call.data, create_fields, Project.get_fields())
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
@@ -313,7 +311,7 @@ def update(call: APICall):
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
updated = project.update(upsert=False, **fields)
|
||||
|
||||
@@ -58,7 +58,9 @@ def get_all(call: APICall):
|
||||
|
||||
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
|
||||
def create(call: APICall, company_id, request: CreateRequest):
|
||||
tags, system_tags = conform_tags(call, request.tags, request.system_tags)
|
||||
tags, system_tags = conform_tags(
|
||||
call, request.tags, request.system_tags, validate=True
|
||||
)
|
||||
queue = queue_bll.create(
|
||||
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
|
||||
)
|
||||
@@ -73,7 +75,7 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
)
|
||||
def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
data = call.data_model_for_partial_update
|
||||
conform_tag_fields(call, data)
|
||||
conform_tag_fields(call, data, validate=True)
|
||||
updated, fields = queue_bll.update(
|
||||
company_id=company_id, queue_id=req_model.queue, **data
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Callable, Type, TypeVar, Union
|
||||
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
@@ -31,6 +31,7 @@ from apimodels.tasks import (
|
||||
AddOrUpdateArtifactsResponse,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.organization import OrgBLL
|
||||
from bll.queue import QueueBLL
|
||||
from bll.task import (
|
||||
TaskBLL,
|
||||
@@ -63,7 +64,7 @@ task_script_fields = set(get_fields(Script))
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
queue_bll = QueueBLL()
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
NonResponsiveTasksWatchdog.start()
|
||||
|
||||
@@ -129,7 +130,7 @@ def escape_execution_parameters(call: APICall):
|
||||
|
||||
|
||||
@endpoint("tasks.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall):
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
@@ -137,7 +138,7 @@ def get_all_ex(call: APICall):
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
)
|
||||
@@ -146,7 +147,7 @@ def get_all_ex(call: APICall):
|
||||
|
||||
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
def get_all(call: APICall):
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
|
||||
escape_execution_parameters(call)
|
||||
@@ -154,7 +155,7 @@ def get_all(call: APICall):
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=call.identity.company,
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True, # required in case projection is requested for public dataset/versions
|
||||
@@ -255,7 +256,7 @@ create_fields = {
|
||||
|
||||
|
||||
def prepare_for_save(call: APICall, fields: dict):
|
||||
conform_tag_fields(call, fields)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_fields:
|
||||
@@ -315,7 +316,7 @@ def prepare_create_fields(
|
||||
return prepare_for_save(call, fields)
|
||||
|
||||
|
||||
def _validate_and_get_task_from_call(call: APICall, **kwargs):
|
||||
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
|
||||
with translate_errors_context(
|
||||
field_does_not_exist_cls=errors.bad_request.ValidationError
|
||||
), TimingContext("code", "parse_call"):
|
||||
@@ -325,7 +326,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs):
|
||||
with TimingContext("code", "validate"):
|
||||
task_bll.validate(task)
|
||||
|
||||
return task
|
||||
return task, fields
|
||||
|
||||
|
||||
@endpoint("tasks.validate", request_data_model=CreateRequest)
|
||||
@@ -333,14 +334,21 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
|
||||
_validate_and_get_task_from_call(call)
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
|
||||
)
|
||||
def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
task = _validate_and_get_task_from_call(call)
|
||||
task, fields = _validate_and_get_task_from_call(call)
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "save_task"):
|
||||
task.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
update_project_time(task.project)
|
||||
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
@@ -398,8 +406,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
partial_update_dict=partial_update_dict,
|
||||
injected_update=dict(last_update=datetime.utcnow()),
|
||||
)
|
||||
|
||||
update_project_time(updated_fields.get("project"))
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
update_project_time(updated_fields.get("project"))
|
||||
unprepare_from_saved(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
@@ -431,9 +440,7 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
|
||||
|
||||
|
||||
@endpoint("tasks.update_batch")
|
||||
def update_batch(call: APICall):
|
||||
identity = call.identity
|
||||
|
||||
def update_batch(call: APICall, company_id, _):
|
||||
items = call.batched_data
|
||||
if items is None:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
@@ -443,7 +450,7 @@ def update_batch(call: APICall):
|
||||
tasks = {
|
||||
t.id: t
|
||||
for t in Task.get_many_for_writing(
|
||||
company=identity.company, query=Q(id__in=list(items))
|
||||
company=company_id, query=Q(id__in=list(items))
|
||||
)
|
||||
}
|
||||
|
||||
@@ -461,7 +468,7 @@ def update_batch(call: APICall):
|
||||
continue
|
||||
partial_update_dict.update(last_update=now)
|
||||
update_op = UpdateOne(
|
||||
{"_id": id, "company": identity.company}, {"$set": partial_update_dict}
|
||||
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
|
||||
)
|
||||
bulk_ops.append(update_op)
|
||||
|
||||
@@ -469,7 +476,8 @@ def update_batch(call: APICall):
|
||||
if bulk_ops:
|
||||
res = Task._get_collection().bulk_write(bulk_ops)
|
||||
updated = res.modified_count
|
||||
|
||||
if updated:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
call.result.data = {"updated": updated}
|
||||
|
||||
|
||||
@@ -524,7 +532,9 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
fields.update(last_update=now)
|
||||
fixed_fields.update(last_update=now)
|
||||
updated = task.update(upsert=False, **fixed_fields)
|
||||
update_project_time(fields.get("project"))
|
||||
if updated:
|
||||
_update_org_tags(company_id, fixed_fields)
|
||||
update_project_time(fields.get("project"))
|
||||
unprepare_from_saved(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
@@ -877,7 +887,7 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
task.switch_collection(collection_name)
|
||||
|
||||
task.delete()
|
||||
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
call.result.data = dict(deleted=True, **attr.asdict(result))
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.base import GetMixin
|
||||
from database.utils import partition_tags
|
||||
from service_repo import APICall
|
||||
from service_repo.base import PartialVersion
|
||||
@@ -19,13 +21,13 @@ def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
|
||||
|
||||
|
||||
def conform_tag_fields(call: APICall, document: dict):
|
||||
def conform_tag_fields(call: APICall, document: dict, validate=False):
|
||||
"""
|
||||
Upgrade old client tags in place
|
||||
"""
|
||||
if "tags" in document:
|
||||
tags, system_tags = conform_tags(
|
||||
call, document["tags"], document.get("system_tags")
|
||||
call, document["tags"], document.get("system_tags"), validate
|
||||
)
|
||||
if tags != document.get("tags"):
|
||||
document["tags"] = tags
|
||||
@@ -34,16 +36,18 @@ def conform_tag_fields(call: APICall, document: dict):
|
||||
|
||||
|
||||
def conform_tags(
|
||||
call: APICall, tags: Sequence, system_tags: Sequence
|
||||
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
"""
|
||||
Make sure that 'tags' from the old SDK clients
|
||||
are correctly split into 'tags' and 'system_tags'
|
||||
Make sure that there are no duplicate tags
|
||||
"""
|
||||
if validate:
|
||||
validate_tags(tags, system_tags)
|
||||
if call.requested_endpoint_version < PartialVersion("2.3"):
|
||||
tags, system_tags = _upgrade_tags(call, tags, system_tags)
|
||||
return _get_unique_values(tags), _get_unique_values(system_tags)
|
||||
return tags, system_tags
|
||||
|
||||
|
||||
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
|
||||
@@ -55,9 +59,12 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
|
||||
return tags, system_tags
|
||||
|
||||
|
||||
def _get_unique_values(values: Sequence) -> Sequence:
|
||||
"""Get unique values from the given sequence"""
|
||||
if not values:
|
||||
return values
|
||||
|
||||
return list(set(values))
|
||||
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
|
||||
for values in filter(None, (tags, system_tags)):
|
||||
unsupported = [
|
||||
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
|
||||
]
|
||||
if unsupported:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"unsupported tag prefix", values=unsupported
|
||||
)
|
||||
|
||||
36
server/tests/automated/test_organization.py
Normal file
36
server/tests/automated/test_organization.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestOrganization(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_tags(self):
|
||||
tag1 = "Orgtest tag1"
|
||||
tag2 = "Orgtest tag2"
|
||||
system_tag = "Orgtest system tag"
|
||||
|
||||
model = self.create_temp(
|
||||
"models", name="test_org", uri="file:///a", tags=[tag1]
|
||||
)
|
||||
task = self.create_temp(
|
||||
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
|
||||
)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
|
||||
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
|
||||
data = self.api.organization.get_tags(include_system=True)
|
||||
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
|
||||
self.assertTrue(system_tag in data.system_tags)
|
||||
|
||||
data = self.api.organization.get_tags(
|
||||
filter={"system_tags": ["__$not", system_tag]}
|
||||
)
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
self.assertFalse(tag2 in data.tags)
|
||||
|
||||
self.api.models.delete(model=model)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
||||
Reference in New Issue
Block a user