mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Add multi-models support
This commit is contained in:
parent
3c5195028e
commit
ef42d0265d
@ -107,6 +107,11 @@ class GetTypesRequest(models.Base):
|
||||
projects = ListField(items_types=[str])
|
||||
|
||||
|
||||
class TaskInputModel(models.Base):
|
||||
name = StringField()
|
||||
model = StringField()
|
||||
|
||||
|
||||
class CloneRequest(TaskRequest):
|
||||
new_task_name = StringField()
|
||||
new_task_comment = StringField()
|
||||
@ -116,6 +121,7 @@ class CloneRequest(TaskRequest):
|
||||
new_task_project = StringField()
|
||||
new_task_hyperparams = DictField()
|
||||
new_task_configuration = DictField()
|
||||
new_task_input_models = ListField([TaskInputModel])
|
||||
execution_overrides = DictField()
|
||||
validate_references = BoolField(default=False)
|
||||
new_project_name = StringField()
|
||||
@ -224,3 +230,26 @@ class ArchiveRequest(MultiTaskRequest):
|
||||
|
||||
class ArchiveResponse(models.Base):
|
||||
archived = IntField()
|
||||
|
||||
|
||||
class ModelItemType(object):
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
class AddUpdateModelRequest(TaskRequest):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
|
||||
iteration = IntField()
|
||||
|
||||
|
||||
class ModelItemKey(models.Base):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
|
||||
|
||||
|
||||
class DeleteModelsRequest(TaskRequest):
|
||||
models: Sequence[ModelItemKey] = ListField(
|
||||
[ModelItemKey], validators=Length(minimum_value=1)
|
||||
)
|
||||
|
@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from typing import Tuple, Set, Sequence
|
||||
|
||||
import attr
|
||||
@ -125,20 +124,29 @@ def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
|
||||
if not models:
|
||||
return 0, set()
|
||||
|
||||
model_ids = {m.id for m in models}
|
||||
Task.objects(execution__model__in=model_ids, project__nin=projects).update(
|
||||
execution__model=None
|
||||
model_ids = list({m.id for m in models})
|
||||
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"project": {"$nin": projects},
|
||||
"models.input.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.input.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
model_tasks = {m.task for m in models if m.task}
|
||||
model_tasks = list({m.task for m in models if m.task})
|
||||
if model_tasks:
|
||||
now = datetime.utcnow()
|
||||
Task.objects(
|
||||
id__in=model_tasks, project__nin=projects, output__model__in=model_ids
|
||||
).update(
|
||||
output__model=None,
|
||||
output__error=f"model deleted on {now.isoformat()}",
|
||||
last_change=now,
|
||||
Task._get_collection().update_many(
|
||||
filter={
|
||||
"_id": {"$in": model_tasks},
|
||||
"project": {"$nin": projects},
|
||||
"models.output.model": {"$in": model_ids},
|
||||
},
|
||||
update={"$set": {"models.output.$[elem].model": None}},
|
||||
array_filters=[{"elem.model": {"$in": model_ids}}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
urls = {m.uri for m in models if m.uri}
|
||||
|
@ -26,6 +26,8 @@ from apiserver.database.model.task.task import (
|
||||
TaskStatusMessage,
|
||||
TaskSystemTags,
|
||||
ArtifactModes,
|
||||
ModelItem,
|
||||
Models,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
@ -43,6 +45,7 @@ from .utils import (
|
||||
update_project_time,
|
||||
deleted_prefix,
|
||||
)
|
||||
from ...apimodels.tasks import TaskInputModel
|
||||
|
||||
log = config.logger(__file__)
|
||||
org_bll = OrgBLL()
|
||||
@ -145,19 +148,20 @@ class TaskBLL:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_execution_model(task, allow_only_public=False):
|
||||
if not task.execution or not task.execution.model:
|
||||
def validate_input_models(task, allow_only_public=False):
|
||||
if not task.models.input:
|
||||
return
|
||||
|
||||
company = None if allow_only_public else task.company
|
||||
model_id = task.execution.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(model=model_id)
|
||||
model_ids = set(m.model for m in task.models.input)
|
||||
models = Model.objects(
|
||||
Q(id__in=model_ids) & get_company_or_none_constraint(company)
|
||||
).only("id")
|
||||
missing = model_ids - {m.id for m in models}
|
||||
if missing:
|
||||
raise errors.bad_request.InvalidModelId(models=missing)
|
||||
|
||||
return model
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def clone_task(
|
||||
@ -174,6 +178,7 @@ class TaskBLL:
|
||||
hyperparams: Optional[dict] = None,
|
||||
configuration: Optional[dict] = None,
|
||||
execution_overrides: Optional[dict] = None,
|
||||
input_models: Optional[Sequence[TaskInputModel]] = None,
|
||||
validate_references: bool = False,
|
||||
new_project_name: str = None,
|
||||
) -> Tuple[Task, dict]:
|
||||
@ -189,10 +194,16 @@ class TaskBLL:
|
||||
|
||||
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
|
||||
|
||||
now = datetime.utcnow()
|
||||
if input_models:
|
||||
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models]
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
execution_model_overriden = False
|
||||
if execution_overrides:
|
||||
execution_model_overriden = execution_overrides.get("model") is not None
|
||||
execution_model = execution_overrides.pop("model", None)
|
||||
if not input_models and execution_model:
|
||||
input_models = [ModelItem(model=execution_model, name="input")]
|
||||
|
||||
artifacts_prepare_for_save({"execution": execution_overrides})
|
||||
|
||||
params_dict["execution"] = {}
|
||||
@ -225,8 +236,6 @@ class TaskBLL:
|
||||
)
|
||||
new_project_data = {"id": project, "name": new_project_name}
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
|
||||
if not input_tags:
|
||||
return input_tags
|
||||
@ -262,13 +271,14 @@ class TaskBLL:
|
||||
output=Output(destination=task.output.destination)
|
||||
if task.output
|
||||
else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
)
|
||||
cls.validate(
|
||||
new_task,
|
||||
validate_model=validate_references or execution_model_overriden,
|
||||
validate_models=validate_references or input_models,
|
||||
validate_parent=validate_references or parent,
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
@ -295,7 +305,7 @@ class TaskBLL:
|
||||
def validate(
|
||||
cls,
|
||||
task: Task,
|
||||
validate_model=True,
|
||||
validate_models=True,
|
||||
validate_parent=True,
|
||||
validate_project=True,
|
||||
):
|
||||
@ -318,8 +328,8 @@ class TaskBLL:
|
||||
if validate_project and not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=task.project)
|
||||
|
||||
if validate_model:
|
||||
cls.validate_execution_model(task)
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(
|
||||
@ -379,6 +389,7 @@ class TaskBLL:
|
||||
tasks = Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"status", "started"
|
||||
)
|
||||
count = 0
|
||||
for task in tasks:
|
||||
updates = extra_updates
|
||||
if task.status == TaskStatus.in_progress and task.started:
|
||||
@ -388,12 +399,13 @@ class TaskBLL:
|
||||
).total_seconds(),
|
||||
**extra_updates,
|
||||
}
|
||||
Task.objects(id=task.id, company=company_id).update(
|
||||
count += Task.objects(id=task.id, company=company_id).update(
|
||||
upsert=False,
|
||||
last_update=last_update,
|
||||
last_change=last_update,
|
||||
**updates,
|
||||
)
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
@ -456,7 +468,7 @@ class TaskBLL:
|
||||
}
|
||||
extra_updates["metric_stats"] = metric_stats
|
||||
|
||||
TaskBLL.set_last_update(
|
||||
return TaskBLL.set_last_update(
|
||||
task_ids=[task_id],
|
||||
company_id=company_id,
|
||||
last_update=last_update,
|
||||
@ -524,17 +536,11 @@ class TaskBLL:
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.output.model and publish_model:
|
||||
output_model = (
|
||||
Model.objects(id=task.output.model)
|
||||
.only("id", "task", "ready")
|
||||
.first()
|
||||
)
|
||||
if output_model and not output_model.ready:
|
||||
if task.models.output and publish_model:
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for model in Model.objects(id__in=model_ids, ready__ne=True).only("id"):
|
||||
cls.model_set_ready(
|
||||
model_id=task.output.model,
|
||||
company_id=company_id,
|
||||
publish_task=False,
|
||||
model_id=model.id, company_id=company_id, publish_task=False,
|
||||
)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
|
||||
|
||||
@ -131,7 +131,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
metric_urls.discard(None)
|
||||
urls[metric].update(metric_urls)
|
||||
|
||||
return set(itertools.chain.from_iterable(urls.values()))
|
||||
return set(chain.from_iterable(urls.values()))
|
||||
|
||||
|
||||
def cleanup_task(
|
||||
@ -198,7 +198,7 @@ def cleanup_task(
|
||||
)
|
||||
|
||||
|
||||
def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
|
||||
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
|
||||
if not force:
|
||||
with TimingContext("mongo", "count_published_children"):
|
||||
published_children_count = Task.objects(
|
||||
@ -224,16 +224,16 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
|
||||
models=len(models.published),
|
||||
)
|
||||
|
||||
if task.output.model:
|
||||
if task.models.output:
|
||||
with TimingContext("mongo", "get_task_output_model"):
|
||||
output_model = Model.objects(id=task.output.model).first()
|
||||
if output_model:
|
||||
model_ids = [m.model for m in task.models.output]
|
||||
for output_model in Model.objects(id__in=model_ids):
|
||||
if output_model.ready:
|
||||
if not force:
|
||||
raise errors.bad_request.TaskCannotBeDeleted(
|
||||
"has output model, use force=True",
|
||||
task=task.id,
|
||||
model=task.output.model,
|
||||
model=output_model.id,
|
||||
)
|
||||
models.published.append(output_model)
|
||||
else:
|
||||
@ -242,10 +242,13 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
|
||||
if models.draft:
|
||||
with TimingContext("mongo", "get_execution_models"):
|
||||
model_ids = models.draft.ids
|
||||
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
|
||||
"id", "execution.model"
|
||||
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
|
||||
"id", "models__input"
|
||||
)
|
||||
input_models = [t.execution.model for t in dependent_tasks]
|
||||
input_models = {
|
||||
m.model
|
||||
for m in chain.from_iterable(t.models.input for t in dependent_tasks)
|
||||
}
|
||||
if input_models:
|
||||
models.draft = DocumentGroup(
|
||||
Model, (m for m in models.draft if m.id not in input_models)
|
||||
|
@ -11,6 +11,5 @@ class Result(object):
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = StrippedStringField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict
|
||||
from typing import Dict, Sequence
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
@ -17,6 +17,7 @@ from apiserver.database.fields import (
|
||||
SafeDictField,
|
||||
UnionField,
|
||||
SafeSortedListField,
|
||||
EmbeddedDocumentListField,
|
||||
)
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
@ -105,11 +106,21 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
description = StringField()
|
||||
|
||||
|
||||
class ModelItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True, reference_field="Model")
|
||||
updated = DateTimeField()
|
||||
|
||||
|
||||
class Models(EmbeddedDocument, ProperDictMixin):
|
||||
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
||||
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument, ProperDictMixin):
|
||||
meta = {"strict": strict}
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field="Model")
|
||||
model_desc = SafeMapField(StringField(default=""))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
@ -155,7 +166,7 @@ class Task(AttributedDocument):
|
||||
"active_duration",
|
||||
"parent",
|
||||
"project",
|
||||
"execution.model",
|
||||
"models.input.model",
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
("company", "status", "type"),
|
||||
@ -169,8 +180,8 @@ class Task(AttributedDocument):
|
||||
"$name",
|
||||
"$id",
|
||||
"$comment",
|
||||
"$execution.model",
|
||||
"$output.model",
|
||||
"$models.input.model",
|
||||
"$models.output.model",
|
||||
"$script.repository",
|
||||
"$script.entry_point",
|
||||
],
|
||||
@ -179,8 +190,8 @@ class Task(AttributedDocument):
|
||||
"name": 10,
|
||||
"id": 10,
|
||||
"comment": 10,
|
||||
"execution.model": 2,
|
||||
"output.model": 2,
|
||||
"models.output.model": 2,
|
||||
"models.input.model": 2,
|
||||
"script.repository": 1,
|
||||
"script.entry_point": 1,
|
||||
},
|
||||
@ -238,6 +249,7 @@ class Task(AttributedDocument):
|
||||
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
|
||||
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
|
||||
runtime = SafeDictField(default=dict)
|
||||
models: Models = EmbeddedDocumentField(Models, default=Models)
|
||||
docker_init_script = StringField()
|
||||
|
||||
def get_index_company(self) -> str:
|
||||
|
@ -10,13 +10,14 @@ from apiserver.database import utils
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.model.version import Version as DatabaseVersion
|
||||
|
||||
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
|
||||
_migrations = "migrations"
|
||||
_parent_dir = Path(__file__).resolve().parents[1]
|
||||
_migration_dir = _parent_dir / _migrations
|
||||
|
||||
|
||||
def check_mongo_empty() -> bool:
|
||||
return not all(
|
||||
get_db(alias).collection_names()
|
||||
for alias in utils.get_options(Database)
|
||||
get_db(alias).collection_names() for alias in utils.get_options(Database)
|
||||
)
|
||||
|
||||
|
||||
@ -41,8 +42,8 @@ def _apply_migrations(log: Logger):
|
||||
|
||||
log.info(f"Started mongodb migrations")
|
||||
|
||||
if not migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {migration_dir}")
|
||||
if not _migration_dir.is_dir():
|
||||
raise ValueError(f"Invalid migration dir {_migration_dir}")
|
||||
|
||||
empty_dbs = check_mongo_empty()
|
||||
last_version = get_last_server_version()
|
||||
@ -50,7 +51,10 @@ def _apply_migrations(log: Logger):
|
||||
try:
|
||||
new_scripts = {
|
||||
ver: path
|
||||
for ver, path in ((parse(f.stem), f) for f in migration_dir.glob("*.py"))
|
||||
for ver, path in (
|
||||
(parse(f.stem.replace("_", ".")), f)
|
||||
for f in _migration_dir.glob("*.py")
|
||||
)
|
||||
if ver > last_version
|
||||
}
|
||||
except ValueError as ex:
|
||||
@ -64,7 +68,9 @@ def _apply_migrations(log: Logger):
|
||||
if empty_dbs:
|
||||
log.info(f"Skipping migration {script.name} (empty databases)")
|
||||
else:
|
||||
spec = importlib.util.spec_from_file_location(script.stem, str(script))
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
".".join((_parent_dir.name, _migrations, script.stem)), str(script)
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
@ -83,7 +89,7 @@ def _apply_migrations(log: Logger):
|
||||
|
||||
DatabaseVersion(
|
||||
id=utils.id(),
|
||||
num=script.stem,
|
||||
num=str(script_version),
|
||||
created=datetime.utcnow(),
|
||||
desc="Applied on server startup",
|
||||
).save()
|
||||
|
@ -49,7 +49,7 @@ from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@ -437,7 +437,9 @@ class PrePopulate:
|
||||
if not orphans:
|
||||
return
|
||||
|
||||
print(f"ERROR: the following projects are exported without their parents: {orphans}")
|
||||
print(
|
||||
f"ERROR: the following projects are exported without their parents: {orphans}"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
@classmethod
|
||||
@ -483,12 +485,13 @@ class PrePopulate:
|
||||
|
||||
cls._check_projects_hierarchy(entities[cls.project_cls])
|
||||
|
||||
model_ids = {
|
||||
model_id
|
||||
task_models = chain.from_iterable(
|
||||
models
|
||||
for task in entities[cls.task_cls]
|
||||
for model_id in (task.output.model, task.execution.model)
|
||||
if model_id
|
||||
}
|
||||
for models in (task.models.input, task.models.output)
|
||||
if models
|
||||
)
|
||||
model_ids = {tm.model for tm in task_models}
|
||||
if model_ids:
|
||||
print("Reading models...")
|
||||
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
|
||||
@ -780,7 +783,32 @@ class PrePopulate:
|
||||
artifacts_path,
|
||||
value={get_artifact_id(a): a for a in artifacts},
|
||||
)
|
||||
item = json.dumps(task_data)
|
||||
|
||||
models = task_data.get("models", {})
|
||||
now = datetime.utcnow()
|
||||
for old_field, type_ in (
|
||||
("execution.model", "input"),
|
||||
("output.model", "output"),
|
||||
):
|
||||
old_path = old_field.split(".")
|
||||
old_model = nested_get(task_data, old_path)
|
||||
new_models = models.get(type_, [])
|
||||
if old_model and not any(
|
||||
m
|
||||
for m in new_models
|
||||
if m.get("model") == old_model or m.get("name") == type_
|
||||
):
|
||||
model_item = {"model": old_model, "name": type_, "updated": now}
|
||||
if type_ == "input":
|
||||
new_models = [model_item, *new_models]
|
||||
else:
|
||||
new_models = [*new_models, model_item]
|
||||
models[type_] = new_models
|
||||
nested_delete(task_data, old_path)
|
||||
task_data["models"] = models
|
||||
|
||||
item = json.dumps(task_data)
|
||||
print(item)
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
|
@ -1,15 +1,6 @@
|
||||
from collections import Collection
|
||||
from typing import Sequence
|
||||
from pymongo.database import Database
|
||||
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
|
||||
for collection_name in db.list_collection_names():
|
||||
if collection_name not in names:
|
||||
continue
|
||||
collection: Collection = db[collection_name]
|
||||
collection.drop_indexes()
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def migrate_auth(db: Database):
|
80
apiserver/mongo/migrations/0_18_0.py
Normal file
80
apiserver/mongo/migrations/0_18_0.py
Normal file
@ -0,0 +1,80 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
"""
|
||||
Collect the task output models from the models collections
|
||||
Move the execution and output models to new models.input and output lists
|
||||
Drop the task indices to accommodate the change in schema
|
||||
"""
|
||||
tasks: Collection = db["task"]
|
||||
models: Collection = db["model"]
|
||||
|
||||
models_field = "models"
|
||||
input = "input"
|
||||
output = "output"
|
||||
now = datetime.utcnow()
|
||||
|
||||
pipeline = [
|
||||
{"$match": {"task": {"$exists": True}}},
|
||||
{"$project": {"name": 1, "task": 1}},
|
||||
{"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}},
|
||||
]
|
||||
output_models = f"{models_field}.{output}"
|
||||
for group in models.aggregate(pipeline=pipeline, allowDiskUse=True):
|
||||
task_id = group.get("_id")
|
||||
task_models = group.get("models")
|
||||
if task_id and models:
|
||||
task_models = [
|
||||
{"model": m["_id"], "name": m.get("name", m["_id"]), "updated": now}
|
||||
for m in task_models
|
||||
]
|
||||
tasks.update_one(
|
||||
{"_id": task_id, output_models: {"$in": [None, []]}},
|
||||
{"$set": {output_models: task_models}},
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
fields = {input: "execution.model", output: "output.model"}
|
||||
query = {
|
||||
"$or": [
|
||||
{field: {"$exists": True, "$nin": [None, ""]}} for field in fields.values()
|
||||
]
|
||||
}
|
||||
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
|
||||
set_commands = {}
|
||||
for mode, field in fields.items():
|
||||
value = nested_get(doc, field.split("."))
|
||||
if not value:
|
||||
continue
|
||||
|
||||
model_doc = models.find_one(filter={"_id": value}, projection=["name"])
|
||||
name = model_doc.get("name", mode) if model_doc else mode
|
||||
model_item = {"model": value, "name": name, "updated": now}
|
||||
existing_models = nested_get(doc, (models_field, mode), default=[])
|
||||
existing_models = (
|
||||
m
|
||||
for m in existing_models
|
||||
if m.get("name") != name and m.get("model") != value
|
||||
)
|
||||
if mode == input:
|
||||
updated_models = [model_item, *existing_models]
|
||||
else:
|
||||
updated_models = [*existing_models, model_item]
|
||||
set_commands[f"{models_field}.{mode}"] = updated_models
|
||||
|
||||
tasks.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{
|
||||
"$unset": {field: 1 for field in fields.values()},
|
||||
**({"$set": set_commands} if set_commands else {}),
|
||||
},
|
||||
)
|
||||
|
||||
_drop_all_indices_from_collections(db, ["task*"])
|
20
apiserver/mongo/migrations/utils.py
Normal file
20
apiserver/mongo/migrations/utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Sequence
|
||||
|
||||
from boltons.iterutils import partition
|
||||
from pymongo.database import Database, Collection
|
||||
|
||||
|
||||
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
|
||||
"""
|
||||
Drop all indices for the existing collections from the specified list
|
||||
"""
|
||||
prefixes, names = partition(names, key=lambda x: x.endswith("*"))
|
||||
prefixes = {p.rstrip("*") for p in prefixes}
|
||||
for collection_name in db.list_collection_names():
|
||||
if not (
|
||||
collection_name in names
|
||||
or any(p for p in prefixes if collection_name.startswith(p))
|
||||
):
|
||||
continue
|
||||
collection: Collection = db[collection_name]
|
||||
collection.drop_indexes()
|
@ -40,6 +40,24 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
model_type_enum {
|
||||
type: string
|
||||
enum: ["input", "output"]
|
||||
}
|
||||
task_model_item {
|
||||
type: object
|
||||
required: [ name, model]
|
||||
properties {
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "The model ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
script {
|
||||
type: object
|
||||
properties {
|
||||
@ -207,6 +225,22 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
task_models {
|
||||
type: object
|
||||
properties {
|
||||
input {
|
||||
description: "The list of task input models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
|
||||
}
|
||||
output {
|
||||
description: "The list of task output models"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
execution {
|
||||
type: object
|
||||
properties {
|
||||
@ -454,6 +488,10 @@ _definitions {
|
||||
description: "Task execution params"
|
||||
"$ref": "#/definitions/execution"
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
// TODO: will be removed
|
||||
script {
|
||||
description: "Script info"
|
||||
@ -833,6 +871,101 @@ clone {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${clone."2.12"}{
|
||||
request {
|
||||
properties {
|
||||
new_task_input_models {
|
||||
description: "The list of input models for the cloned task. If not specifed then copied from the original task"
|
||||
type: array
|
||||
items {"$ref": "#/definitions/task_model_item"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
add_or_update_model {
|
||||
"2.13" {
|
||||
description: "Add or update task model"
|
||||
request {
|
||||
type: object
|
||||
required: [task, name, model, type]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
model {
|
||||
description: "The model ID"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "The task model type"
|
||||
"$ref": "#/definitions/model_type_enum"
|
||||
}
|
||||
iteration {
|
||||
description: "Iteration (used to update task statistics)"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_models {
|
||||
"2.13" {
|
||||
description: "Delete models from task"
|
||||
request {
|
||||
type: object
|
||||
required: [ task, models ]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task"
|
||||
type: string
|
||||
}
|
||||
models {
|
||||
description: "The list of models to delete"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
required: [name, type]
|
||||
properties {
|
||||
name {
|
||||
description: "The task model name"
|
||||
type: string
|
||||
}
|
||||
type {
|
||||
description: "The task model type"
|
||||
"$ref": "#/definitions/model_type_enum"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
create {
|
||||
"2.1" {
|
||||
@ -912,6 +1045,16 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
validate {
|
||||
"2.1" {
|
||||
@ -986,6 +1129,16 @@ validate {
|
||||
additionalProperties: false
|
||||
}
|
||||
}
|
||||
"2.13": ${validate."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@ -1161,6 +1314,16 @@ edit {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
models {
|
||||
description: "Task models"
|
||||
"$ref": "#/definitions/task_models"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
reset {
|
||||
"2.1" {
|
||||
|
@ -25,14 +25,14 @@ from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import validate_id
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
|
||||
from apiserver.database.utils import (
|
||||
parse_from_call,
|
||||
get_company_or_none_constraint,
|
||||
filter_fields,
|
||||
)
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags
|
||||
from apiserver.services.utils import conform_tag_fields, conform_output_tags, ModelsBackwardsCompatibility
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -61,19 +61,20 @@ def get_by_id(call: APICall, company_id, _):
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
def get_by_task_id(call: APICall, company_id, _):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
|
||||
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["output"], **query)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="output")
|
||||
if not task.output.model:
|
||||
raise errors.bad_request.MissingTaskFields(field="output.model")
|
||||
if not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
|
||||
model_id = task.output.model
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
@ -186,6 +187,9 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call: APICall, company_id, _):
|
||||
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
|
||||
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
|
||||
|
||||
task_id = call.data["task"]
|
||||
uri = call.data.get("uri")
|
||||
iteration = call.data.get("iteration")
|
||||
@ -201,7 +205,7 @@ def update_for_task(call: APICall, company_id, _):
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["output", "execution", "name", "status", "project"],
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
@ -226,12 +230,11 @@ def update_for_task(call: APICall, company_id, _):
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.output and task.output.model:
|
||||
if task.models.output:
|
||||
# model exists, update
|
||||
res = _update_model(
|
||||
call, company_id, model_id=task.output.model
|
||||
).to_struct()
|
||||
res.update({"id": task.output.model, "created": False})
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
@ -246,7 +249,7 @@ def update_for_task(call: APICall, company_id, _):
|
||||
company=company_id,
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.execution.model,
|
||||
parent=task.models.input[0].model if task.models.input else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
@ -259,7 +262,9 @@ def update_for_task(call: APICall, company_id, _):
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
output__model=model.id,
|
||||
models__output=[
|
||||
ModelItem(model=model.id, name=model.name, updated=datetime.utcnow())
|
||||
],
|
||||
)
|
||||
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
@ -465,7 +470,7 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
|
||||
|
||||
deleted_model_id = f"{deleted_prefix}{model_id}"
|
||||
|
||||
using_tasks = Task.objects(execution__model=model_id).only("id")
|
||||
using_tasks = Task.objects(models__input__model=model_id).only("id")
|
||||
if using_tasks:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelInUse(
|
||||
@ -473,23 +478,32 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
|
||||
num_tasks=len(using_tasks),
|
||||
)
|
||||
# update deleted model id in using tasks
|
||||
using_tasks.update(
|
||||
execution__model=deleted_model_id, upsert=False, multi=True
|
||||
Task._get_collection().update_many(
|
||||
filter={"_id": {"$in": [t.id for t in using_tasks]}},
|
||||
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
if model.task:
|
||||
task = Task.objects(id=model.task).first()
|
||||
task: Task = Task.objects(id=model.task).first()
|
||||
if task and task.status == TaskStatus.published:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelCreatingTaskExists(
|
||||
"and published, use force=True to delete", task=model.task
|
||||
)
|
||||
if task.output and task.output.model == model_id:
|
||||
if task.models.output and model_id in task.models.output:
|
||||
now = datetime.utcnow()
|
||||
task.update(
|
||||
output__model=deleted_model_id,
|
||||
output__error=f"model deleted on {now.isoformat()}",
|
||||
last_change=now,
|
||||
Task._get_collection().update_one(
|
||||
filter={"_id": model.task, "models.output.model": model_id},
|
||||
update={
|
||||
"$set": {
|
||||
"models.output.$[elem].model": deleted_model_id,
|
||||
"output.error": f"model deleted on {now.isoformat()}",
|
||||
},
|
||||
"last_change": now,
|
||||
},
|
||||
array_filters=[{"elem.model": model_id}],
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
|
@ -44,6 +44,9 @@ from apiserver.apimodels.tasks import (
|
||||
DeleteArtifactsRequest,
|
||||
ArchiveResponse,
|
||||
ArchiveRequest,
|
||||
AddUpdateModelRequest,
|
||||
DeleteModelsRequest,
|
||||
ModelItemType,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
@ -79,12 +82,14 @@ from apiserver.database.model.task.task import (
|
||||
DEFAULT_LAST_ITERATION,
|
||||
Execution,
|
||||
ArtifactModes,
|
||||
ModelItem,
|
||||
)
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
@ -329,6 +334,7 @@ create_fields = {
|
||||
"parent": Task,
|
||||
"project": None,
|
||||
"input": None,
|
||||
"models": None,
|
||||
"output_dest": None,
|
||||
"execution": None,
|
||||
"hyperparams": None,
|
||||
@ -341,6 +347,7 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
params_prepare_for_save(fields, previous_task=previous_task)
|
||||
artifacts_prepare_for_save(fields)
|
||||
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_stripped_fields:
|
||||
@ -361,6 +368,7 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
|
||||
tasks_data = [tasks_data]
|
||||
|
||||
conform_output_tags(call, tasks_data)
|
||||
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
||||
|
||||
for data in tasks_data:
|
||||
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
|
||||
@ -389,6 +397,17 @@ def prepare_create_fields(
|
||||
output = Output(destination=output_dest)
|
||||
fields["output"] = output
|
||||
|
||||
# Add models updated time
|
||||
models = fields.get("models")
|
||||
if models:
|
||||
now = datetime.utcnow()
|
||||
for field in ("input", "output"):
|
||||
field_models = models.get(field)
|
||||
if not field_models:
|
||||
continue
|
||||
for model in field_models:
|
||||
model["updated"] = now
|
||||
|
||||
return prepare_for_save(call, fields, previous_task=previous_task)
|
||||
|
||||
|
||||
@ -456,6 +475,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
hyperparams=request.new_task_hyperparams,
|
||||
configuration=request.new_task_configuration,
|
||||
execution_overrides=request.execution_overrides,
|
||||
input_models=request.new_task_input_models,
|
||||
validate_references=request.validate_references,
|
||||
new_project_name=request.new_project_name,
|
||||
)
|
||||
@ -886,8 +906,8 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
set__last_iteration=DEFAULT_LAST_ITERATION,
|
||||
set__last_metrics={},
|
||||
set__metric_stats={},
|
||||
set__models__output=[],
|
||||
unset__output__result=1,
|
||||
unset__output__model=1,
|
||||
unset__output__error=1,
|
||||
unset__last_worker=1,
|
||||
unset__last_worker_report=1,
|
||||
@ -1130,3 +1150,44 @@ def move(call: APICall, company_id: str, request: MoveRequest):
|
||||
update_project_time(projects)
|
||||
|
||||
return {"project_id": project_id}
|
||||
|
||||
|
||||
@endpoint("tasks.add_or_update_model", min_version="2.13")
|
||||
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
|
||||
TaskBLL.get_task_with_access(
|
||||
request.task, company_id=company_id, requires_write_access=True, only=["id"]
|
||||
)
|
||||
|
||||
models_field = f"models__{request.type}"
|
||||
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
|
||||
query = {"id": request.task, f"{models_field}__name": request.name}
|
||||
updated = Task.objects(**query).update_one(**{f"set__{models_field}__S": model})
|
||||
|
||||
updated = TaskBLL.update_statistics(
|
||||
task_id=request.task,
|
||||
company_id=company_id,
|
||||
last_iteration_max=request.iteration,
|
||||
**({f"push__{models_field}": model} if not updated else {}),
|
||||
)
|
||||
|
||||
return {"updated": updated}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_models", min_version="2.13")
|
||||
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
task = TaskBLL.get_task_with_access(
|
||||
request.task, company_id=company_id, requires_write_access=True, only=["id"]
|
||||
)
|
||||
|
||||
delete_names = {
|
||||
type_: [m.name for m in request.models if m.type == type_]
|
||||
for type_ in get_options(ModelItemType)
|
||||
}
|
||||
commands = {
|
||||
f"pull__models__{field}__name__in": names
|
||||
for field, names in delete_names.items()
|
||||
if names
|
||||
}
|
||||
|
||||
updated = task.update(last_change=datetime.utcnow(), **commands,)
|
||||
return {"updated": updated}
|
||||
|
@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
@ -5,6 +6,7 @@ from apiserver.apimodels.organization import Filter
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.utils import partition_tags
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
|
||||
@ -84,3 +86,47 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"unsupported tag prefix", values=unsupported
|
||||
)
|
||||
|
||||
|
||||
class ModelsBackwardsCompatibility:
|
||||
max_version = PartialVersion("2.13")
|
||||
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
|
||||
models_field = "models"
|
||||
|
||||
@classmethod
|
||||
def prepare_for_save(cls, call: APICall, fields: dict):
|
||||
if call.requested_endpoint_version > cls.max_version:
|
||||
return
|
||||
|
||||
for mode, field in cls.mode_to_fields.items():
|
||||
value = nested_get(fields, field)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
nested_delete(fields, field)
|
||||
|
||||
nested_set(
|
||||
fields,
|
||||
(cls.models_field, mode),
|
||||
value=[dict(name=mode, model=value, updated=datetime.utcnow())],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unprepare_from_saved(
|
||||
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
|
||||
):
|
||||
if call.requested_endpoint_version > cls.max_version:
|
||||
return
|
||||
|
||||
if isinstance(tasks_data, dict):
|
||||
tasks_data = [tasks_data]
|
||||
|
||||
for task in tasks_data:
|
||||
for mode, field in cls.mode_to_fields.items():
|
||||
models = nested_get(task, (cls.models_field, mode))
|
||||
if not models:
|
||||
continue
|
||||
|
||||
model = models[0] if mode == "input" else models[-1]
|
||||
if model:
|
||||
nested_set(task, field, model.get("model"))
|
||||
|
111
apiserver/tests/automated/test_task_models.py
Normal file
111
apiserver/tests/automated/test_task_models.py
Normal file
@ -0,0 +1,111 @@
|
||||
from copy import deepcopy
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskModels(TestService):
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_new_apis(self):
|
||||
# no models
|
||||
empty_task = self.new_task()
|
||||
self.assertModels(empty_task, [], [])
|
||||
|
||||
id1, id2 = self.new_model("model1"), self.new_model("model2")
|
||||
input_models = [
|
||||
{"name": "input1", "model": id1},
|
||||
{"name": "input2", "model": id2},
|
||||
]
|
||||
output_models = [
|
||||
{"name": "output1", "model": "id3"},
|
||||
{"name": "output2", "model": "id4"},
|
||||
]
|
||||
|
||||
# task creation with models
|
||||
task = self.new_task(models={"input": input_models, "output": output_models})
|
||||
self.assertModels(task, input_models, output_models)
|
||||
|
||||
# add_or_update existing model
|
||||
res = self.api.tasks.add_or_update_model(
|
||||
task=task, name="input1", type="input", model="Test"
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
modified_input = deepcopy(input_models)
|
||||
modified_input[0]["model"] = "Test"
|
||||
self.assertModels(task, modified_input, output_models)
|
||||
|
||||
# add_or_update new mode
|
||||
res = self.api.tasks.add_or_update_model(
|
||||
task=task, name="output3", type="output", model="TestOutput"
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
modified_output = deepcopy(output_models)
|
||||
modified_output.append({"name": "output3", "model": "TestOutput"})
|
||||
self.assertModels(task, modified_input, modified_output)
|
||||
|
||||
# task editing
|
||||
self.api.tasks.edit(
|
||||
task=task, models={"input": input_models, "output": output_models}
|
||||
)
|
||||
self.assertModels(task, input_models, output_models)
|
||||
|
||||
# delete models
|
||||
res = self.api.tasks.delete_models(
|
||||
task=task,
|
||||
models=[
|
||||
{"name": "input1", "type": "input"},
|
||||
{"name": "input2", "type": "input"},
|
||||
{"name": "output1", "type": "output"},
|
||||
{"name": "not_existing", "type": "output"},
|
||||
]
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
self.assertModels(task, [], output_models[1:])
|
||||
|
||||
def assertModels(
|
||||
self, task_id: str, input_models: Sequence[dict], output_models: Sequence[dict],
|
||||
):
|
||||
def get_model_id(model: dict) -> Optional[str]:
|
||||
if not model:
|
||||
return None
|
||||
id_ = model.get("model")
|
||||
if isinstance(id_, str):
|
||||
return id_
|
||||
if id_ is None or id_ == {}:
|
||||
return None
|
||||
return id_.get("id")
|
||||
|
||||
def compare_models(actual: Sequence[dict], expected: Sequence[dict]):
|
||||
self.assertEqual(
|
||||
[(m["name"], get_model_id(m)) for m in actual],
|
||||
[(m["name"], m["model"]) for m in expected],
|
||||
)
|
||||
|
||||
for task in (
|
||||
self.api.tasks.get_all_ex(id=task_id).tasks[0],
|
||||
self.api.tasks.get_all(id=task_id).tasks[0],
|
||||
self.api.tasks.get_by_id(task=task_id).task,
|
||||
):
|
||||
compare_models(task.models.input, input_models)
|
||||
compare_models(task.models.output, output_models)
|
||||
self.assertEqual(
|
||||
get_model_id(task.execution),
|
||||
input_models[0]["model"] if input_models else None,
|
||||
)
|
||||
self.assertEqual(
|
||||
get_model_id(task.output),
|
||||
output_models[-1]["model"] if output_models else None,
|
||||
)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs, type="testing", name="test task models", input=dict(view=dict())
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def new_model(self, name: str, **kwargs):
|
||||
return self.create_temp(
|
||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||
)
|
Loading…
Reference in New Issue
Block a user