Add organization.get_tags to obtain the set of all used task, model, queue and project tags

This commit is contained in:
allegroai
2020-06-01 13:00:35 +03:00
parent bf7f0f646b
commit c85ab66ae6
19 changed files with 417 additions and 132 deletions

View File

@@ -0,0 +1,10 @@
from jsonmodels import fields, models
class Filter(models.Base):
system_tags = fields.ListField([str])
class TagsRequest(models.Base):
include_system = fields.BoolField(default=False)
filter = fields.EmbeddedField(Filter)

View File

@@ -0,0 +1,85 @@
from typing import Sequence
from mongoengine import Q
from config import config
from database.model.base import GetMixin
from database.model.model import Model
from database.model.task.task import Task
from redis_manager import redman
from utilities import json
log = config.logger(__file__)
class OrgBLL:
_tags_field = "tags"
_system_tags_field = "system_tags"
_settings_prefix = "services.organization"
def __init__(self, redis=None):
self.redis = redis or redman.connection("apiserver")
@property
def _tags_cache_expiration_seconds(self):
return config.get(
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
)
@staticmethod
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
filter_str = "_".join(filter_) if filter_ else ""
return f"{field}_{company}_{filter_str}"
@staticmethod
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
query = Q(company=company)
if filter_:
query &= GetMixin.get_list_field_query("system_tags", filter_)
tags = set()
for cls_ in (Task, Model):
tags |= set(cls_.objects(query).distinct(field))
return tags
def get_tags(
self, company, include_system: bool = False, filter_: Sequence[str] = None
) -> dict:
"""
Get tags and optionally system tags for the company
Return the dictionary of tags per tags field name
The function retrieves both cached values from Redis in one call
and re calculates any of them if missing in Redis
"""
fields = [
self._tags_field,
*([self._system_tags_field] if include_system else []),
]
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
cached = self.redis.mget(redis_keys)
ret = {}
for field, tag_data, key in zip(fields, cached, redis_keys):
if tag_data is not None:
tags = json.loads(tag_data)
else:
tags = list(self._get_tags_from_db(company, field, filter_))
self.redis.setex(
key,
time=self._tags_cache_expiration_seconds,
value=json.dumps(tags),
)
ret[field] = tags
return ret
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
"""
Updates system tags. If reset is set then both tags and system_tags
are recalculated. Otherwise only those that are not 'None'
"""
if reset or tags is not None:
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
if reset or system_tags is not None:
self.redis.delete(
self._get_tags_cache_key(company, self._system_tags_field)
)

View File

@@ -14,6 +14,7 @@ import database.utils as dbutils
import es_factory
from apierrors import errors
from apimodels.tasks import Artifact as ApiArtifact
from bll.organization import OrgBLL
from config import config
from database.errors import translate_errors_context
from database.model.model import Model
@@ -30,11 +31,13 @@ from database.model.task.task import (
)
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from services.utils import validate_tags
from timing_context import TimingContext
from utilities.dicts import deep_merge
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
log = config.logger(__file__)
org_bll = OrgBLL()
class TaskBLL(object):
@@ -166,6 +169,7 @@ class TaskBLL(object):
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
validate_tags(tags, system_tags)
task = cls.get_by_id(company_id=company_id, task_id=task_id)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
@@ -212,7 +216,7 @@ class TaskBLL(object):
validate_project=validate_references or project,
)
new_task.save()
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
return new_task
@classmethod
@@ -344,7 +348,7 @@ class TaskBLL(object):
metric_stats = {
dbutils.hash_field_name(metric_key): MetricEventStats(
metric=metric_key, event_stats_by_type=events_per_type(metric_data),
metric=metric_key, event_stats_by_type=events_per_type(metric_data)
)
for metric_key, metric_data in last_events.items()
}

View File

@@ -33,8 +33,8 @@ log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es if es is not None else es_factory.connect("workers")
self.redis = redis if redis is not None else redman.connection("workers")
self.es_client = es or es_factory.connect("workers")
self.redis = redis or redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@property

View File

@@ -0,0 +1,3 @@
tags_cache {
expiration_seconds: 3600
}

View File

@@ -3,7 +3,7 @@ from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional
from boltons.iterutils import first
from boltons.iterutils import first, bucketize
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
@@ -98,6 +98,37 @@ class GetMixin(PropsMixin):
self.list_fields = list_fields
self.pattern_fields = pattern_fields
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_default = "in"
_ops = {"not": "nin"}
_next = _default
def __init__(self, legacy=False):
self._legacy = legacy
def key(self, v):
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"]
elif v.startswith(self.op_prefix):
self._next = self._ops.get(v[len(self.op_prefix) :], self._default)
return None
next_ = self._next
self._next = self._default
return next_
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
get_all_query_options = QueryParameterOptions()
@classmethod
@@ -175,17 +206,7 @@ class GetMixin(PropsMixin):
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
exclude = [t for t in data if t.startswith("-")]
include = list(set(data).difference(exclude))
mongoengine_field = field.replace(".", "__")
if include:
dict_query[f"{mongoengine_field}__in"] = include
if exclude:
dict_query[f"{mongoengine_field}__nin"] = [
t[1:] for t in exclude
]
query &= cls.get_list_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
@@ -229,6 +250,47 @@ class GetMixin(PropsMixin):
return query & RegexQ(**dict_query)
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
is included.
- Exclusion can be specified by a leading "-" for each value (API versions <2.8)
or by a preceding "__$not" value (operator)
"""
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{
f"{mongoengine_field}__{action}": list(
set(filter(None, actions[action]))
)
}
)
if not allow_empty:
return q
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
if allow_public:

View File

@@ -1,7 +1,7 @@
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
from mongoengine import Document, StringField, DateTimeField, BooleanField
from database import Database, strict
from database.fields import StrippedStringField, SafeDictField
from database.fields import StrippedStringField, SafeDictField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import GetMixin
from database.model.model_labels import ModelLabels
@@ -61,8 +61,8 @@ class Model(DbModelMixin, Document):
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
comment = StringField(user_set_allowed=True)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
uri = StrippedStringField(default="", user_set_allowed=True)
framework = StringField()
design = SafeDictField()

View File

@@ -1,7 +1,7 @@
from mongoengine import StringField, DateTimeField, ListField
from mongoengine import StringField, DateTimeField
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import AttributedDocument
from database.model.base import GetMixin
@@ -36,7 +36,7 @@ class Project(AttributedDocument):
)
description = StringField(required=True)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True))
system_tags = ListField(StringField(required=True))
tags = SafeSortedListField(StringField(required=True))
system_tags = SafeSortedListField(StringField(required=True))
default_output_destination = StrippedStringField()
last_update = DateTimeField()

View File

@@ -4,11 +4,10 @@ from mongoengine import (
StringField,
DateTimeField,
EmbeddedDocumentListField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField
from database.fields import StrippedStringField, SafeSortedListField
from database.model import DbModelMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.company import Company
@@ -41,7 +40,7 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()

View File

@@ -172,8 +172,8 @@ class Task(AttributedDocument):
project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
last_worker = StringField()
last_worker_report = DateTimeField()

View File

@@ -23,10 +23,25 @@ def migrate_auth(db: Database):
def migrate_backend(db: Database):
"""
Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
1. Sort tags and system tags
2. Remove the old indices from the collections since
they may come out of sync with the latest changes
in the code and mongo libraries update
"""
fields = ("tags", "system_tags")
query = {"$or": [{field: {"$exists": True, "$ne": []}} for field in fields]}
for collection_name in ("task", "model", "project", "queue"):
collection = db[collection_name]
for doc in collection.find(filter=query, projection=fields):
update = {
field: sorted(doc[field])
for field in fields
if doc.get(field)
}
if update:
collection.update_one({"_id": doc["_id"]}, {"$set": update})
_drop_all_indices_from_collections(
db,
[

View File

@@ -0,0 +1,43 @@
_description: "This service provides organization level operations"
get_tags {
"2.8" {
description: "Get all the user and system tags used for the company tasks and models"
request {
type: object
properties {
include_system {
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
type: boolean
default: false
}
filter {
description: "Filter on entities to collect tags from"
type: object
properties {
system_tags {
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
type: array
items {type: string}
}
}
}
}
}
response {
type: object
properties {
tags {
description: "The list of unique tag values"
type: array
items {type: string}
}
system_tags {
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
type: array
items {type: string}
}
}
}
}
}

View File

@@ -12,6 +12,7 @@ from apimodels.models import (
PublishModelResponse,
ModelTaskPublishResponse,
)
from bll.organization import OrgBLL
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
@@ -29,37 +30,34 @@ from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
log = config.logger(__file__)
org_bll = OrgBLL()
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call):
assert isinstance(call, APICall)
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
models = Model.get_many(
company=call.identity.company,
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
"no such public or company model", id=model_id, company=company_id,
)
conform_output_tags(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call):
assert isinstance(call, APICall)
def get_by_task_id(call: APICall, company_id, _):
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=call.identity.company)
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["output"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
@@ -70,13 +68,11 @@ def get_by_task_id(call):
model_id = task.output.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
@@ -84,24 +80,24 @@ def get_by_task_id(call):
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=call.identity.company, query_dict=call.data, allow_public=True
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
models = Model.get_many(
company=call.identity.company,
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
@@ -128,13 +124,18 @@ create_fields = {
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
return fields
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
)
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call, company_id, _):
assert isinstance(call, APICall)
def update_for_task(call: APICall, company_id, _):
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
@@ -177,7 +178,9 @@ def update_for_task(call, company_id, _):
if task.output and task.output.model:
# model exists, update
res = _update_model(call, model_id=task.output.model).to_struct()
res = _update_model(
call, company_id, model_id=task.output.model
).to_struct()
res.update({"id": task.output.model, "created": False})
call.result.data = res
return
@@ -200,6 +203,7 @@ def update_for_task(call, company_id, _):
**fields,
)
model.save()
_update_org_tags(company_id, fields)
TaskBLL.update_statistics(
task_id=task_id,
@@ -216,48 +220,46 @@ def update_for_task(call, company_id, _):
request_data_model=CreateModelRequest,
response_data_model=CreateModelResponse,
)
def create(call, company, req_model):
assert isinstance(call, APICall)
assert isinstance(req_model, CreateModelRequest)
identity = call.identity
def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company = ""
company_id = ""
with translate_errors_context():
project = req_model.project
if project:
validate_id(Project, company=company, project=project)
validate_id(Project, company=company_id, project=project)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(call, req_data)
validate_task(company_id, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
# create and save model
model = Model(
id=database.utils.id(),
user=identity.user,
company=company,
user=call.identity.user,
company=company_id,
created=datetime.utcnow(),
**fields,
)
model.save()
_update_org_tags(company_id, fields)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
def prepare_update_fields(call, fields):
def prepare_update_fields(call, company_id, fields: dict):
fields = fields.copy()
if "uri" in fields:
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(call, fields)
validate_task(company_id, fields)
if "labels" in fields:
labels = fields["labels"]
@@ -282,27 +284,26 @@ def prepare_update_fields(call, fields):
"labels values must be integers", values=invalid_values
)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
return fields
def validate_task(call, fields):
Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])
def validate_task(company_id, fields: dict):
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call: APICall):
identity = call.identity
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, fields)
fields = prepare_update_fields(call, company_id, fields)
for key in fields:
field = getattr(model, key, None)
@@ -320,44 +321,41 @@ def edit(call: APICall):
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
if fields:
updated = model.update(upsert=False, **fields)
if updated:
_update_org_tags(company_id, fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, model_id=None):
identity = call.identity
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
# get model by id
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
data = prepare_update_fields(call, call.data)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
updated_count, updated_fields = Model.safe_update(
call.identity.company, model.id, data
)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
_update_org_tags(company_id, updated_fields)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -365,8 +363,8 @@ def _update_model(call: APICall, model_id=None):
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call):
call.result.data_model = _update_model(call)
def update(call, company_id, _):
call.result.data_model = _update_model(call, company_id)
@endpoint(
@@ -374,10 +372,10 @@ def update(call):
request_data_model=PublishModelRequest,
response_data_model=PublishModelResponse,
)
def set_ready(call: APICall, company, req_model: PublishModelRequest):
def set_ready(call: APICall, company_id, req_model: PublishModelRequest):
updated, published_task_data = TaskBLL.model_set_ready(
model_id=req_model.model,
company_id=company,
company_id=company_id,
publish_task=req_model.publish_task,
force_publish_task=req_model.force_publish_task,
)
@@ -391,14 +389,12 @@ def set_ready(call: APICall, company, req_model: PublishModelRequest):
@endpoint("models.delete", required_fields=["model"])
def update(call):
assert isinstance(call, APICall)
identity = call.identity
def update(call: APICall, company_id, _):
model_id = call.data["model"]
force = call.data.get("force", False)
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).only("id", "task").first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
@@ -431,4 +427,6 @@ def update(call):
)
del_count = Model.objects(**query).delete()
if del_count:
org_bll.update_org_tags(company_id, reset=True)
call.result.data = dict(deleted=del_count > 0)

View File

@@ -0,0 +1,13 @@
from apimodels.organization import TagsRequest
from bll.organization import OrgBLL
from service_repo import endpoint, APICall
org_bll = OrgBLL()
@endpoint("organization.get_tags", request_data_model=TagsRequest)
def get_tags(call: APICall, company, request: TagsRequest):
filter_ = request.filter.system_tags if request.filter else None
call.result.data = org_bll.get_tags(
company, include_system=request.include_system, filter_=filter_
)

View File

@@ -155,10 +155,8 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
{
"$match": {
"type": {"$in": ["training", "testing", "annotation"]},
"project": {
"company": {"$in": [None, "", company_id]},
"$in": project_ids,
},
"project": {"$in": project_ids},
"company": {"$in": [None, "", company_id]},
}
},
ensure_valid_fields(),
@@ -276,7 +274,7 @@ def create(call):
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -313,7 +311,7 @@ def update(call: APICall):
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)

View File

@@ -58,7 +58,9 @@ def get_all(call: APICall):
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags(call, request.tags, request.system_tags)
tags, system_tags = conform_tags(
call, request.tags, request.system_tags, validate=True
)
queue = queue_bll.create(
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
)
@@ -73,7 +75,7 @@ def create(call: APICall, company_id, request: CreateRequest):
)
def update(call: APICall, company_id, req_model: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data)
conform_tag_fields(call, data, validate=True)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)

View File

@@ -1,7 +1,7 @@
from copy import deepcopy
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Callable, Type, TypeVar, Union
from typing import Sequence, Callable, Type, TypeVar, Union, Tuple
import attr
import dpath
@@ -31,6 +31,7 @@ from apimodels.tasks import (
AddOrUpdateArtifactsResponse,
)
from bll.event import EventBLL
from bll.organization import OrgBLL
from bll.queue import QueueBLL
from bll.task import (
TaskBLL,
@@ -63,7 +64,7 @@ task_script_fields = set(get_fields(Script))
task_bll = TaskBLL()
event_bll = EventBLL()
queue_bll = QueueBLL()
org_bll = OrgBLL()
NonResponsiveTasksWatchdog.start()
@@ -129,7 +130,7 @@ def escape_execution_parameters(call: APICall):
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
@@ -137,7 +138,7 @@ def get_all_ex(call: APICall):
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
tasks = Task.get_many_with_join(
company=call.identity.company,
company=company_id,
query_dict=call.data,
allow_public=True, # required in case projection is requested for public dataset/versions
)
@@ -146,7 +147,7 @@ def get_all_ex(call: APICall):
@endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall):
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
escape_execution_parameters(call)
@@ -154,7 +155,7 @@ def get_all(call: APICall):
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=call.identity.company,
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True, # required in case projection is requested for public dataset/versions
@@ -255,7 +256,7 @@ create_fields = {
def prepare_for_save(call: APICall, fields: dict):
conform_tag_fields(call, fields)
conform_tag_fields(call, fields, validate=True)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
@@ -315,7 +316,7 @@ def prepare_create_fields(
return prepare_for_save(call, fields)
def _validate_and_get_task_from_call(call: APICall, **kwargs):
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
with translate_errors_context(
field_does_not_exist_cls=errors.bad_request.ValidationError
), TimingContext("code", "parse_call"):
@@ -325,7 +326,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs):
with TimingContext("code", "validate"):
task_bll.validate(task)
return task
return task, fields
@endpoint("tasks.validate", request_data_model=CreateRequest)
@@ -333,14 +334,21 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
_validate_and_get_task_from_call(call)
def _update_org_tags(company, fields: dict):
org_bll.update_org_tags(
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
)
@endpoint(
"tasks.create", request_data_model=CreateRequest, response_data_model=IdResponse
)
def create(call: APICall, company_id, req_model: CreateRequest):
task = _validate_and_get_task_from_call(call)
task, fields = _validate_and_get_task_from_call(call)
with translate_errors_context(), TimingContext("mongo", "save_task"):
task.save()
_update_org_tags(company_id, fields)
update_project_time(task.project)
call.result.data_model = IdResponse(id=task.id)
@@ -398,8 +406,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
partial_update_dict=partial_update_dict,
injected_update=dict(last_update=datetime.utcnow()),
)
update_project_time(updated_fields.get("project"))
if updated_count:
_update_org_tags(company_id, updated_fields)
update_project_time(updated_fields.get("project"))
unprepare_from_saved(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -431,9 +440,7 @@ def set_requirements(call: APICall, company_id, req_model: SetRequirementsReques
@endpoint("tasks.update_batch")
def update_batch(call: APICall):
identity = call.identity
def update_batch(call: APICall, company_id, _):
items = call.batched_data
if items is None:
raise errors.bad_request.BatchContainsNoItems()
@@ -443,7 +450,7 @@ def update_batch(call: APICall):
tasks = {
t.id: t
for t in Task.get_many_for_writing(
company=identity.company, query=Q(id__in=list(items))
company=company_id, query=Q(id__in=list(items))
)
}
@@ -461,7 +468,7 @@ def update_batch(call: APICall):
continue
partial_update_dict.update(last_update=now)
update_op = UpdateOne(
{"_id": id, "company": identity.company}, {"$set": partial_update_dict}
{"_id": id, "company": company_id}, {"$set": partial_update_dict}
)
bulk_ops.append(update_op)
@@ -469,7 +476,8 @@ def update_batch(call: APICall):
if bulk_ops:
res = Task._get_collection().bulk_write(bulk_ops)
updated = res.modified_count
if updated:
org_bll.update_org_tags(company_id, reset=True)
call.result.data = {"updated": updated}
@@ -524,7 +532,9 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
fields.update(last_update=now)
fixed_fields.update(last_update=now)
updated = task.update(upsert=False, **fixed_fields)
update_project_time(fields.get("project"))
if updated:
_update_org_tags(company_id, fixed_fields)
update_project_time(fields.get("project"))
unprepare_from_saved(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
@@ -877,7 +887,7 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
task.switch_collection(collection_name)
task.delete()
org_bll.update_org_tags(company_id, reset=True)
call.result.data = dict(deleted=True, **attr.asdict(result))

View File

@@ -1,5 +1,7 @@
from typing import Union, Sequence, Tuple
from apierrors import errors
from database.model.base import GetMixin
from database.utils import partition_tags
from service_repo import APICall
from service_repo.base import PartialVersion
@@ -19,13 +21,13 @@ def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
def conform_tag_fields(call: APICall, document: dict):
def conform_tag_fields(call: APICall, document: dict, validate=False):
"""
Upgrade old client tags in place
"""
if "tags" in document:
tags, system_tags = conform_tags(
call, document["tags"], document.get("system_tags")
call, document["tags"], document.get("system_tags"), validate
)
if tags != document.get("tags"):
document["tags"] = tags
@@ -34,16 +36,18 @@ def conform_tag_fields(call: APICall, document: dict):
def conform_tags(
call: APICall, tags: Sequence, system_tags: Sequence
call: APICall, tags: Sequence, system_tags: Sequence, validate=False
) -> Tuple[Sequence, Sequence]:
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if validate:
validate_tags(tags, system_tags)
if call.requested_endpoint_version < PartialVersion("2.3"):
tags, system_tags = _upgrade_tags(call, tags, system_tags)
return _get_unique_values(tags), _get_unique_values(system_tags)
return tags, system_tags
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
@@ -55,9 +59,12 @@ def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
return tags, system_tags
def _get_unique_values(values: Sequence) -> Sequence:
"""Get unique values from the given sequence"""
if not values:
return values
return list(set(values))
def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
for values in filter(None, (tags, system_tags)):
unsupported = [
t for t in values if t.startswith(GetMixin.ListFieldBucketHelper.op_prefix)
]
if unsupported:
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)

View File

@@ -0,0 +1,36 @@
from tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.8"):
super().setUp(version=version)
def test_tags(self):
tag1 = "Orgtest tag1"
tag2 = "Orgtest tag2"
system_tag = "Orgtest system tag"
model = self.create_temp(
"models", name="test_org", uri="file:///a", tags=[tag1]
)
task = self.create_temp(
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
)
data = self.api.organization.get_tags()
self.assertTrue(tag1 in data.tags)
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
data = self.api.organization.get_tags(include_system=True)
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
self.assertTrue(system_tag in data.system_tags)
data = self.api.organization.get_tags(
filter={"system_tags": ["__$not", system_tag]}
)
self.assertTrue(tag1 in data.tags)
self.assertFalse(tag2 in data.tags)
self.api.models.delete(model=model)
data = self.api.organization.get_tags()
self.assertFalse(tag1 in data.tags)
self.assertTrue(tag2 in data.tags)