mirror of
https://github.com/clearml/clearml-server
synced 2025-01-31 10:56:48 +00:00
238 lines
8.2 KiB
Python
238 lines
8.2 KiB
Python
from datetime import datetime
|
|
from typing import Callable, Tuple, Sequence, Dict, Optional
|
|
|
|
from mongoengine import Q
|
|
|
|
from apiserver.apierrors import errors
|
|
from apiserver.apimodels.models import ModelTaskPublishResponse
|
|
from apiserver.bll.task.utils import deleted_prefix, get_last_metric_updates
|
|
from apiserver.database.model import EntityVisibility
|
|
from apiserver.database.model.model import Model
|
|
from apiserver.database.model.task.task import Task, TaskStatus
|
|
from apiserver.service_repo.auth import Identity
|
|
from .metadata import Metadata
|
|
|
|
|
|
class ModelBLL:
|
|
@classmethod
|
|
def get_company_model_by_id(
|
|
cls, company_id: str, model_id: str, only_fields=None
|
|
) -> Model:
|
|
query = dict(company=company_id, id=model_id)
|
|
qs = Model.objects(**query)
|
|
if only_fields:
|
|
qs = qs.only(*only_fields)
|
|
model = qs.first()
|
|
if not model:
|
|
raise errors.bad_request.InvalidModelId(**query)
|
|
return model
|
|
|
|
@staticmethod
|
|
def assert_exists(
|
|
company_id, model_ids, only=None, allow_public=False, return_models=True,
|
|
) -> Optional[Sequence[Model]]:
|
|
model_ids = [model_ids] if isinstance(model_ids, str) else model_ids
|
|
ids = set(model_ids)
|
|
query = Q(id__in=ids)
|
|
|
|
q = Model.get_many(
|
|
company=company_id,
|
|
query=query,
|
|
allow_public=allow_public,
|
|
return_dicts=False,
|
|
)
|
|
if only:
|
|
q = q.only(*only)
|
|
|
|
if q.count() != len(ids):
|
|
raise errors.bad_request.InvalidModelId(ids=model_ids)
|
|
|
|
if return_models:
|
|
return list(q)
|
|
|
|
@classmethod
|
|
def publish_model(
|
|
cls,
|
|
model_id: str,
|
|
company_id: str,
|
|
identity: Identity,
|
|
force_publish_task: bool = False,
|
|
publish_task_func: Callable[[str, str, Identity, bool], dict] = None,
|
|
) -> Tuple[int, ModelTaskPublishResponse]:
|
|
model = cls.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
|
if model.ready:
|
|
raise errors.bad_request.ModelIsReady(company=company_id, model=model_id)
|
|
|
|
user_id = identity.user
|
|
published_task = None
|
|
if model.task and publish_task_func:
|
|
task = (
|
|
Task.objects(id=model.task, company=company_id)
|
|
.only("id", "status")
|
|
.first()
|
|
)
|
|
if task and task.status != TaskStatus.published:
|
|
task_publish_res = publish_task_func(
|
|
model.task, company_id, identity, force_publish_task
|
|
)
|
|
published_task = ModelTaskPublishResponse(
|
|
id=model.task, data=task_publish_res
|
|
)
|
|
|
|
now = datetime.utcnow()
|
|
updated = model.update(
|
|
upsert=False,
|
|
ready=True,
|
|
last_update=now,
|
|
last_change=now,
|
|
last_changed_by=user_id,
|
|
)
|
|
return updated, published_task
|
|
|
|
@classmethod
|
|
def delete_model(
|
|
cls, model_id: str, company_id: str, user_id: str, force: bool
|
|
) -> Tuple[int, Model]:
|
|
model = cls.get_company_model_by_id(
|
|
company_id=company_id,
|
|
model_id=model_id,
|
|
only_fields=("id", "task", "project", "uri"),
|
|
)
|
|
deleted_model_id = f"{deleted_prefix}{model_id}"
|
|
|
|
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
|
if using_tasks:
|
|
if not force:
|
|
raise errors.bad_request.ModelInUse(
|
|
"as execution model, use force=True to delete",
|
|
num_tasks=len(using_tasks),
|
|
)
|
|
# update deleted model id in using tasks
|
|
Task._get_collection().update_many(
|
|
filter={"_id": {"$in": [t.id for t in using_tasks]}},
|
|
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
|
|
array_filters=[{"elem.model": model_id}],
|
|
upsert=False,
|
|
)
|
|
|
|
if model.task:
|
|
task = Task.objects(id=model.task).first()
|
|
if task:
|
|
now = datetime.utcnow()
|
|
if task.status == TaskStatus.published:
|
|
if not force:
|
|
raise errors.bad_request.ModelCreatingTaskExists(
|
|
"and published, use force=True to delete", task=model.task
|
|
)
|
|
Task._get_collection().update_one(
|
|
filter={"_id": model.task, "models.output.model": model_id},
|
|
update={
|
|
"$set": {
|
|
"models.output.$[elem].model": deleted_model_id,
|
|
"output.error": f"model deleted on {now.isoformat()}",
|
|
"last_change": now,
|
|
"last_changed_by": user_id,
|
|
},
|
|
},
|
|
array_filters=[{"elem.model": model_id}],
|
|
upsert=False,
|
|
)
|
|
else:
|
|
task.update(
|
|
pull__models__output__model=model_id,
|
|
set__last_change=now,
|
|
set__last_changed_by=user_id,
|
|
)
|
|
|
|
del_count = Model.objects(id=model_id, company=company_id).delete()
|
|
return del_count, model
|
|
|
|
@classmethod
|
|
def archive_model(cls, model_id: str, company_id: str, user_id: str):
|
|
cls.get_company_model_by_id(
|
|
company_id=company_id, model_id=model_id, only_fields=("id",)
|
|
)
|
|
now = datetime.utcnow()
|
|
archived = Model.objects(company=company_id, id=model_id).update(
|
|
add_to_set__system_tags=EntityVisibility.archived.value,
|
|
last_change=now,
|
|
last_changed_by=user_id,
|
|
)
|
|
|
|
return archived
|
|
|
|
@classmethod
|
|
def unarchive_model(cls, model_id: str, company_id: str, user_id: str):
|
|
cls.get_company_model_by_id(
|
|
company_id=company_id, model_id=model_id, only_fields=("id",)
|
|
)
|
|
now = datetime.utcnow()
|
|
unarchived = Model.objects(company=company_id, id=model_id).update(
|
|
pull__system_tags=EntityVisibility.archived.value,
|
|
last_change=now,
|
|
last_changed_by=user_id,
|
|
)
|
|
|
|
return unarchived
|
|
|
|
@classmethod
|
|
def get_model_stats(
|
|
cls, company: str, model_ids: Sequence[str],
|
|
) -> Dict[str, dict]:
|
|
if not model_ids:
|
|
return {}
|
|
|
|
result = Model.aggregate(
|
|
[
|
|
{
|
|
"$match": {
|
|
"company": {"$in": ["", company]},
|
|
"_id": {"$in": model_ids},
|
|
}
|
|
},
|
|
{
|
|
"$addFields": {
|
|
"labels_count": {"$size": {"$objectToArray": "$labels"}}
|
|
}
|
|
},
|
|
{"$project": {"labels_count": 1}},
|
|
]
|
|
)
|
|
return {r.pop("_id"): r for r in result}
|
|
|
|
@staticmethod
|
|
def update_statistics(
|
|
company_id: str,
|
|
user_id: str,
|
|
model_id: str,
|
|
last_update: datetime = None,
|
|
last_iteration_max: int = None,
|
|
last_scalar_events: Dict[str, Dict[str, dict]] = None,
|
|
):
|
|
last_update = last_update or datetime.utcnow()
|
|
updates = {
|
|
"last_update": datetime.utcnow(),
|
|
"last_change": last_update,
|
|
"last_changed_by": user_id,
|
|
}
|
|
if last_iteration_max is not None:
|
|
updates.update(max__last_iteration=last_iteration_max)
|
|
|
|
raw_updates = {}
|
|
if last_scalar_events is not None:
|
|
raw_updates = {}
|
|
if last_scalar_events is not None:
|
|
get_last_metric_updates(
|
|
task_id=model_id,
|
|
last_scalar_events=last_scalar_events,
|
|
raw_updates=raw_updates,
|
|
extra_updates=updates,
|
|
model_events=True,
|
|
)
|
|
|
|
ret = Model.objects(id=model_id).update_one(**updates)
|
|
if ret and raw_updates:
|
|
Model.objects(id=model_id).update_one(__raw__=[{"$set": raw_updates}])
|
|
|
|
return ret
|