Add multi-models support

This commit is contained in:
allegroai 2021-05-03 17:46:00 +03:00
parent 3c5195028e
commit ef42d0265d
23 changed files with 690 additions and 113 deletions

View File

@ -107,6 +107,11 @@ class GetTypesRequest(models.Base):
projects = ListField(items_types=[str])
class TaskInputModel(models.Base):
name = StringField()
model = StringField()
class CloneRequest(TaskRequest):
new_task_name = StringField()
new_task_comment = StringField()
@ -116,6 +121,7 @@ class CloneRequest(TaskRequest):
new_task_project = StringField()
new_task_hyperparams = DictField()
new_task_configuration = DictField()
new_task_input_models = ListField([TaskInputModel])
execution_overrides = DictField()
validate_references = BoolField(default=False)
new_project_name = StringField()
@ -224,3 +230,26 @@ class ArchiveRequest(MultiTaskRequest):
class ArchiveResponse(models.Base):
archived = IntField()
class ModelItemType(object):
input = "input"
output = "output"
class AddUpdateModelRequest(TaskRequest):
name = StringField(required=True)
model = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
iteration = IntField()
class ModelItemKey(models.Base):
name = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
class DeleteModelsRequest(TaskRequest):
models: Sequence[ModelItemKey] = ListField(
[ModelItemKey], validators=Length(minimum_value=1)
)

View File

@ -1,4 +1,3 @@
from datetime import datetime
from typing import Tuple, Set, Sequence
import attr
@ -125,20 +124,29 @@ def _delete_models(projects: Sequence[str]) -> Tuple[int, Set[str]]:
if not models:
return 0, set()
model_ids = {m.id for m in models}
Task.objects(execution__model__in=model_ids, project__nin=projects).update(
execution__model=None
model_ids = list({m.id for m in models})
Task._get_collection().update_many(
filter={
"project": {"$nin": projects},
"models.input.model": {"$in": model_ids},
},
update={"$set": {"models.input.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
model_tasks = {m.task for m in models if m.task}
model_tasks = list({m.task for m in models if m.task})
if model_tasks:
now = datetime.utcnow()
Task.objects(
id__in=model_tasks, project__nin=projects, output__model__in=model_ids
).update(
output__model=None,
output__error=f"model deleted on {now.isoformat()}",
last_change=now,
Task._get_collection().update_many(
filter={
"_id": {"$in": model_tasks},
"project": {"$nin": projects},
"models.output.model": {"$in": model_ids},
},
update={"$set": {"models.output.$[elem].model": None}},
array_filters=[{"elem.model": {"$in": model_ids}}],
upsert=False,
)
urls = {m.uri for m in models if m.uri}

View File

@ -26,6 +26,8 @@ from apiserver.database.model.task.task import (
TaskStatusMessage,
TaskSystemTags,
ArtifactModes,
ModelItem,
Models,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
@ -43,6 +45,7 @@ from .utils import (
update_project_time,
deleted_prefix,
)
from ...apimodels.tasks import TaskInputModel
log = config.logger(__file__)
org_bll = OrgBLL()
@ -145,19 +148,20 @@ class TaskBLL:
)
@staticmethod
def validate_execution_model(task, allow_only_public=False):
if not task.execution or not task.execution.model:
def validate_input_models(task, allow_only_public=False):
if not task.models.input:
return
company = None if allow_only_public else task.company
model_id = task.execution.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company)
).first()
if not model:
raise errors.bad_request.InvalidModelId(model=model_id)
model_ids = set(m.model for m in task.models.input)
models = Model.objects(
Q(id__in=model_ids) & get_company_or_none_constraint(company)
).only("id")
missing = model_ids - {m.id for m in models}
if missing:
raise errors.bad_request.InvalidModelId(models=missing)
return model
return
@classmethod
def clone_task(
@ -174,6 +178,7 @@ class TaskBLL:
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
input_models: Optional[Sequence[TaskInputModel]] = None,
validate_references: bool = False,
new_project_name: str = None,
) -> Tuple[Task, dict]:
@ -189,10 +194,16 @@ class TaskBLL:
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
now = datetime.utcnow()
if input_models:
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models]
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:
execution_model_overriden = execution_overrides.get("model") is not None
execution_model = execution_overrides.pop("model", None)
if not input_models and execution_model:
input_models = [ModelItem(model=execution_model, name="input")]
artifacts_prepare_for_save({"execution": execution_overrides})
params_dict["execution"] = {}
@ -225,8 +236,6 @@ class TaskBLL:
)
new_project_data = {"id": project, "name": new_project_name}
now = datetime.utcnow()
def clean_system_tags(input_tags: Sequence[str]) -> Sequence[str]:
if not input_tags:
return input_tags
@ -262,13 +271,14 @@ class TaskBLL:
output=Output(destination=task.output.destination)
if task.output
else None,
models=Models(input=input_models or task.models.input),
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
validate_model=validate_references or execution_model_overriden,
validate_models=validate_references or input_models,
validate_parent=validate_references or parent,
validate_project=validate_references or project,
)
@ -295,7 +305,7 @@ class TaskBLL:
def validate(
cls,
task: Task,
validate_model=True,
validate_models=True,
validate_parent=True,
validate_project=True,
):
@ -318,8 +328,8 @@ class TaskBLL:
if validate_project and not project:
raise errors.bad_request.InvalidProjectId(id=task.project)
if validate_model:
cls.validate_execution_model(task)
if validate_models:
cls.validate_input_models(task)
@staticmethod
def get_unique_metric_variants(
@ -379,6 +389,7 @@ class TaskBLL:
tasks = Task.objects(id__in=task_ids, company=company_id).only(
"status", "started"
)
count = 0
for task in tasks:
updates = extra_updates
if task.status == TaskStatus.in_progress and task.started:
@ -388,12 +399,13 @@ class TaskBLL:
).total_seconds(),
**extra_updates,
}
Task.objects(id=task.id, company=company_id).update(
count += Task.objects(id=task.id, company=company_id).update(
upsert=False,
last_update=last_update,
last_change=last_update,
**updates,
)
return count
@staticmethod
def update_statistics(
@ -456,7 +468,7 @@ class TaskBLL:
}
extra_updates["metric_stats"] = metric_stats
TaskBLL.set_last_update(
return TaskBLL.set_last_update(
task_ids=[task_id],
company_id=company_id,
last_update=last_update,
@ -524,17 +536,11 @@ class TaskBLL:
task.save()
# publish task models
if task.output.model and publish_model:
output_model = (
Model.objects(id=task.output.model)
.only("id", "task", "ready")
.first()
)
if output_model and not output_model.ready:
if task.models.output and publish_model:
model_ids = [m.model for m in task.models.output]
for model in Model.objects(id__in=model_ids, ready__ne=True).only("id"):
cls.model_set_ready(
model_id=task.output.model,
company_id=company_id,
publish_task=False,
model_id=model.id, company_id=company_id, publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)

View File

@ -1,5 +1,5 @@
import itertools
from collections import defaultdict
from itertools import chain
from operator import attrgetter
from typing import Sequence, Generic, Callable, Type, Iterable, TypeVar, List, Set
@ -131,7 +131,7 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
metric_urls.discard(None)
urls[metric].update(metric_urls)
return set(itertools.chain.from_iterable(urls.values()))
return set(chain.from_iterable(urls.values()))
def cleanup_task(
@ -198,7 +198,7 @@ def cleanup_task(
)
def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
def verify_task_children_and_ouptuts(task: Task, force: bool) -> TaskOutputs[Model]:
if not force:
with TimingContext("mongo", "count_published_children"):
published_children_count = Task.objects(
@ -224,16 +224,16 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
models=len(models.published),
)
if task.output.model:
if task.models.output:
with TimingContext("mongo", "get_task_output_model"):
output_model = Model.objects(id=task.output.model).first()
if output_model:
model_ids = [m.model for m in task.models.output]
for output_model in Model.objects(id__in=model_ids):
if output_model.ready:
if not force:
raise errors.bad_request.TaskCannotBeDeleted(
"has output model, use force=True",
task=task.id,
model=task.output.model,
model=output_model.id,
)
models.published.append(output_model)
else:
@ -242,10 +242,13 @@ def verify_task_children_and_ouptuts(task, force: bool) -> TaskOutputs[Model]:
if models.draft:
with TimingContext("mongo", "get_execution_models"):
model_ids = models.draft.ids
dependent_tasks = Task.objects(execution__model__in=model_ids).only(
"id", "execution.model"
dependent_tasks = Task.objects(models__input__model__in=model_ids).only(
"id", "models__input"
)
input_models = [t.execution.model for t in dependent_tasks]
input_models = {
m.model
for m in chain.from_iterable(t.models.input for t in dependent_tasks)
}
if input_models:
models.draft = DocumentGroup(
Model, (m for m in models.draft if m.id not in input_models)

View File

@ -11,6 +11,5 @@ class Result(object):
class Output(EmbeddedDocument):
destination = StrippedStringField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@ -1,4 +1,4 @@
from typing import Dict
from typing import Dict, Sequence
from mongoengine import (
StringField,
@ -17,6 +17,7 @@ from apiserver.database.fields import (
SafeDictField,
UnionField,
SafeSortedListField,
EmbeddedDocumentListField,
)
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
@ -105,11 +106,21 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
description = StringField()
class ModelItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
model = StringField(required=True, reference_field="Model")
updated = DateTimeField()
class Models(EmbeddedDocument, ProperDictMixin):
input: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
output: Sequence[ModelItem] = EmbeddedDocumentListField(ModelItem, default=list)
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
model_desc = SafeMapField(StringField(default=""))
model_labels = ModelLabels()
framework = StringField()
@ -155,7 +166,7 @@ class Task(AttributedDocument):
"active_duration",
"parent",
"project",
"execution.model",
"models.input.model",
("company", "name"),
("company", "user"),
("company", "status", "type"),
@ -169,8 +180,8 @@ class Task(AttributedDocument):
"$name",
"$id",
"$comment",
"$execution.model",
"$output.model",
"$models.input.model",
"$models.output.model",
"$script.repository",
"$script.entry_point",
],
@ -179,8 +190,8 @@ class Task(AttributedDocument):
"name": 10,
"id": 10,
"comment": 10,
"execution.model": 2,
"output.model": 2,
"models.output.model": 2,
"models.input.model": 2,
"script.repository": 1,
"script.entry_point": 1,
},
@ -238,6 +249,7 @@ class Task(AttributedDocument):
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)
models: Models = EmbeddedDocumentField(Models, default=Models)
docker_init_script = StringField()
def get_index_company(self) -> str:

View File

@ -10,13 +10,14 @@ from apiserver.database import utils
from apiserver.database import Database
from apiserver.database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
_migrations = "migrations"
_parent_dir = Path(__file__).resolve().parents[1]
_migration_dir = _parent_dir / _migrations
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in utils.get_options(Database)
get_db(alias).collection_names() for alias in utils.get_options(Database)
)
@ -41,8 +42,8 @@ def _apply_migrations(log: Logger):
log.info(f"Started mongodb migrations")
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
if not _migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {_migration_dir}")
empty_dbs = check_mongo_empty()
last_version = get_last_server_version()
@ -50,7 +51,10 @@ def _apply_migrations(log: Logger):
try:
new_scripts = {
ver: path
for ver, path in ((parse(f.stem), f) for f in migration_dir.glob("*.py"))
for ver, path in (
(parse(f.stem.replace("_", ".")), f)
for f in _migration_dir.glob("*.py")
)
if ver > last_version
}
except ValueError as ex:
@ -64,7 +68,9 @@ def _apply_migrations(log: Logger):
if empty_dbs:
log.info(f"Skipping migration {script.name} (empty databases)")
else:
spec = importlib.util.spec_from_file_location(script.stem, str(script))
spec = importlib.util.spec_from_file_location(
".".join((_parent_dir.name, _migrations, script.stem)), str(script)
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
@ -83,7 +89,7 @@ def _apply_migrations(log: Logger):
DatabaseVersion(
id=utils.id(),
num=script.stem,
num=str(script_version),
created=datetime.utcnow(),
desc="Applied on server startup",
).save()

View File

@ -49,7 +49,7 @@ from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
from apiserver.database.utils import get_options
from apiserver.tools import safe_get
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
class PrePopulate:
@ -437,7 +437,9 @@ class PrePopulate:
if not orphans:
return
print(f"ERROR: the following projects are exported without their parents: {orphans}")
print(
f"ERROR: the following projects are exported without their parents: {orphans}"
)
exit(1)
@classmethod
@ -483,12 +485,13 @@ class PrePopulate:
cls._check_projects_hierarchy(entities[cls.project_cls])
model_ids = {
model_id
task_models = chain.from_iterable(
models
for task in entities[cls.task_cls]
for model_id in (task.output.model, task.execution.model)
if model_id
}
for models in (task.models.input, task.models.output)
if models
)
model_ids = {tm.model for tm in task_models}
if model_ids:
print("Reading models...")
entities[cls.model_cls] = set(cls.model_cls.objects(id__in=list(model_ids)))
@ -780,7 +783,32 @@ class PrePopulate:
artifacts_path,
value={get_artifact_id(a): a for a in artifacts},
)
item = json.dumps(task_data)
models = task_data.get("models", {})
now = datetime.utcnow()
for old_field, type_ in (
("execution.model", "input"),
("output.model", "output"),
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = models.get(type_, [])
if old_model and not any(
m
for m in new_models
if m.get("model") == old_model or m.get("name") == type_
):
model_item = {"model": old_model, "name": type_, "updated": now}
if type_ == "input":
new_models = [model_item, *new_models]
else:
new_models = [*new_models, model_item]
models[type_] = new_models
nested_delete(task_data, old_path)
task_data["models"] = models
item = json.dumps(task_data)
print(item)
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):

View File

@ -1,15 +1,6 @@
from collections import Collection
from typing import Sequence
from pymongo.database import Database
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
for collection_name in db.list_collection_names():
if collection_name not in names:
continue
collection: Collection = db[collection_name]
collection.drop_indexes()
from .utils import _drop_all_indices_from_collections
def migrate_auth(db: Database):

View File

@ -0,0 +1,80 @@
from datetime import datetime
from pymongo.collection import Collection
from pymongo.database import Database
from apiserver.utilities.dicts import nested_get
from .utils import _drop_all_indices_from_collections
def migrate_backend(db: Database):
"""
Collect the task output models from the models collections
Move the execution and output models to new models.input and output lists
Drop the task indices to accommodate the change in schema
"""
tasks: Collection = db["task"]
models: Collection = db["model"]
models_field = "models"
input = "input"
output = "output"
now = datetime.utcnow()
pipeline = [
{"$match": {"task": {"$exists": True}}},
{"$project": {"name": 1, "task": 1}},
{"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}},
]
output_models = f"{models_field}.{output}"
for group in models.aggregate(pipeline=pipeline, allowDiskUse=True):
task_id = group.get("_id")
task_models = group.get("models")
if task_id and models:
task_models = [
{"model": m["_id"], "name": m.get("name", m["_id"]), "updated": now}
for m in task_models
]
tasks.update_one(
{"_id": task_id, output_models: {"$in": [None, []]}},
{"$set": {output_models: task_models}},
upsert=False,
)
fields = {input: "execution.model", output: "output.model"}
query = {
"$or": [
{field: {"$exists": True, "$nin": [None, ""]}} for field in fields.values()
]
}
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
set_commands = {}
for mode, field in fields.items():
value = nested_get(doc, field.split("."))
if not value:
continue
model_doc = models.find_one(filter={"_id": value}, projection=["name"])
name = model_doc.get("name", mode) if model_doc else mode
model_item = {"model": value, "name": name, "updated": now}
existing_models = nested_get(doc, (models_field, mode), default=[])
existing_models = (
m
for m in existing_models
if m.get("name") != name and m.get("model") != value
)
if mode == input:
updated_models = [model_item, *existing_models]
else:
updated_models = [*existing_models, model_item]
set_commands[f"{models_field}.{mode}"] = updated_models
tasks.update_one(
{"_id": doc["_id"]},
{
"$unset": {field: 1 for field in fields.values()},
**({"$set": set_commands} if set_commands else {}),
},
)
_drop_all_indices_from_collections(db, ["task*"])

View File

@ -0,0 +1,20 @@
from typing import Sequence
from boltons.iterutils import partition
from pymongo.database import Database, Collection
def _drop_all_indices_from_collections(db: Database, names: Sequence[str]):
"""
Drop all indices for the existing collections from the specified list
"""
prefixes, names = partition(names, key=lambda x: x.endswith("*"))
prefixes = {p.rstrip("*") for p in prefixes}
for collection_name in db.list_collection_names():
if not (
collection_name in names
or any(p for p in prefixes if collection_name.startswith(p))
):
continue
collection: Collection = db[collection_name]
collection.drop_indexes()

View File

@ -40,6 +40,24 @@ _definitions {
}
}
}
model_type_enum {
type: string
enum: ["input", "output"]
}
task_model_item {
type: object
required: [ name, model]
properties {
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
}
}
script {
type: object
properties {
@ -207,6 +225,22 @@ _definitions {
}
}
}
task_models {
type: object
properties {
input {
description: "The list of task input models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
output {
description: "The list of task output models"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
}
}
execution {
type: object
properties {
@ -454,6 +488,10 @@ _definitions {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
// TODO: will be removed
script {
description: "Script info"
@ -833,6 +871,101 @@ clone {
}
}
}
"2.13": ${clone."2.12"}{
request {
properties {
new_task_input_models {
description: "The list of input models for the cloned task. If not specifed then copied from the original task"
type: array
items {"$ref": "#/definitions/task_model_item"}
}
}
}
}
}
add_or_update_model {
"2.13" {
description: "Add or update task model"
request {
type: object
required: [task, name, model, type]
properties {
task {
description: "ID of the task"
type: string
}
name {
description: "The task model name"
type: string
}
model {
description: "The model ID"
type: string
}
type {
description: "The task model type"
"$ref": "#/definitions/model_type_enum"
}
iteration {
description: "Iteration (used to update task statistics)"
type: integer
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
delete_models {
"2.13" {
description: "Delete models from task"
request {
type: object
required: [ task, models ]
properties {
task {
description: "ID of the task"
type: string
}
models {
description: "The list of models to delete"
type: array
items {
type: object
required: [name, type]
properties {
name {
description: "The task model name"
type: string
}
type {
description: "The task model type"
"$ref": "#/definitions/model_type_enum"
}
}
}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [0, 1]
}
}
}
}
}
create {
"2.1" {
@ -912,6 +1045,16 @@ create {
}
}
}
"2.13": ${create."2.1"} {
request {
properties {
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
}
}
}
}
validate {
"2.1" {
@ -986,6 +1129,16 @@ validate {
additionalProperties: false
}
}
"2.13": ${validate."2.1"} {
request {
properties {
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
}
}
}
}
update {
"2.1" {
@ -1161,6 +1314,16 @@ edit {
}
}
}
"2.13": ${edit."2.1"} {
request {
properties {
models {
description: "Task models"
"$ref": "#/definitions/task_models"
}
}
}
}
}
reset {
"2.1" {

View File

@ -25,14 +25,14 @@ from apiserver.database.errors import translate_errors_context
from apiserver.database.model import validate_id
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
from apiserver.database.utils import (
parse_from_call,
get_company_or_none_constraint,
filter_fields,
)
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import conform_tag_fields, conform_output_tags
from apiserver.services.utils import conform_tag_fields, conform_output_tags, ModelsBackwardsCompatibility
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
@ -61,19 +61,20 @@ def get_by_id(call: APICall, company_id, _):
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call: APICall, company_id, _):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use models.get_by_id/get_all apis")
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["output"], **query)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.output:
raise errors.bad_request.MissingTaskFields(field="output")
if not task.output.model:
raise errors.bad_request.MissingTaskFields(field="output.model")
if not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
model_id = task.output.model
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
@ -186,6 +187,9 @@ def _reset_cached_tags(company: str, projects: Sequence[str]):
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call: APICall, company_id, _):
if call.requested_endpoint_version > ModelsBackwardsCompatibility.max_version:
raise errors.moved_permanently.NotSupported("use tasks.add_or_update_model")
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
@ -201,7 +205,7 @@ def update_for_task(call: APICall, company_id, _):
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["output", "execution", "name", "status", "project"],
_only=["models", "execution", "name", "status", "project"],
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
@ -226,12 +230,11 @@ def update_for_task(call: APICall, company_id, _):
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.output and task.output.model:
if task.models.output:
# model exists, update
res = _update_model(
call, company_id, model_id=task.output.model
).to_struct()
res.update({"id": task.output.model, "created": False})
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
@ -246,7 +249,7 @@ def update_for_task(call: APICall, company_id, _):
company=company_id,
project=task.project,
framework=task.execution.framework,
parent=task.execution.model,
parent=task.models.input[0].model if task.models.input else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
@ -259,7 +262,9 @@ def update_for_task(call: APICall, company_id, _):
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
output__model=model.id,
models__output=[
ModelItem(model=model.id, name=model.name, updated=datetime.utcnow())
],
)
call.result.data = {"id": model.id, "created": True}
@ -465,7 +470,7 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
deleted_model_id = f"{deleted_prefix}{model_id}"
using_tasks = Task.objects(execution__model=model_id).only("id")
using_tasks = Task.objects(models__input__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
@ -473,23 +478,32 @@ def delete(call: APICall, company_id, request: DeleteModelRequest):
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
using_tasks.update(
execution__model=deleted_model_id, upsert=False, multi=True
Task._get_collection().update_many(
filter={"_id": {"$in": [t.id for t in using_tasks]}},
update={"$set": {"models.input.$[elem].model": deleted_model_id}},
array_filters=[{"elem.model": model_id}],
upsert=False,
)
if model.task:
task = Task.objects(id=model.task).first()
task: Task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
if task.output and task.output.model == model_id:
if task.models.output and model_id in task.models.output:
now = datetime.utcnow()
task.update(
output__model=deleted_model_id,
output__error=f"model deleted on {now.isoformat()}",
last_change=now,
Task._get_collection().update_one(
filter={"_id": model.task, "models.output.model": model_id},
update={
"$set": {
"models.output.$[elem].model": deleted_model_id,
"output.error": f"model deleted on {now.isoformat()}",
},
"last_change": now,
},
array_filters=[{"elem.model": model_id}],
upsert=False,
)

View File

@ -44,6 +44,9 @@ from apiserver.apimodels.tasks import (
DeleteArtifactsRequest,
ArchiveResponse,
ArchiveRequest,
AddUpdateModelRequest,
DeleteModelsRequest,
ModelItemType,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.organization import OrgBLL, Tags
@ -79,12 +82,14 @@ from apiserver.database.model.task.task import (
DEFAULT_LAST_ITERATION,
Execution,
ArtifactModes,
ModelItem,
)
from apiserver.database.utils import get_fields_attr, parse_from_call
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
)
from apiserver.timing_context import TimingContext
from apiserver.utilities.partial_version import PartialVersion
@ -329,6 +334,7 @@ create_fields = {
"parent": Task,
"project": None,
"input": None,
"models": None,
"output_dest": None,
"execution": None,
"hyperparams": None,
@ -341,6 +347,7 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
conform_tag_fields(call, fields, validate=True)
params_prepare_for_save(fields, previous_task=previous_task)
artifacts_prepare_for_save(fields)
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_stripped_fields:
@ -361,6 +368,7 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
tasks_data = [tasks_data]
conform_output_tags(call, tasks_data)
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
for data in tasks_data:
need_legacy_params = call.requested_endpoint_version < PartialVersion("2.9")
@ -389,6 +397,17 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
# Add models updated time
models = fields.get("models")
if models:
now = datetime.utcnow()
for field in ("input", "output"):
field_models = models.get(field)
if not field_models:
continue
for model in field_models:
model["updated"] = now
return prepare_for_save(call, fields, previous_task=previous_task)
@ -456,6 +475,7 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
hyperparams=request.new_task_hyperparams,
configuration=request.new_task_configuration,
execution_overrides=request.execution_overrides,
input_models=request.new_task_input_models,
validate_references=request.validate_references,
new_project_name=request.new_project_name,
)
@ -886,8 +906,8 @@ def reset(call: APICall, company_id, request: ResetRequest):
set__last_iteration=DEFAULT_LAST_ITERATION,
set__last_metrics={},
set__metric_stats={},
set__models__output=[],
unset__output__result=1,
unset__output__model=1,
unset__output__error=1,
unset__last_worker=1,
unset__last_worker_report=1,
@ -1130,3 +1150,44 @@ def move(call: APICall, company_id: str, request: MoveRequest):
update_project_time(projects)
return {"project_id": project_id}
@endpoint("tasks.add_or_update_model", min_version="2.13")
def add_or_update_model(_: APICall, company_id: str, request: AddUpdateModelRequest):
TaskBLL.get_task_with_access(
request.task, company_id=company_id, requires_write_access=True, only=["id"]
)
models_field = f"models__{request.type}"
model = ModelItem(name=request.name, model=request.model, updated=datetime.utcnow())
query = {"id": request.task, f"{models_field}__name": request.name}
updated = Task.objects(**query).update_one(**{f"set__{models_field}__S": model})
updated = TaskBLL.update_statistics(
task_id=request.task,
company_id=company_id,
last_iteration_max=request.iteration,
**({f"push__{models_field}": model} if not updated else {}),
)
return {"updated": updated}
@endpoint("tasks.delete_models", min_version="2.13")
def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
task = TaskBLL.get_task_with_access(
request.task, company_id=company_id, requires_write_access=True, only=["id"]
)
delete_names = {
type_: [m.name for m in request.models if m.type == type_]
for type_ in get_options(ModelItemType)
}
commands = {
f"pull__models__{field}__name__in": names
for field, names in delete_names.items()
if names
}
updated = task.update(last_change=datetime.utcnow(), **commands,)
return {"updated": updated}

View File

@ -1,3 +1,4 @@
from datetime import datetime
from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
@ -5,6 +6,7 @@ from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.utils import partition_tags
from apiserver.service_repo import APICall
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
from apiserver.utilities.partial_version import PartialVersion
@ -84,3 +86,47 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
raise errors.bad_request.FieldsValueError(
"unsupported tag prefix", values=unsupported
)
class ModelsBackwardsCompatibility:
max_version = PartialVersion("2.13")
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
models_field = "models"
@classmethod
def prepare_for_save(cls, call: APICall, fields: dict):
if call.requested_endpoint_version > cls.max_version:
return
for mode, field in cls.mode_to_fields.items():
value = nested_get(fields, field)
if not value:
continue
nested_delete(fields, field)
nested_set(
fields,
(cls.models_field, mode),
value=[dict(name=mode, model=value, updated=datetime.utcnow())],
)
@classmethod
def unprepare_from_saved(
cls, call: APICall, tasks_data: Union[Sequence[dict], dict]
):
if call.requested_endpoint_version > cls.max_version:
return
if isinstance(tasks_data, dict):
tasks_data = [tasks_data]
for task in tasks_data:
for mode, field in cls.mode_to_fields.items():
models = nested_get(task, (cls.models_field, mode))
if not models:
continue
model = models[0] if mode == "input" else models[-1]
if model:
nested_set(task, field, model.get("model"))

View File

@ -0,0 +1,111 @@
from copy import deepcopy
from typing import Sequence, Optional
from apiserver.tests.automated import TestService
class TestTaskModels(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
def test_new_apis(self):
# no models
empty_task = self.new_task()
self.assertModels(empty_task, [], [])
id1, id2 = self.new_model("model1"), self.new_model("model2")
input_models = [
{"name": "input1", "model": id1},
{"name": "input2", "model": id2},
]
output_models = [
{"name": "output1", "model": "id3"},
{"name": "output2", "model": "id4"},
]
# task creation with models
task = self.new_task(models={"input": input_models, "output": output_models})
self.assertModels(task, input_models, output_models)
# add_or_update existing model
res = self.api.tasks.add_or_update_model(
task=task, name="input1", type="input", model="Test"
)
self.assertEqual(res.updated, 1)
modified_input = deepcopy(input_models)
modified_input[0]["model"] = "Test"
self.assertModels(task, modified_input, output_models)
# add_or_update new mode
res = self.api.tasks.add_or_update_model(
task=task, name="output3", type="output", model="TestOutput"
)
self.assertEqual(res.updated, 1)
modified_output = deepcopy(output_models)
modified_output.append({"name": "output3", "model": "TestOutput"})
self.assertModels(task, modified_input, modified_output)
# task editing
self.api.tasks.edit(
task=task, models={"input": input_models, "output": output_models}
)
self.assertModels(task, input_models, output_models)
# delete models
res = self.api.tasks.delete_models(
task=task,
models=[
{"name": "input1", "type": "input"},
{"name": "input2", "type": "input"},
{"name": "output1", "type": "output"},
{"name": "not_existing", "type": "output"},
]
)
self.assertEqual(res.updated, 1)
self.assertModels(task, [], output_models[1:])
def assertModels(
self, task_id: str, input_models: Sequence[dict], output_models: Sequence[dict],
):
def get_model_id(model: dict) -> Optional[str]:
if not model:
return None
id_ = model.get("model")
if isinstance(id_, str):
return id_
if id_ is None or id_ == {}:
return None
return id_.get("id")
def compare_models(actual: Sequence[dict], expected: Sequence[dict]):
self.assertEqual(
[(m["name"], get_model_id(m)) for m in actual],
[(m["name"], m["model"]) for m in expected],
)
for task in (
self.api.tasks.get_all_ex(id=task_id).tasks[0],
self.api.tasks.get_all(id=task_id).tasks[0],
self.api.tasks.get_by_id(task=task_id).task,
):
compare_models(task.models.input, input_models)
compare_models(task.models.output, output_models)
self.assertEqual(
get_model_id(task.execution),
input_models[0]["model"] if input_models else None,
)
self.assertEqual(
get_model_id(task.output),
output_models[-1]["model"] if output_models else None,
)
def new_task(self, **kwargs):
self.update_missing(
kwargs, type="testing", name="test task models", input=dict(view=dict())
)
return self.create_temp("tasks", **kwargs)
def new_model(self, name: str, **kwargs):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)