Add multi-models support

This commit is contained in:
allegroai
2021-05-03 17:46:00 +03:00
parent 3c5195028e
commit ef42d0265d
23 changed files with 690 additions and 113 deletions

View File

@@ -1,5 +1,5 @@
import itertools
from collections import defaultdict
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
@@ -131,7 +131,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
metric_urls.discard(None)
urls[metric].update(metric_urls)
return set(itertools.chain.from_iterable(urls.values()))
return set(chain.from_iterable(urls.values()))
def cleanup_task(
@@ -198,7 +198,7 @@ def cleanup_task(
)
def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
@@ -224,16 +224,16 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
models=len(models.published),
)
if task.output.model:
if task.models.output:
with TimingContext("mongo", "get_task_output_model"):
output_model = Model.objects(id=task.output.model).first()
if output_model:
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=task.output.model,
model=output_model.id,
)
models.published.append(output_model)
else:
@@ -242,10 +242,13 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models__input"
)
input_models = [t.execution.model for t in dependent_tasks]
input_models = {
m.model
for m in chain.from_iterable(t.models.input for t in dependent_tasks)
}
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)