Add multi-models support

This commit is contained in:
allegroai
2021-05-03 17:46:00 +03:00
parent 3c5195028e
commit ef42d0265d
23 changed files with 690 additions and 113 deletions

View File

@@ -26,6 +26,8 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
ModelItem,
Models,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
@@ -43,6 +45,7 @@ from .utils import (
update_project_time,
deleted_prefix,
)
from ...apimodels.tasks import TaskInputModel
log = config.logger(__file__)
org_bll = OrgBLL()
@@ -145,19 +148,20 @@ class TaskBLL:
)
@staticmethod
def validate_execution_model(task, allow_only_public=False):
if not task.execution or not task.execution.model:
def validate_input_models(task, allow_only_public=False):
if not task.models.input:
return
company = None if allow_only_public else task.company
model_id = task.execution.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company)
).first()
if not model:
raise errors.bad_request.InvalidModelId(model=model_id)
model_ids = set(m.model for m in task.models.input)
models = Model.objects(
Q(id__in=model_ids) & get_company_or_none_constraint(company)
).only("id")
missing = model_ids - {m.id for m in models}
if missing:
raise errors.bad_request.InvalidModelId(models=missing)
return model
return
@classmethod
def clone_task(
@@ -174,6 +178,7 @@ class TaskBLL:
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
) -> Tuple[Task, dict]:
@@ -189,10 +194,16 @@ class TaskBLL:
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
now = datetime.utcnow()
if input_models:
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models]
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
execution_model_overriden = execution_overrides.get("model") is not None
execution_model = execution_overrides.pop("model", None)
if not input_models and execution_model:
input_models = [ModelItem(model=execution_model, name="input")]
artifacts_prepare_for_save({"execution": execution_overrides})
params_dict["execution"] = {}
@@ -225,8 +236,6 @@ class TaskBLL:
)
new_project_data = {"id": project, "name": new_project_name}
now = datetime.utcnow()
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
if not input_tags:
return input_tags
@@ -262,13 +271,14 @@ class TaskBLL:
output=Output(destination=task.output.destination)
if task.output
else None,
models=Models(input=input_models or task.models.input),
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
@@ -295,7 +305,7 @@ class TaskBLL:
def validate(
cls,
task: Task,
validate_model=True,
validate_models=True,
validate_parent=True,
validate_project=True,
):
@@ -318,8 +328,8 @@ class TaskBLL:
if validate_project and not project:
raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task)
if validate_models:
cls.validate_input_models(task)
@staticmethod
def get_unique_metric_variants(
@@ -379,6 +389,7 @@ class TaskBLL:
tasks = Task.objects(id__in=task_ids, company=company_id).only(
"status", "started"
)
count = 0
for task in tasks:
updates = extra_updates
if task.status == TaskStatus.in_progress and task.started:
@@ -388,12 +399,13 @@ class TaskBLL:
).total_seconds(),
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
count += Task.objects(id=task.id, company=company_id).update(
upsert=False,
last_update=last_update,
last_change=last_update,
**updates,
)
return count
@staticmethod
def update_statistics(
@@ -456,7 +468,7 @@ class TaskBLL:
}
extra_updates["metric_stats"] = metric_stats
TaskBLL.set_last_update(
return TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
last_update=last_update,
@@ -524,17 +536,11 @@ class TaskBLL:
task.save()
# publish task models
if task.output.model and publish_model:
output_model = (
Model.objects(id=task.output.model)
.only("id", "task", "ready")
.first()
)
if output_model and not output_model.ready:
if task.models.output and publish_model:
model_ids = [m.model for m in task.models.output]
for model in Model.objects(id__in=model_ids, ready__ne=True).only("id"):
cls.model_set_ready(
model_id=task.output.model,
company_id=company_id,
publish_task=False,
model_id=model.id, company_id=company_id, publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)

View File

@@ -1,5 +1,5 @@
import itertools
from collections import defaultdict
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
@@ -131,7 +131,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
metric_urls.discard(None)
urls[metric].update(metric_urls)
return set(itertools.chain.from_iterable(urls.values()))
return set(chain.from_iterable(urls.values()))
def cleanup_task(
@@ -198,7 +198,7 @@ def cleanup_task(
)
def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
@@ -224,16 +224,16 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
models=len(models.published),
)
if task.output.model:
if task.models.output:
with TimingContext("mongo", "get_task_output_model"):
output_model = Model.objects(id=task.output.model).first()
if output_model:
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=task.output.model,
model=output_model.id,
)
models.published.append(output_model)
else:
@@ -242,10 +242,13 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models__input"
)
input_models = [t.execution.model for t in dependent_tasks]
input_models = {
m.model
for m in chain.from_iterable(t.models.input for t in dependent_tasks)
}
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)