clearml-server/apiserver/services/models.py

734 lines
23 KiB
Python
Raw Normal View History

2019-06-10 21:24:35 +00:00
from datetime import datetime
2021-05-03 14:52:54 +00:00
from functools import partial
2023-05-25 16:17:40 +00:00
from typing import Sequence, Union
2019-06-10 21:24:35 +00:00
from mongoengine import Q, EmbeddedDocument
2021-01-05 14:28:49 +00:00
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidModelId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, MoveRequest
2021-05-03 15:04:17 +00:00
from apiserver.apimodels.batch import BatchResponse, BatchRequest
2021-01-05 14:28:49 +00:00
from apiserver.apimodels.models import (
2019-06-10 21:24:35 +00:00
CreateModelRequest,
CreateModelResponse,
PublishModelRequest,
PublishModelResponse,
2020-07-06 18:50:43 +00:00
GetFrameworksRequest,
DeleteModelRequest,
DeleteMetadataRequest,
AddOrUpdateMetadataRequest,
2021-05-03 14:52:54 +00:00
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
2022-07-08 14:39:41 +00:00
ModelsGetRequest,
ModelRequest,
TaskRequest,
UpdateForTaskRequest,
UpdateModelRequest,
2019-06-10 21:24:35 +00:00
)
from apiserver.apimodels.tasks import UpdateTagsRequest
from apiserver.bll.model import ModelBLL, Metadata
2021-01-05 14:28:49 +00:00
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
2021-01-05 14:28:49 +00:00
from apiserver.bll.task import TaskBLL
2021-05-03 14:52:54 +00:00
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.task.utils import get_task_with_write_access
2021-05-03 14:52:54 +00:00
from apiserver.bll.util import run_batch_operation
2021-01-05 14:44:31 +00:00
from apiserver.config_repo import config
2021-01-05 14:28:49 +00:00
from apiserver.database.model import validate_id
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
Task,
TaskStatus,
ModelItem,
TaskModelNames,
TaskModelTypes,
)
2021-01-05 14:28:49 +00:00
from apiserver.database.utils import (
2019-06-10 21:24:35 +00:00
parse_from_call,
get_company_or_none_constraint,
filter_fields,
)
2021-01-05 14:28:49 +00:00
from apiserver.service_repo import APICall, endpoint
from apiserver.service_repo.auth import Identity
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
unescape_metadata,
escape_metadata,
process_include_subprojects,
)
2019-06-10 21:24:35 +00:00
log = config.logger(__file__)
org_bll = OrgBLL()
project_bll = ProjectBLL()
2019-06-10 21:24:35 +00:00
2023-05-25 16:17:40 +00:00
def conform_model_data(call: APICall, model_data: Union[Sequence[dict], dict]):
conform_output_tags(call, model_data)
unescape_metadata(call, model_data)
@endpoint("models.get_by_id")
def get_by_id(call: APICall, company_id, request: ModelRequest):
model_id = request.model
call_data = Metadata.escape_query_parameters(call.data)
models = Model.get_many(
company=company_id,
query_dict=call_data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=company_id,
2019-06-10 21:24:35 +00:00
)
2023-05-25 16:17:40 +00:00
conform_model_data(call, models[0])
call.result.data = {"model": models[0]}
2019-06-10 21:24:35 +00:00
@endpoint("models.get_by_task_id")
def get_by_task_id(call: APICall, company_id, request: TaskRequest):
2021-05-03 14:46:00 +00:00
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
task_id = request.task
2019-06-10 21:24:35 +00:00
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=company_id,
)
model_dict = model.to_proper_dict()
2023-05-25 16:17:40 +00:00
conform_model_data(call, model_dict)
call.result.data = {"model": model_dict}
2019-06-10 21:24:35 +00:00
2022-07-08 14:39:41 +00:00
@endpoint("models.get_all_ex", request_data_model=ModelsGetRequest)
def get_all_ex(call: APICall, company_id, request: ModelsGetRequest):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call.data)
process_include_subprojects(call_data)
2022-09-29 16:37:15 +00:00
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call_data,
allow_public=request.allow_public,
2022-09-29 16:37:15 +00:00
ret_params=ret_params,
)
2023-05-25 16:17:40 +00:00
conform_model_data(call, models)
2022-07-08 14:39:41 +00:00
if not request.include_stats:
call.result.data = {"models": models, **ret_params}
return
model_ids = {model["id"] for model in models}
stats = ModelBLL.get_model_stats(
company=company_id,
model_ids=list(model_ids),
)
2022-07-08 14:39:41 +00:00
for model in models:
model["stats"] = stats.get(model["id"])
call.result.data = {"models": models, **ret_params}
2019-06-10 21:24:35 +00:00
@endpoint("models.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call.data)
2022-09-29 16:37:15 +00:00
models = Model.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True
2022-09-29 16:37:15 +00:00
)
2023-05-25 16:17:40 +00:00
conform_model_data(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all")
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
call_data = Metadata.escape_query_parameters(call.data)
process_include_subprojects(call_data)
2022-09-29 16:37:15 +00:00
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
2022-09-29 16:37:15 +00:00
allow_public=True,
ret_params=ret_params,
)
2023-05-25 16:17:40 +00:00
conform_model_data(call, models)
call.result.data = {"models": models, **ret_params}
2019-06-10 21:24:35 +00:00
2020-07-06 18:50:43 +00:00
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
call.result.data = {
"frameworks": sorted(
2021-05-03 14:42:10 +00:00
project_bll.get_model_frameworks(company_id, project_ids=request.projects)
2020-07-06 18:50:43 +00:00
)
}
2019-06-10 21:24:35 +00:00
create_fields = {
"name": None,
"tags": list,
"system_tags": list,
2019-06-10 21:24:35 +00:00
"task": Task,
"comment": None,
"uri": None,
"project": Project,
"parent": Model,
"framework": None,
"design": dict,
2019-06-10 21:24:35 +00:00
"labels": dict,
"ready": None,
"metadata": list,
2019-06-10 21:24:35 +00:00
}
last_update_fields = (
"uri",
"framework",
"design",
"labels",
"ready",
"metadata",
"system_tags",
"tags",
)
2021-05-03 15:02:25 +00:00
2019-06-10 21:24:35 +00:00
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
2019-06-10 21:24:35 +00:00
return fields
def _update_cached_tags(company: str, project: str, fields: dict):
org_bll.update_tags(
company,
Tags.Model,
projects=[project],
tags=fields.get("tags"),
system_tags=fields.get("system_tags"),
)
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(
company,
Tags.Model,
projects=projects,
)
@endpoint("models.update_for_task")
def update_for_task(call: APICall, company_id, request: UpdateForTaskRequest):
2021-05-03 14:46:00 +00:00
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
task_id = request.task
uri = request.uri
iteration = request.iteration
override_model_id = request.override_model_id
2019-06-10 21:24:35 +00:00
if not (uri or override_model_id) or (uri and override_model_id):
raise errors.bad_request.MissingRequiredFields(
"exactly one field is required", fields=("uri", "override_model_id")
)
query = dict(id=task_id, company=company_id)
task = get_task_with_write_access(
task_id=task_id,
company_id=company_id,
identity=call.identity,
only=("models", "execution", "name", "status", "project"),
)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
)
if override_model_id:
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models and task.models.output:
# model exists, update
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
2019-06-10 21:24:35 +00:00
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
user=call.identity.user,
2019-06-10 21:24:35 +00:00
company=company_id,
project=task.project,
framework=task.execution.framework,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
2019-06-10 21:24:35 +00:00
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
2019-06-10 21:24:35 +00:00
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
2019-06-10 21:24:35 +00:00
)
],
)
2019-06-10 21:24:35 +00:00
call.result.data = {"id": model.id, "created": True}
2019-06-10 21:24:35 +00:00
@endpoint(
"models.create",
request_data_model=CreateModelRequest,
response_data_model=CreateModelResponse,
)
def create(call: APICall, company_id, req_model: CreateModelRequest):
2019-06-10 21:24:35 +00:00
if req_model.public:
company_id = ""
2019-06-10 21:24:35 +00:00
project = req_model.project
if project:
validate_id(Project, company=company_id, project=project)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, call.identity, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=now,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
2019-06-10 21:24:35 +00:00
call.result.data_model = CreateModelResponse(id=model.id, created=True)
2019-06-10 21:24:35 +00:00
def prepare_update_fields(call, company_id, fields: dict):
2019-06-10 21:24:35 +00:00
fields = fields.copy()
if "uri" in fields:
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(company_id, call.identity, fields)
if "labels" in fields:
labels = fields["labels"]
def find_other_types(iterable, type_):
res = [x for x in iterable if not isinstance(x, type_)]
try:
return set(res)
except TypeError:
# Un-hashable, probably
return res
invalid_keys = find_other_types(labels.keys(), str)
if invalid_keys:
raise errors.bad_request.ValidationError(
"labels keys must be strings", keys=invalid_keys
)
invalid_values = find_other_types(labels.values(), int)
if invalid_values:
raise errors.bad_request.ValidationError(
"labels values must be integers", values=invalid_values
)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
2019-06-10 21:24:35 +00:00
return fields
def validate_task(company_id: str, identity: Identity, fields: dict):
task_id = fields["task"]
get_task_with_write_access(
task_id=task_id, company_id=company_id, identity=identity, only=("id",)
)
2019-06-10 21:24:35 +00:00
@endpoint("models.edit", response_data_model=UpdateResponse)
def edit(call: APICall, company_id, request: UpdateModelRequest):
model_id = request.model
2019-06-10 21:24:35 +00:00
2022-09-29 16:37:15 +00:00
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
for key in fields:
field = getattr(model, key, None)
value = fields[key]
2022-09-29 16:37:15 +00:00
if field and isinstance(value, dict) and isinstance(field, EmbeddedDocument):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = request.iteration
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
2021-05-03 14:52:54 +00:00
)
2019-06-10 21:24:35 +00:00
if fields:
now = datetime.utcnow()
fields.update(
last_change=now,
last_changed_by=call.identity.user,
)
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=now)
2019-06-10 21:24:35 +00:00
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
2022-09-29 16:37:15 +00:00
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
2022-09-29 16:37:15 +00:00
_update_cached_tags(company_id, project=model.project, fields=fields)
2023-05-25 16:17:40 +00:00
conform_model_data(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
2019-06-10 21:24:35 +00:00
def _update_model(call: APICall, company_id, model_id):
2022-09-29 16:37:15 +00:00
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
user_id=call.identity.user,
last_iteration_max=iteration,
)
now = datetime.utcnow()
updated_count, updated_fields = Model.safe_update(
company_id,
model.id,
data,
injected_update=dict(
last_change=now,
last_changed_by=call.identity.user,
),
)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=now)
2021-05-03 15:02:25 +00:00
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
2023-05-25 16:17:40 +00:00
conform_model_data(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
2019-06-10 21:24:35 +00:00
@endpoint("models.update", response_data_model=UpdateResponse)
def update(call, company_id, request: UpdateModelRequest):
call.result.data_model = _update_model(call, company_id, model_id=request.model)
2019-06-10 21:24:35 +00:00
@endpoint(
"models.set_ready",
request_data_model=PublishModelRequest,
response_data_model=PublishModelResponse,
)
2021-05-03 14:52:54 +00:00
def set_ready(call: APICall, company_id: str, request: PublishModelRequest):
updated, published_task = ModelBLL.publish_model(
model_id=request.model,
company_id=company_id,
identity=call.identity,
2021-05-03 14:52:54 +00:00
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
2019-06-10 21:24:35 +00:00
)
call.result.data_model = PublishModelResponse(
2021-05-03 14:52:54 +00:00
updated=updated, published_task=published_task
)
@endpoint(
"models.publish_many",
request_data_model=ModelsPublishManyRequest,
2021-05-03 15:07:37 +00:00
response_data_model=BatchResponse,
2021-05-03 14:52:54 +00:00
)
def publish_many(call: APICall, company_id, request: ModelsPublishManyRequest):
2021-05-03 15:07:37 +00:00
results, failures = run_batch_operation(
2021-05-03 14:52:54 +00:00
func=partial(
ModelBLL.publish_model,
company_id=company_id,
identity=call.identity,
2021-05-03 14:52:54 +00:00
force_publish_task=request.force_publish_task,
publish_task_func=publish_task if request.publish_task else None,
),
ids=request.ids,
)
2021-05-03 15:07:37 +00:00
call.result.data_model = BatchResponse(
succeeded=[
dict(
id=_id,
updated=bool(updated),
published_task=published_task.to_struct() if published_task else None,
2021-05-03 15:07:37 +00:00
)
for _id, (updated, published_task) in results
],
failed=failures,
2019-06-10 21:24:35 +00:00
)
@endpoint("models.delete", request_data_model=DeleteModelRequest)
def delete(call: APICall, company_id, request: DeleteModelRequest):
2021-05-03 14:52:54 +00:00
del_count, model = ModelBLL.delete_model(
model_id=request.model,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
delete_external_artifacts=request.delete_external_artifacts,
2021-05-03 14:52:54 +00:00
)
if del_count:
_reset_cached_tags(
company_id, projects=[model.project] if model.project else []
)
2019-06-10 21:24:35 +00:00
2021-05-03 15:07:37 +00:00
call.result.data = dict(deleted=bool(del_count), url=model.uri)
2019-06-10 21:24:35 +00:00
2021-05-03 14:52:54 +00:00
@endpoint(
"models.delete_many",
request_data_model=ModelsDeleteManyRequest,
2021-05-03 15:07:37 +00:00
response_data_model=BatchResponse,
2021-05-03 14:52:54 +00:00
)
def delete(call: APICall, company_id, request: ModelsDeleteManyRequest):
2021-05-03 15:07:37 +00:00
results, failures = run_batch_operation(
func=partial(
ModelBLL.delete_model,
company_id=company_id,
user_id=call.identity.user,
force=request.force,
delete_external_artifacts=request.delete_external_artifacts,
),
2021-05-03 14:52:54 +00:00
ids=request.ids,
)
2021-05-03 15:07:37 +00:00
if results:
projects = set(model.project for _, (_, model) in results)
_reset_cached_tags(company_id, projects=list(projects))
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, deleted=bool(deleted), url=model.uri)
for _id, (deleted, model) in results
],
failed=failures,
2021-05-03 14:52:54 +00:00
)
@endpoint(
"models.archive_many",
2021-05-03 15:04:17 +00:00
request_data_model=BatchRequest,
2021-05-03 15:03:54 +00:00
response_data_model=BatchResponse,
2021-05-03 14:52:54 +00:00
)
2021-05-03 15:04:17 +00:00
def archive_many(call: APICall, company_id, request: BatchRequest):
2021-05-03 15:07:37 +00:00
results, failures = run_batch_operation(
func=partial(
ModelBLL.archive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
2021-05-03 15:07:37 +00:00
)
call.result.data_model = BatchResponse(
succeeded=[dict(id=_id, archived=bool(archived)) for _id, archived in results],
failed=failures,
2021-05-03 14:52:54 +00:00
)
2021-05-03 15:04:17 +00:00
@endpoint(
"models.unarchive_many",
request_data_model=BatchRequest,
response_data_model=BatchResponse,
)
def unarchive_many(call: APICall, company_id, request: BatchRequest):
2021-05-03 15:07:37 +00:00
results, failures = run_batch_operation(
func=partial(
ModelBLL.unarchive_model, company_id=company_id, user_id=call.identity.user
),
ids=request.ids,
2021-05-03 15:07:37 +00:00
)
call.result.data_model = BatchResponse(
succeeded=[
dict(id=_id, unarchived=bool(unarchived)) for _id, unarchived in results
],
failed=failures,
2021-05-03 15:04:17 +00:00
)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
2021-05-03 14:52:54 +00:00
call.result.data = Model.set_public(
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=True,
2021-05-03 14:52:54 +00:00
)
@endpoint(
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
2021-05-03 14:52:54 +00:00
call.result.data = Model.set_public(
company_id=company_id,
user_id=call.identity.user,
ids=request.ids,
invalid_cls=InvalidModelId,
enabled=False,
2021-05-03 14:52:54 +00:00
)
@endpoint("models.move", request_data_model=MoveRequest)
def move(call: APICall, company_id: str, request: MoveRequest):
if not ("project" in call.data or request.project_name):
raise errors.bad_request.MissingRequiredFields(
"project or project_name is required"
)
2021-05-03 14:52:54 +00:00
return {
"project_id": project_bll.move_under_project(
entity_cls=Model,
user=call.identity.user,
company=company_id,
ids=request.ids,
project=request.project,
project_name=request.project_name,
)
}
@endpoint("models.update_tags")
def update_tags(_, company_id: str, request: UpdateTagsRequest):
return {
"updated": org_bll.edit_entity_tags(
company_id=company_id,
entity_cls=Model,
entity_ids=request.ids,
add_tags=request.add_tags,
remove_tags=request.remove_tags,
)
}
@endpoint("models.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
now = datetime.utcnow()
return {
"updated": Metadata.edit_metadata(
model,
items=request.metadata,
replace_metadata=request.replace_metadata,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
model = ModelBLL.get_company_model_by_id(
2021-05-03 14:52:54 +00:00
company_id=company_id, model_id=model_id, only_fields=("id",)
)
now = datetime.utcnow()
return {
"updated": Metadata.delete_metadata(
model,
keys=request.keys,
last_update=now,
last_change=now,
last_changed_by=call.identity.user,
)
}