2019-06-10 21:24:35 +00:00
|
|
|
from datetime import datetime
|
2021-05-03 14:52:54 +00:00
|
|
|
from functools import partial
|
2021-05-03 15:07:37 +00:00
|
|
|
from typing import Sequence
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
from mongoengine import Q, EmbeddedDocument
|
|
|
|
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver import database
|
|
|
|
from apiserver.apierrors import errors
|
|
|
|
from apiserver.apierrors.errors.bad_request import InvalidModelId
|
2021-01-05 16:05:44 +00:00
|
|
|
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest
|
2021-05-03 15:04:17 +00:00
|
|
|
from apiserver.apimodels.batch import BatchResponse, BatchRequest
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.apimodels.models import (
|
2019-06-10 21:24:35 +00:00
|
|
|
CreateModelRequest,
|
|
|
|
CreateModelResponse,
|
|
|
|
PublishModelRequest,
|
|
|
|
PublishModelResponse,
|
2020-07-06 18:50:43 +00:00
|
|
|
GetFrameworksRequest,
|
2021-05-03 14:38:09 +00:00
|
|
|
DeleteModelRequest,
|
2021-05-03 14:50:25 +00:00
|
|
|
DeleteMetadataRequest,
|
|
|
|
AddOrUpdateMetadataRequest,
|
2021-05-03 14:52:54 +00:00
|
|
|
ModelsPublishManyRequest,
|
|
|
|
ModelsDeleteManyRequest,
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
2021-05-03 14:52:54 +00:00
|
|
|
from apiserver.bll.model import ModelBLL
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.bll.organization import OrgBLL, Tags
|
2021-05-03 14:44:54 +00:00
|
|
|
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.bll.task import TaskBLL
|
2021-05-03 14:52:54 +00:00
|
|
|
from apiserver.bll.task.task_operations import publish_task
|
|
|
|
from apiserver.bll.util import run_batch_operation
|
2021-01-05 14:44:31 +00:00
|
|
|
from apiserver.config_repo import config
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.database.errors import translate_errors_context
|
|
|
|
from apiserver.database.model import validate_id
|
2021-05-03 14:50:25 +00:00
|
|
|
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.database.model.model import Model
|
|
|
|
from apiserver.database.model.project import Project
|
2021-05-03 14:56:50 +00:00
|
|
|
from apiserver.database.model.task.task import (
|
|
|
|
Task,
|
|
|
|
TaskStatus,
|
|
|
|
ModelItem,
|
|
|
|
TaskModelNames,
|
|
|
|
TaskModelTypes,
|
|
|
|
)
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.database.utils import (
|
2019-06-10 21:24:35 +00:00
|
|
|
parse_from_call,
|
|
|
|
get_company_or_none_constraint,
|
|
|
|
filter_fields,
|
|
|
|
)
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.service_repo import APICall, endpoint
|
2021-05-03 14:50:25 +00:00
|
|
|
from apiserver.services.utils import (
|
|
|
|
conform_tag_fields,
|
|
|
|
conform_output_tags,
|
|
|
|
ModelsBackwardsCompatibility,
|
|
|
|
validate_metadata,
|
|
|
|
get_metadata_from_api,
|
|
|
|
)
|
2021-01-05 14:28:49 +00:00
|
|
|
from apiserver.timing_context import TimingContext
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
log = config.logger(__file__)
|
2020-06-01 10:00:35 +00:00
|
|
|
org_bll = OrgBLL()
|
2021-01-05 16:05:44 +00:00
|
|
|
project_bll = ProjectBLL()
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint("models.get_by_id", required_fields=["model"])
|
2020-06-01 10:00:35 +00:00
|
|
|
def get_by_id(call: APICall, company_id, _):
|
2019-06-10 21:24:35 +00:00
|
|
|
model_id = call.data["model"]
|
|
|
|
|
|
|
|
with translate_errors_context():
|
2019-09-24 18:34:35 +00:00
|
|
|
models = Model.get_many(
|
2020-06-01 10:00:35 +00:00
|
|
|
company=company_id,
|
2019-06-10 21:24:35 +00:00
|
|
|
query_dict=call.data,
|
|
|
|
query=Q(id=model_id),
|
|
|
|
allow_public=True,
|
|
|
|
)
|
2019-09-24 18:34:35 +00:00
|
|
|
if not models:
|
2019-06-10 21:24:35 +00:00
|
|
|
raise errors.bad_request.InvalidModelId(
|
2020-06-01 10:00:35 +00:00
|
|
|
"no such public or company model", id=model_id, company=company_id,
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
2019-09-24 18:34:35 +00:00
|
|
|
conform_output_tags(call, models[0])
|
|
|
|
call.result.data = {"model": models[0]}
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint("models.get_by_task_id", required_fields=["task"])
|
2020-06-01 10:00:35 +00:00
|
|
|
def get_by_task_id(call: APICall, company_id, _):
|
2021-05-03 14:46:00 +00:00
|
|
|
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
|
|
|
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
task_id = call.data["task"]
|
|
|
|
|
|
|
|
with translate_errors_context():
|
2020-06-01 10:00:35 +00:00
|
|
|
query = dict(id=task_id, company=company_id)
|
2021-05-03 14:46:00 +00:00
|
|
|
task = Task.get(_only=["models"], **query)
|
2019-09-24 18:34:35 +00:00
|
|
|
if not task:
|
2019-06-10 21:24:35 +00:00
|
|
|
raise errors.bad_request.InvalidTaskId(**query)
|
2021-05-03 14:52:54 +00:00
|
|
|
if not task.models or not task.models.output:
|
2021-05-03 14:46:00 +00:00
|
|
|
raise errors.bad_request.MissingTaskFields(field="models.output")
|
2019-06-10 21:24:35 +00:00
|
|
|
|
2021-05-03 14:46:00 +00:00
|
|
|
model_id = task.models.output[-1].model
|
2019-09-24 18:34:35 +00:00
|
|
|
model = Model.objects(
|
2020-06-01 10:00:35 +00:00
|
|
|
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
2019-06-10 21:24:35 +00:00
|
|
|
).first()
|
2019-09-24 18:34:35 +00:00
|
|
|
if not model:
|
2019-06-10 21:24:35 +00:00
|
|
|
raise errors.bad_request.InvalidModelId(
|
2020-06-01 10:00:35 +00:00
|
|
|
"no such public or company model", id=model_id, company=company_id,
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
2019-09-24 18:34:35 +00:00
|
|
|
model_dict = model.to_proper_dict()
|
|
|
|
conform_output_tags(call, model_dict)
|
|
|
|
call.result.data = {"model": model_dict}
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
|
2021-05-03 14:44:54 +00:00
|
|
|
def _process_include_subprojects(call_data: dict):
|
|
|
|
include_subprojects = call_data.pop("include_subprojects", False)
|
|
|
|
project_ids = call_data.get("project")
|
|
|
|
if not project_ids or not include_subprojects:
|
|
|
|
return
|
|
|
|
|
|
|
|
if not isinstance(project_ids, list):
|
|
|
|
project_ids = [project_ids]
|
|
|
|
call_data["project"] = project_ids_with_children(project_ids)
|
|
|
|
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
@endpoint("models.get_all_ex", required_fields=[])
|
2020-06-01 10:00:35 +00:00
|
|
|
def get_all_ex(call: APICall, company_id, _):
|
2019-09-24 18:34:35 +00:00
|
|
|
conform_tag_fields(call, call.data)
|
2019-06-10 21:24:35 +00:00
|
|
|
with translate_errors_context():
|
2021-05-03 14:44:54 +00:00
|
|
|
_process_include_subprojects(call.data)
|
2019-06-10 21:24:35 +00:00
|
|
|
with TimingContext("mongo", "models_get_all_ex"):
|
|
|
|
models = Model.get_many_with_join(
|
2020-06-01 10:00:35 +00:00
|
|
|
company=company_id, query_dict=call.data, allow_public=True
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
2019-09-24 18:34:35 +00:00
|
|
|
conform_output_tags(call, models)
|
2019-06-10 21:24:35 +00:00
|
|
|
call.result.data = {"models": models}
|
|
|
|
|
|
|
|
|
2021-01-05 15:44:59 +00:00
|
|
|
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
|
|
|
def get_by_id_ex(call: APICall, company_id, _):
|
|
|
|
conform_tag_fields(call, call.data)
|
|
|
|
with translate_errors_context():
|
|
|
|
with TimingContext("mongo", "models_get_by_id_ex"):
|
|
|
|
models = Model.get_many_with_join(
|
|
|
|
company=company_id, query_dict=call.data, allow_public=True
|
|
|
|
)
|
|
|
|
conform_output_tags(call, models)
|
|
|
|
call.result.data = {"models": models}
|
|
|
|
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
@endpoint("models.get_all", required_fields=[])
|
2020-06-01 10:00:35 +00:00
|
|
|
def get_all(call: APICall, company_id, _):
|
2019-09-24 18:34:35 +00:00
|
|
|
conform_tag_fields(call, call.data)
|
2019-06-10 21:24:35 +00:00
|
|
|
with translate_errors_context():
|
|
|
|
with TimingContext("mongo", "models_get_all"):
|
|
|
|
models = Model.get_many(
|
2020-06-01 10:00:35 +00:00
|
|
|
company=company_id,
|
2019-06-10 21:24:35 +00:00
|
|
|
parameters=call.data,
|
|
|
|
query_dict=call.data,
|
|
|
|
allow_public=True,
|
|
|
|
)
|
2019-09-24 18:34:35 +00:00
|
|
|
conform_output_tags(call, models)
|
2019-06-10 21:24:35 +00:00
|
|
|
call.result.data = {"models": models}
|
|
|
|
|
|
|
|
|
2020-07-06 18:50:43 +00:00
|
|
|
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
|
|
|
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
|
|
|
|
call.result.data = {
|
|
|
|
"frameworks": sorted(
|
2021-05-03 14:42:10 +00:00
|
|
|
project_bll.get_model_frameworks(company_id, project_ids=request.projects)
|
2020-07-06 18:50:43 +00:00
|
|
|
)
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
create_fields = {
|
|
|
|
"name": None,
|
|
|
|
"tags": list,
|
2019-09-24 18:34:35 +00:00
|
|
|
"system_tags": list,
|
2019-06-10 21:24:35 +00:00
|
|
|
"task": Task,
|
|
|
|
"comment": None,
|
|
|
|
"uri": None,
|
|
|
|
"project": Project,
|
|
|
|
"parent": Model,
|
|
|
|
"framework": None,
|
|
|
|
"design": None,
|
|
|
|
"labels": dict,
|
|
|
|
"ready": None,
|
2021-05-03 14:50:25 +00:00
|
|
|
"metadata": list,
|
2019-06-10 21:24:35 +00:00
|
|
|
}
|
|
|
|
|
2021-05-03 15:02:25 +00:00
|
|
|
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata")
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
def parse_model_fields(call, valid_fields):
|
|
|
|
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
2020-06-01 10:00:35 +00:00
|
|
|
conform_tag_fields(call, fields, validate=True)
|
2021-05-03 14:50:25 +00:00
|
|
|
metadata = fields.get("metadata")
|
|
|
|
if metadata:
|
|
|
|
validate_metadata(metadata)
|
2019-06-10 21:24:35 +00:00
|
|
|
return fields
|
|
|
|
|
|
|
|
|
2020-06-21 20:54:05 +00:00
|
|
|
def _update_cached_tags(company: str, project: str, fields: dict):
|
|
|
|
org_bll.update_tags(
|
|
|
|
company,
|
|
|
|
Tags.Model,
|
|
|
|
project=project,
|
|
|
|
tags=fields.get("tags"),
|
|
|
|
system_tags=fields.get("system_tags"),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
|
|
|
org_bll.reset_tags(
|
|
|
|
company, Tags.Model, projects=projects,
|
2020-06-01 10:00:35 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
@endpoint("models.update_for_task", required_fields=["task"])
|
2020-06-01 10:00:35 +00:00
|
|
|
def update_for_task(call: APICall, company_id, _):
|
2021-05-03 14:46:00 +00:00
|
|
|
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
|
|
|
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
task_id = call.data["task"]
|
|
|
|
uri = call.data.get("uri")
|
|
|
|
iteration = call.data.get("iteration")
|
|
|
|
override_model_id = call.data.get("override_model_id")
|
|
|
|
if not (uri or override_model_id) or (uri and override_model_id):
|
|
|
|
raise errors.bad_request.MissingRequiredFields(
|
|
|
|
"exactly one field is required", fields=("uri", "override_model_id")
|
|
|
|
)
|
|
|
|
|
|
|
|
with translate_errors_context():
|
|
|
|
|
|
|
|
query = dict(id=task_id, company=company_id)
|
|
|
|
task = Task.get_for_writing(
|
|
|
|
id=task_id,
|
|
|
|
company=company_id,
|
2021-05-03 14:46:00 +00:00
|
|
|
_only=["models", "execution", "name", "status", "project"],
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
|
|
|
if not task:
|
|
|
|
raise errors.bad_request.InvalidTaskId(**query)
|
|
|
|
|
|
|
|
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
|
|
|
if task.status not in allowed_states:
|
|
|
|
raise errors.bad_request.InvalidTaskStatus(
|
|
|
|
f"model can only be updated for tasks in the {allowed_states} states",
|
|
|
|
**query,
|
|
|
|
)
|
|
|
|
|
|
|
|
if override_model_id:
|
2021-05-03 14:52:54 +00:00
|
|
|
model = ModelBLL.get_company_model_by_id(
|
2021-05-03 14:50:25 +00:00
|
|
|
company_id=company_id, model_id=override_model_id
|
|
|
|
)
|
2019-06-10 21:24:35 +00:00
|
|
|
else:
|
|
|
|
if "name" not in call.data:
|
|
|
|
# use task name if name not provided
|
|
|
|
call.data["name"] = task.name
|
|
|
|
|
|
|
|
if "comment" not in call.data:
|
|
|
|
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
|
|
|
|
2021-05-03 14:52:54 +00:00
|
|
|
if task.models and task.models.output:
|
2019-06-10 21:24:35 +00:00
|
|
|
# model exists, update
|
2021-05-03 14:46:00 +00:00
|
|
|
model_id = task.models.output[-1].model
|
|
|
|
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
|
|
|
res.update({"id": model_id, "created": False})
|
2019-06-10 21:24:35 +00:00
|
|
|
call.result.data = res
|
|
|
|
return
|
|
|
|
|
|
|
|
# new model, create
|
|
|
|
fields = parse_model_fields(call, create_fields)
|
|
|
|
|
|
|
|
# create and save model
|
2021-05-03 15:02:25 +00:00
|
|
|
now = datetime.utcnow()
|
2019-06-10 21:24:35 +00:00
|
|
|
model = Model(
|
|
|
|
id=database.utils.id(),
|
2021-05-03 15:02:25 +00:00
|
|
|
created=now,
|
|
|
|
last_update=now,
|
2019-06-10 21:24:35 +00:00
|
|
|
user=call.identity.user,
|
|
|
|
company=company_id,
|
|
|
|
project=task.project,
|
|
|
|
framework=task.execution.framework,
|
2021-05-03 14:52:54 +00:00
|
|
|
parent=task.models.input[0].model
|
|
|
|
if task.models and task.models.input
|
|
|
|
else None,
|
2019-06-10 21:24:35 +00:00
|
|
|
design=task.execution.model_desc,
|
|
|
|
labels=task.execution.model_labels,
|
|
|
|
ready=(task.status == TaskStatus.published),
|
|
|
|
**fields,
|
|
|
|
)
|
|
|
|
model.save()
|
2020-06-21 20:54:05 +00:00
|
|
|
_update_cached_tags(company_id, project=model.project, fields=fields)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
TaskBLL.update_statistics(
|
|
|
|
task_id=task_id,
|
|
|
|
company_id=company_id,
|
|
|
|
last_iteration_max=iteration,
|
2021-05-03 14:46:00 +00:00
|
|
|
models__output=[
|
2021-05-03 14:56:50 +00:00
|
|
|
ModelItem(
|
|
|
|
model=model.id,
|
|
|
|
name=TaskModelNames[TaskModelTypes.output],
|
|
|
|
updated=datetime.utcnow(),
|
|
|
|
)
|
2021-05-03 14:46:00 +00:00
|
|
|
],
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
call.result.data = {"id": model.id, "created": True}
|
|
|
|
|
|
|
|
|
|
|
|
@endpoint(
|
|
|
|
"models.create",
|
|
|
|
request_data_model=CreateModelRequest,
|
|
|
|
response_data_model=CreateModelResponse,
|
|
|
|
)
|
2020-06-01 10:00:35 +00:00
|
|
|
def create(call: APICall, company_id, req_model: CreateModelRequest):
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
if req_model.public:
|
2020-06-01 10:00:35 +00:00
|
|
|
company_id = ""
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
with translate_errors_context():
|
|
|
|
|
|
|
|
project = req_model.project
|
|
|
|
if project:
|
2020-06-01 10:00:35 +00:00
|
|
|
validate_id(Project, company=company_id, project=project)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
task = req_model.task
|
|
|
|
req_data = req_model.to_struct()
|
|
|
|
if task:
|
2020-06-01 10:00:35 +00:00
|
|
|
validate_task(company_id, req_data)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
fields = filter_fields(Model, req_data)
|
2020-06-01 10:00:35 +00:00
|
|
|
conform_tag_fields(call, fields, validate=True)
|
2019-09-24 18:34:35 +00:00
|
|
|
|
2021-05-03 14:50:25 +00:00
|
|
|
validate_metadata(fields.get("metadata"))
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
# create and save model
|
2021-05-03 15:02:25 +00:00
|
|
|
now = datetime.utcnow()
|
2019-06-10 21:24:35 +00:00
|
|
|
model = Model(
|
|
|
|
id=database.utils.id(),
|
2020-06-01 10:00:35 +00:00
|
|
|
user=call.identity.user,
|
|
|
|
company=company_id,
|
2021-05-03 15:02:25 +00:00
|
|
|
created=now,
|
|
|
|
last_update=now,
|
2019-06-10 21:24:35 +00:00
|
|
|
**fields,
|
|
|
|
)
|
|
|
|
model.save()
|
2020-06-21 20:54:05 +00:00
|
|
|
_update_cached_tags(company_id, project=model.project, fields=fields)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
|
|
|
|
|
|
|
|
2020-06-01 10:00:35 +00:00
|
|
|
def prepare_update_fields(call, company_id, fields: dict):
|
2019-06-10 21:24:35 +00:00
|
|
|
fields = fields.copy()
|
|
|
|
if "uri" in fields:
|
|
|
|
# clear UI cache if URI is provided (model updated)
|
|
|
|
fields["ui_cache"] = fields.pop("ui_cache", {})
|
|
|
|
if "task" in fields:
|
2020-06-01 10:00:35 +00:00
|
|
|
validate_task(company_id, fields)
|
2019-09-24 18:34:35 +00:00
|
|
|
|
2019-10-25 12:36:58 +00:00
|
|
|
if "labels" in fields:
|
|
|
|
labels = fields["labels"]
|
|
|
|
|
|
|
|
def find_other_types(iterable, type_):
|
|
|
|
res = [x for x in iterable if not isinstance(x, type_)]
|
|
|
|
try:
|
|
|
|
return set(res)
|
|
|
|
except TypeError:
|
|
|
|
# Un-hashable, probably
|
|
|
|
return res
|
|
|
|
|
|
|
|
invalid_keys = find_other_types(labels.keys(), str)
|
|
|
|
if invalid_keys:
|
2020-06-01 08:29:50 +00:00
|
|
|
raise errors.bad_request.ValidationError(
|
|
|
|
"labels keys must be strings", keys=invalid_keys
|
|
|
|
)
|
2019-10-25 12:36:58 +00:00
|
|
|
|
|
|
|
invalid_values = find_other_types(labels.values(), int)
|
|
|
|
if invalid_values:
|
2020-06-01 08:29:50 +00:00
|
|
|
raise errors.bad_request.ValidationError(
|
|
|
|
"labels values must be integers", values=invalid_values
|
|
|
|
)
|
2019-10-25 12:36:58 +00:00
|
|
|
|
2020-06-01 10:00:35 +00:00
|
|
|
conform_tag_fields(call, fields, validate=True)
|
2019-06-10 21:24:35 +00:00
|
|
|
return fields
|
|
|
|
|
|
|
|
|
2020-06-01 10:00:35 +00:00
|
|
|
def validate_task(company_id, fields: dict):
|
|
|
|
Task.get_for_writing(company=company_id, id=fields["task"], _only=["id"])
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
2020-06-01 10:00:35 +00:00
|
|
|
def edit(call: APICall, company_id, _):
|
2019-06-10 21:24:35 +00:00
|
|
|
model_id = call.data["model"]
|
|
|
|
|
|
|
|
with translate_errors_context():
|
2021-05-03 14:52:54 +00:00
|
|
|
model = ModelBLL.get_company_model_by_id(
|
|
|
|
company_id=company_id, model_id=model_id
|
|
|
|
)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
fields = parse_model_fields(call, create_fields)
|
2020-06-01 10:00:35 +00:00
|
|
|
fields = prepare_update_fields(call, company_id, fields)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
for key in fields:
|
|
|
|
field = getattr(model, key, None)
|
|
|
|
value = fields[key]
|
|
|
|
if (
|
|
|
|
field
|
|
|
|
and isinstance(value, dict)
|
|
|
|
and isinstance(field, EmbeddedDocument)
|
|
|
|
):
|
|
|
|
d = field.to_mongo(use_db_field=False).to_dict()
|
|
|
|
d.update(value)
|
|
|
|
fields[key] = d
|
|
|
|
|
|
|
|
iteration = call.data.get("iteration")
|
2020-06-01 08:29:50 +00:00
|
|
|
task_id = model.task or fields.get("task")
|
2019-06-10 21:24:35 +00:00
|
|
|
if task_id and iteration is not None:
|
|
|
|
TaskBLL.update_statistics(
|
2020-06-01 10:00:35 +00:00
|
|
|
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
if fields:
|
2021-05-03 15:02:25 +00:00
|
|
|
if any(uf in fields for uf in last_update_fields):
|
|
|
|
fields.update(last_update=datetime.utcnow())
|
|
|
|
|
2019-06-10 21:24:35 +00:00
|
|
|
updated = model.update(upsert=False, **fields)
|
2020-06-01 10:00:35 +00:00
|
|
|
if updated:
|
2020-06-21 20:54:05 +00:00
|
|
|
new_project = fields.get("project", model.project)
|
|
|
|
if new_project != model.project:
|
|
|
|
_reset_cached_tags(
|
|
|
|
company_id, projects=[new_project, model.project]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
_update_cached_tags(
|
|
|
|
company_id, project=model.project, fields=fields
|
|
|
|
)
|
2019-09-24 18:34:35 +00:00
|
|
|
conform_output_tags(call, fields)
|
2019-06-10 21:24:35 +00:00
|
|
|
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
|
|
|
else:
|
|
|
|
call.result.data_model = UpdateResponse(updated=0)
|
|
|
|
|
|
|
|
|
2020-06-01 10:00:35 +00:00
|
|
|
def _update_model(call: APICall, company_id, model_id=None):
|
2019-06-10 21:24:35 +00:00
|
|
|
model_id = model_id or call.data["model"]
|
|
|
|
|
|
|
|
with translate_errors_context():
|
2021-05-03 14:52:54 +00:00
|
|
|
model = ModelBLL.get_company_model_by_id(
|
|
|
|
company_id=company_id, model_id=model_id
|
|
|
|
)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
2020-06-01 10:00:35 +00:00
|
|
|
data = prepare_update_fields(call, company_id, call.data)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
task_id = data.get("task")
|
|
|
|
iteration = data.get("iteration")
|
|
|
|
if task_id and iteration is not None:
|
|
|
|
TaskBLL.update_statistics(
|
2020-06-01 10:00:35 +00:00
|
|
|
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
|
|
|
|
2021-05-03 14:50:25 +00:00
|
|
|
metadata = data.get("metadata")
|
|
|
|
if metadata:
|
|
|
|
validate_metadata(metadata)
|
|
|
|
|
2020-06-01 10:00:35 +00:00
|
|
|
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
|
|
|
if updated_count:
|
2021-05-03 15:02:25 +00:00
|
|
|
if any(uf in updated_fields for uf in last_update_fields):
|
|
|
|
model.update(upsert=False, last_update=datetime.utcnow())
|
|
|
|
|
2020-06-21 20:54:05 +00:00
|
|
|
new_project = updated_fields.get("project", model.project)
|
|
|
|
if new_project != model.project:
|
|
|
|
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
|
|
|
else:
|
|
|
|
_update_cached_tags(
|
|
|
|
company_id, project=model.project, fields=updated_fields
|
|
|
|
)
|
2019-09-24 18:34:35 +00:00
|
|
|
conform_output_tags(call, updated_fields)
|
2019-06-10 21:24:35 +00:00
|
|
|
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
|
|
|
|
|
|
|
|
|
|
|
@endpoint(
|
|
|
|
"models.update", required_fields=["model"], response_data_model=UpdateResponse
|
|
|
|
)
|
2020-06-01 10:00:35 +00:00
|
|
|
def update(call, company_id, _):
|
|
|
|
call.result.data_model = _update_model(call, company_id)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint(
|
|
|
|
"models.set_ready",
|
|
|
|
request_data_model=PublishModelRequest,
|
|
|
|
response_data_model=PublishModelResponse,
|
|
|
|
)
|
2021-05-03 14:52:54 +00:00
|
|
|
def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
|
|
|
|
updated, published_task = ModelBLL.publish_model(
|
|
|
|
model_id=request.model,
|
2020-06-01 10:00:35 +00:00
|
|
|
company_id=company_id,
|
2021-05-03 14:52:54 +00:00
|
|
|
force_publish_task=request.force_publish_task,
|
|
|
|
publish_task_func=publish_task if request.publish_task else None,
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
|
|
|
call.result.data_model = PublishModelResponse(
|
2021-05-03 14:52:54 +00:00
|
|
|
updated=updated, published_task=published_task
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@endpoint(
|
|
|
|
"models.publish_many",
|
|
|
|
request_data_model=ModelsPublishManyRequest,
|
2021-05-03 15:07:37 +00:00
|
|
|
response_data_model=BatchResponse,
|
2021-05-03 14:52:54 +00:00
|
|
|
)
|
|
|
|
def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
|
2021-05-03 15:07:37 +00:00
|
|
|
results, failures = run_batch_operation(
|
2021-05-03 14:52:54 +00:00
|
|
|
func=partial(
|
|
|
|
ModelBLL.publish_model,
|
|
|
|
company_id=company_id,
|
|
|
|
force_publish_task=request.force_publish_task,
|
|
|
|
publish_task_func=publish_task if request.publish_task else None,
|
|
|
|
),
|
|
|
|
ids=request.ids,
|
|
|
|
)
|
|
|
|
|
2021-05-03 15:07:37 +00:00
|
|
|
call.result.data_model = BatchResponse(
|
|
|
|
succeeded=[
|
|
|
|
dict(
|
|
|
|
id=_id, updated=bool(updated), published_task=published_task.to_struct()
|
|
|
|
)
|
|
|
|
for _id, (updated, published_task) in results
|
|
|
|
],
|
|
|
|
failed=failures,
|
2019-06-10 21:24:35 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2021-05-03 14:38:09 +00:00
|
|
|
@endpoint("models.delete", request_data_model=DeleteModelRequest)
|
2021-05-03 14:39:13 +00:00
|
|
|
def delete(call: APICall, company_id, request: DeleteModelRequest):
|
2021-05-03 14:52:54 +00:00
|
|
|
del_count, model = ModelBLL.delete_model(
|
|
|
|
model_id=request.model, company_id=company_id, force=request.force
|
|
|
|
)
|
|
|
|
if del_count:
|
|
|
|
_reset_cached_tags(
|
|
|
|
company_id, projects=[model.project] if model.project else []
|
|
|
|
)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
2021-05-03 15:07:37 +00:00
|
|
|
call.result.data = dict(deleted=bool(del_count), url=model.uri)
|
2019-06-10 21:24:35 +00:00
|
|
|
|
|
|
|
|
2021-05-03 14:52:54 +00:00
|
|
|
@endpoint(
|
|
|
|
"models.delete_many",
|
|
|
|
request_data_model=ModelsDeleteManyRequest,
|
2021-05-03 15:07:37 +00:00
|
|
|
response_data_model=BatchResponse,
|
2021-05-03 14:52:54 +00:00
|
|
|
)
|
|
|
|
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
|
2021-05-03 15:07:37 +00:00
|
|
|
results, failures = run_batch_operation(
|
2021-05-03 14:52:54 +00:00
|
|
|
func=partial(ModelBLL.delete_model, company_id=company_id, force=request.force),
|
|
|
|
ids=request.ids,
|
|
|
|
)
|
|
|
|
|
2021-05-03 15:07:37 +00:00
|
|
|
if results:
|
|
|
|
projects = set(model.project for _, (_, model) in results)
|
|
|
|
_reset_cached_tags(company_id, projects=list(projects))
|
|
|
|
|
|
|
|
call.result.data_model = BatchResponse(
|
|
|
|
succeeded=[
|
|
|
|
dict(id=_id, deleted=bool(deleted), url=model.uri)
|
|
|
|
for _id, (deleted, model) in results
|
|
|
|
],
|
|
|
|
failed=failures,
|
2021-05-03 14:52:54 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@endpoint(
|
|
|
|
"models.archive_many",
|
2021-05-03 15:04:17 +00:00
|
|
|
request_data_model=BatchRequest,
|
2021-05-03 15:03:54 +00:00
|
|
|
response_data_model=BatchResponse,
|
2021-05-03 14:52:54 +00:00
|
|
|
)
|
2021-05-03 15:04:17 +00:00
|
|
|
def archive_many(call: APICall, company_id, request: BatchRequest):
|
2021-05-03 15:07:37 +00:00
|
|
|
results, failures = run_batch_operation(
|
|
|
|
func=partial(ModelBLL.archive_model, company_id=company_id), ids=request.ids,
|
|
|
|
)
|
|
|
|
call.result.data_model = BatchResponse(
|
|
|
|
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
|
|
|
|
failed=failures,
|
2021-05-03 14:52:54 +00:00
|
|
|
)
|
2020-08-10 05:30:40 +00:00
|
|
|
|
|
|
|
|
2021-05-03 15:04:17 +00:00
|
|
|
@endpoint(
|
|
|
|
"models.unarchive_many",
|
|
|
|
request_data_model=BatchRequest,
|
|
|
|
response_data_model=BatchResponse,
|
|
|
|
)
|
|
|
|
def unarchive_many(call: APICall, company_id, request: BatchRequest):
|
2021-05-03 15:07:37 +00:00
|
|
|
results, failures = run_batch_operation(
|
|
|
|
func=partial(ModelBLL.unarchive_model, company_id=company_id), ids=request.ids,
|
|
|
|
)
|
|
|
|
call.result.data_model = BatchResponse(
|
|
|
|
succeeded=[
|
|
|
|
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
|
|
|
|
],
|
|
|
|
failed=failures,
|
2021-05-03 15:04:17 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2020-08-10 05:30:40 +00:00
|
|
|
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
|
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
2021-05-03 14:52:54 +00:00
|
|
|
call.result.data = Model.set_public(
|
|
|
|
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
|
|
|
|
)
|
2020-08-10 05:30:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint(
|
|
|
|
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
|
|
|
)
|
|
|
|
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
2021-05-03 14:52:54 +00:00
|
|
|
call.result.data = Model.set_public(
|
|
|
|
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
|
|
|
|
)
|
2021-01-05 16:05:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint("models.move", request_data_model=MoveRequest)
|
|
|
|
def move(call: APICall, company_id: str, request: MoveRequest):
|
|
|
|
if not (request.project or request.project_name):
|
|
|
|
raise errors.bad_request.MissingRequiredFields(
|
|
|
|
"project or project_name is required"
|
|
|
|
)
|
|
|
|
|
2021-05-03 14:52:54 +00:00
|
|
|
return {
|
|
|
|
"project_id": project_bll.move_under_project(
|
|
|
|
entity_cls=Model,
|
|
|
|
user=call.identity.user,
|
|
|
|
company=company_id,
|
|
|
|
ids=request.ids,
|
|
|
|
project=request.project,
|
|
|
|
project_name=request.project_name,
|
|
|
|
)
|
|
|
|
}
|
2021-05-03 14:50:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint("models.add_or_update_metadata", min_version="2.13")
|
|
|
|
def add_or_update_metadata(
|
|
|
|
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
|
|
|
):
|
|
|
|
model_id = request.model
|
2021-05-03 14:52:54 +00:00
|
|
|
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
2021-05-03 14:50:25 +00:00
|
|
|
|
2021-05-03 15:02:25 +00:00
|
|
|
updated = metadata_add_or_update(
|
|
|
|
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
|
|
|
)
|
|
|
|
if updated:
|
|
|
|
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
|
|
|
|
|
|
|
return {"updated": updated}
|
2021-05-03 14:50:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
@endpoint("models.delete_metadata", min_version="2.13")
|
|
|
|
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
|
|
|
model_id = request.model
|
2021-05-03 14:52:54 +00:00
|
|
|
ModelBLL.get_company_model_by_id(
|
|
|
|
company_id=company_id, model_id=model_id, only_fields=("id",)
|
|
|
|
)
|
2021-05-03 14:50:25 +00:00
|
|
|
|
2021-05-03 15:02:25 +00:00
|
|
|
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
|
|
|
|
if updated:
|
|
|
|
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
|
|
|
|
|
|
|
return {"updated": updated}
|