mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add multi-models support
This commit is contained in:
@@ -26,6 +26,8 @@ from apiserver.database.model.task.task import (
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
ModelItem,
|
||||
Models,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
@@ -43,6 +45,7 @@ from .utils import (
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
)
|
||||
from ...apimodels.tasks import TaskInputModel
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@@ -145,19 +148,20 @@ class TaskBLL:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_execution_model(task, allow_only_public=False):
|
||||
if not task.execution or not task.execution.model:
|
||||
def validate_input_models(task, allow_only_public=False):
|
||||
if not task.models.input:
|
||||
return
|
||||
|
||||
company = None if allow_only_public else task.company
|
||||
model_id = task.execution.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(model=model_id)
|
||||
model_ids = set(m.model for m in task.models.input)
|
||||
models = Model.objects(
|
||||
Q(id__in=model_ids) & get_company_or_none_constraint(company)
|
||||
).only("id")
|
||||
missing = model_ids - {m.id for m in models}
|
||||
if missing:
|
||||
raise errors.bad_request.InvalidModelId(models=missing)
|
||||
|
||||
return model
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def clone_task(
|
||||
@@ -174,6 +178,7 @@ class TaskBLL:
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
input_models: Optional[Sequence[TaskInputModel]] = None,
|
||||
validate_references: bool = False,
|
||||
new_project_name: str = None,
|
||||
) -> Tuple[Task, dict]:
|
||||
@@ -189,10 +194,16 @@ class TaskBLL:
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
now = datetime.utcnow()
|
||||
if input_models:
|
||||
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models]
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
execution_model = execution_overrides.pop("model", None)
|
||||
if not input_models and execution_model:
|
||||
input_models = [ModelItem(model=execution_model, name="input")]
|
||||
|
||||
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||
|
||||
params_dict["execution"] = {}
|
||||
@@ -225,8 +236,6 @@ class TaskBLL:
|
||||
)
|
||||
new_project_data = {"id": project, "name": new_project_name}
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
|
||||
if not input_tags:
|
||||
return input_tags
|
||||
@@ -262,13 +271,14 @@ class TaskBLL:
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
@@ -295,7 +305,7 @@ class TaskBLL:
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_models=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
@@ -318,8 +328,8 @@ class TaskBLL:
|
||||
if validate_project and not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(
|
||||
@@ -379,6 +389,7 @@ class TaskBLL:
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"status", "started"
|
||||
)
|
||||
count = 0
|
||||
for task in tasks:
|
||||
updates = extra_updates
|
||||
if task.status == TaskStatus.in_progress and task.started:
|
||||
@@ -388,12 +399,13 @@ class TaskBLL:
|
||||
).total_seconds(),
|
||||
**extra_updates,
|
||||
}
|
||||
Task.objects(id=task.id, company=company_id).update(
|
||||
count += Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
**updates,
|
||||
)
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
@@ -456,7 +468,7 @@ class TaskBLL:
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
TaskBLL.set_last_update(
|
||||
return TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
last_update=last_update,
|
||||
@@ -524,17 +536,11 @@ class TaskBLL:
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.output.model and publish_model:
|
||||
output_model = (
|
||||
Model.objects(id=task.output.model)
|
||||
.only("id", "task", "ready")
|
||||
.first()
|
||||
)
|
||||
if output_model and not output_model.ready:
|
||||
if task.models.output and publish_model:
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for model in Model.objects(id__in=model_ids, ready__ne=True).only("id"):
|
||||
cls.model_set_ready(
|
||||
model_id=task.output.model,
|
||||
company_id=company_id,
|
||||
publish_task=False,
|
||||
model_id=model.id, company_id=company_id, publish_task=False,
|
||||
)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
|
||||
|
||||
@@ -131,7 +131,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
metric_urls.discard(None)
|
||||
urls[metric].update(metric_urls)
|
||||
|
||||
return set(itertools.chain.from_iterable(urls.values()))
|
||||
return set(chain.from_iterable(urls.values()))
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
@@ -198,7 +198,7 @@ def cleanup_task(
|
||||
)
|
||||
|
||||
|
||||
def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
|
||||
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
|
||||
if not force:
|
||||
with TimingContext("mongo", "count_published_children"):
|
||||
published_children_count = Task.objects(
|
||||
@@ -224,16 +224,16 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
|
||||
models=len(models.published),
|
||||
)
|
||||
|
||||
if task.output.model:
|
||||
if task.models.output:
|
||||
with TimingContext("mongo", "get_task_output_model"):
|
||||
output_model = Model.objects(id=task.output.model).first()
|
||||
if output_model:
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for output_model in Model.objects(id__in=model_ids):
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=task.output.model,
|
||||
model=output_model.id,
|
||||
)
|
||||
models.published.append(output_model)
|
||||
else:
|
||||
@@ -242,10 +242,13 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = models.draft.ids
|
||||
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
|
||||
"id", "execution.model"
|
||||
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
|
||||
"id", "models__input"
|
||||
)
|
||||
input_models = [t.execution.model for t in dependent_tasks]
|
||||
input_models = {
|
||||
m.model
|
||||
for m in chain.from_iterable(t.models.input for t in dependent_tasks)
|
||||
}
|
||||
if input_models:
|
||||
models.draft = DocumentGroup(
|
||||
Model, (m for m in models.draft if m.id not in input_models)
|
||||
|
||||
Reference in New Issue
Block a user