mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 19:06:55 +00:00
af09fba755
Add more info for projects
131 lines
4.9 KiB
Python
131 lines
4.9 KiB
Python
from datetime import datetime
|
|
from typing import Callable, Tuple
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apimodels.models import ModelTaskPublishResponse
|
|
from apiserver.bll.task.utils import deleted_prefix
|
|
from apiserver.database.model import EntityVisibility
|
|
from apiserver.database.model.model import Model
|
|
from apiserver.database.model.task.task import Task, TaskStatus
|
|
from .metadata import Metadata
|
|
|
|
|
|
class ModelBLL:
|
|
@classmethod
|
|
def get_company_model_by_id(
|
|
cls, company_id: str, model_id: str, only_fields=None
|
|
) -> Model:
|
|
query = dict(company=company_id, id=model_id)
|
|
qs = Model.objects(**query)
|
|
if only_fields:
|
|
qs = qs.only(*only_fields)
|
|
model = qs.first()
|
|
if not model:
|
|
raise errors.bad_request.InvalidModelId(**query)
|
|
return model
|
|
|
|
@classmethod
|
|
def publish_model(
|
|
cls,
|
|
model_id: str,
|
|
company_id: str,
|
|
force_publish_task: bool = False,
|
|
publish_task_func: Callable[[str, str, bool], dict] = None,
|
|
) -> Tuple[int, ModelTaskPublishResponse]:
|
|
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
|
if model.ready:
|
|
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
|
|
|
published_task = None
|
|
if model.task and publish_task_func:
|
|
task = (
|
|
Task.objects(id=model.task, company=company_id)
|
|
.only("id", "status")
|
|
.first()
|
|
)
|
|
if task and task.status != TaskStatus.published:
|
|
task_publish_res = publish_task_func(
|
|
model.task, company_id, force_publish_task
|
|
)
|
|
published_task = ModelTaskPublishResponse(
|
|
id=model.task, data=task_publish_res
|
|
)
|
|
|
|
updated = model.update(upsert=False, ready=True, last_update=datetime.utcnow())
|
|
return updated, published_task
|
|
|
|
@classmethod
|
|
def delete_model(
|
|
cls, model_id: str, company_id: str, force: bool
|
|
) -> Tuple[int, Model]:
|
|
model = cls.get_company_model_by_id(
|
|
company_id=company_id,
|
|
model_id=model_id,
|
|
only_fields=("id", "task", "project", "uri"),
|
|
)
|
|
deleted_model_id = f"{deleted_prefix}{model_id}"
|
|
|
|
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
|
if using_tasks:
|
|
if not force:
|
|
raise errors.bad_request.ModelInUse(
|
|
"as execution model, use force=True to delete",
|
|
num_tasks=len(using_tasks),
|
|
)
|
|
# update deleted model id in using tasks
|
|
Task._get_collection().update_many(
|
|
filter={"_id": {"$in": [t.id for t in using_tasks]}},
|
|
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
|
|
array_filters=[{"elem.model": model_id}],
|
|
upsert=False,
|
|
)
|
|
|
|
if model.task:
|
|
task = Task.objects(id=model.task).first()
|
|
if task and task.status == TaskStatus.published:
|
|
if not force:
|
|
raise errors.bad_request.ModelCreatingTaskExists(
|
|
"and published, use force=True to delete", task=model.task
|
|
)
|
|
if task.models.output and model_id in task.models.output:
|
|
now = datetime.utcnow()
|
|
Task._get_collection().update_one(
|
|
filter={"_id": model.task, "models.output.model": model_id},
|
|
update={
|
|
"$set": {
|
|
"models.output.$[elem].model": deleted_model_id,
|
|
"output.error": f"model deleted on {now.isoformat()}",
|
|
},
|
|
"last_change": now,
|
|
},
|
|
array_filters=[{"elem.model": model_id}],
|
|
upsert=False,
|
|
)
|
|
|
|
del_count = Model.objects(id=model_id, company=company_id).delete()
|
|
return del_count, model
|
|
|
|
@classmethod
|
|
def archive_model(cls, model_id: str, company_id: str):
|
|
cls.get_company_model_by_id(
|
|
company_id=company_id, model_id=model_id, only_fields=("id",)
|
|
)
|
|
archived = Model.objects(company=company_id, id=model_id).update(
|
|
add_to_set__system_tags=EntityVisibility.archived.value,
|
|
last_update=datetime.utcnow(),
|
|
)
|
|
|
|
return archived
|
|
|
|
@classmethod
|
|
def unarchive_model(cls, model_id: str, company_id: str):
|
|
cls.get_company_model_by_id(
|
|
company_id=company_id, model_id=model_id, only_fields=("id",)
|
|
)
|
|
unarchived = Model.objects(company=company_id, id=model_id).update(
|
|
pull__system_tags=EntityVisibility.archived.value,
|
|
last_update=datetime.utcnow(),
|
|
)
|
|
|
|
return unarchived
|