Rename default input and output models

Better handling of backwards compatibility in task models
Code cleanup
This commit is contained in:
allegroai 2021-05-03 17:56:50 +03:00
parent 3d22ca1888
commit 179661a0d4
8 changed files with 77 additions and 41 deletions

View File

@ -12,6 +12,7 @@ from apiserver.database.model.task.task import (
TaskType, TaskType,
ArtifactModes, ArtifactModes,
DEFAULT_ARTIFACT_MODE, DEFAULT_ARTIFACT_MODE,
TaskModelTypes,
) )
from apiserver.database.utils import get_options from apiserver.database.utils import get_options
@ -279,21 +280,16 @@ class PublishManyRequest(TaskBatchRequest):
force = BoolField(default=False) force = BoolField(default=False)
class ModelItemType(object):
input = "input"
output = "output"
class AddUpdateModelRequest(TaskRequest): class AddUpdateModelRequest(TaskRequest):
name = StringField(required=True) name = StringField(required=True)
model = StringField(required=True) model = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(ModelItemType))) type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
iteration = IntField() iteration = IntField()
class ModelItemKey(models.Base): class ModelItemKey(models.Base):
name = StringField(required=True) name = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(ModelItemType))) type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
class DeleteModelsRequest(TaskRequest): class DeleteModelsRequest(TaskRequest):

View File

@ -29,6 +29,8 @@ from apiserver.database.model.task.task import (
ModelItem, ModelItem,
Models, Models,
DEFAULT_ARTIFACT_MODE, DEFAULT_ARTIFACT_MODE,
TaskModelNames,
TaskModelTypes,
) )
from apiserver.database.model import EntityVisibility from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id from apiserver.database.utils import get_company_or_none_constraint, id as create_id
@ -196,13 +198,21 @@ class TaskBLL:
now = datetime.utcnow() now = datetime.utcnow()
if input_models: if input_models:
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models] input_models = [
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
]
execution_dict = task.execution.to_proper_dict() if task.execution else {} execution_dict = task.execution.to_proper_dict() if task.execution else {}
if execution_overrides: if execution_overrides:
execution_model = execution_overrides.pop("model", None) execution_model = execution_overrides.pop("model", None)
if not input_models and execution_model: if not input_models and execution_model:
input_models = [ModelItem(model=execution_model, name="input")] input_models = [
ModelItem(
model=execution_model,
name=TaskModelNames[TaskModelTypes.input],
updated=now,
)
]
docker_cmd = execution_overrides.pop("docker_cmd", None) docker_cmd = execution_overrides.pop("docker_cmd", None)
if not container and docker_cmd: if not container and docker_cmd:

View File

@ -106,6 +106,17 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
description = StringField() description = StringField()
class TaskModelTypes:
input = "input"
output = "output"
TaskModelNames = {
TaskModelTypes.input: "Input Model",
TaskModelTypes.output: "Output Model",
}
class ModelItem(EmbeddedDocument, ProperDictMixin): class ModelItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True) name = StringField(required=True)
model = StringField(required=True, reference_field="Model") model = StringField(required=True, reference_field="Model")

View File

@ -44,7 +44,13 @@ from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus from apiserver.database.model.task.task import (
Task,
ArtifactModes,
TaskStatus,
TaskModelTypes,
TaskModelNames,
)
from apiserver.database.utils import get_options from apiserver.database.utils import get_options
from apiserver.utilities import json from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
@ -778,19 +784,20 @@ class PrePopulate:
models = task_data.get("models", {}) models = task_data.get("models", {})
now = datetime.utcnow() now = datetime.utcnow()
for old_field, type_ in ( for old_field, type_ in (
("execution.model", "input"), ("execution.model", TaskModelTypes.input),
("output.model", "output"), ("output.model", TaskModelTypes.output),
): ):
old_path = old_field.split(".") old_path = old_field.split(".")
old_model = nested_get(task_data, old_path) old_model = nested_get(task_data, old_path)
new_models = models.get(type_, []) new_models = models.get(type_, [])
name = TaskModelNames[type_]
if old_model and not any( if old_model and not any(
m m
for m in new_models for m in new_models
if m.get("model") == old_model or m.get("name") == type_ if m.get("model") == old_model or m.get("name") == name
): ):
model_item = {"model": old_model, "name": type_, "updated": now} model_item = {"model": old_model, "name": name, "updated": now}
if type_ == "input": if type_ == TaskModelTypes.input:
new_models = [model_item, *new_models] new_models = [model_item, *new_models]
else: else:
new_models = [*new_models, model_item] new_models = [*new_models, model_item]

View File

@ -3,6 +3,7 @@ from datetime import datetime
from pymongo.collection import Collection from pymongo.collection import Collection
from pymongo.database import Database from pymongo.database import Database
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.services.utils import escape_dict from apiserver.services.utils import escape_dict
from apiserver.utilities.dicts import nested_get from apiserver.utilities.dicts import nested_get
from .utils import _drop_all_indices_from_collections from .utils import _drop_all_indices_from_collections
@ -17,8 +18,6 @@ def _migrate_task_models(db: Database):
models: Collection = db["model"] models: Collection = db["model"]
models_field = "models" models_field = "models"
input = "input"
output = "output"
now = datetime.utcnow() now = datetime.utcnow()
pipeline = [ pipeline = [
@ -26,7 +25,7 @@ def _migrate_task_models(db: Database):
{"$project": {"name": 1, "task": 1}}, {"$project": {"name": 1, "task": 1}},
{"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}}, {"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}},
] ]
output_models = f"{models_field}.{output}" output_models = f"{models_field}.{TaskModelTypes.output}"
for group in models.aggregate(pipeline=pipeline, allowDiskUse=True): for group in models.aggregate(pipeline=pipeline, allowDiskUse=True):
task_id = group.get("_id") task_id = group.get("_id")
task_models = group.get("models") task_models = group.get("models")
@ -41,19 +40,17 @@ def _migrate_task_models(db: Database):
upsert=False, upsert=False,
) )
fields = {input: "execution.model", output: "output.model"} fields = {
query = { TaskModelTypes.input: "execution.model",
"$or": [ TaskModelTypes.output: "output.model",
{field: {"$exists": True}} for field in fields.values()
]
} }
query = {"$or": [{field: {"$exists": True}} for field in fields.values()]}
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]): for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
set_commands = {} set_commands = {}
for mode, field in fields.items(): for mode, field in fields.items():
value = nested_get(doc, field.split(".")) value = nested_get(doc, field.split("."))
if value: if value:
model_doc = models.find_one(filter={"_id": value}, projection=["name"]) name = TaskModelNames[mode]
name = model_doc.get("name", mode) if model_doc else mode
model_item = {"model": value, "name": name, "updated": now} model_item = {"model": value, "name": name, "updated": now}
existing_models = nested_get(doc, (models_field, mode), default=[]) existing_models = nested_get(doc, (models_field, mode), default=[])
existing_models = ( existing_models = (
@ -61,7 +58,7 @@ def _migrate_task_models(db: Database):
for m in existing_models for m in existing_models
if m.get("name") != name and m.get("model") != value if m.get("name") != name and m.get("model") != value
) )
if mode == input: if mode == TaskModelTypes.input:
updated_models = [model_item, *existing_models] updated_models = [model_item, *existing_models]
else: else:
updated_models = [*existing_models, model_item] updated_models = [*existing_models, model_item]
@ -94,7 +91,7 @@ def _migrate_docker_cmd(db: Database):
{ {
"$unset": {docker_cmd_field: 1}, "$unset": {docker_cmd_field: 1},
**({"$set": set_commands} if set_commands else {}), **({"$set": set_commands} if set_commands else {}),
} },
) )
@ -116,12 +113,7 @@ def _migrate_model_labels(db: Database):
set_commands[field] = escaped set_commands[field] = escaped
if set_commands: if set_commands:
tasks.update_one( tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
{"_id": doc["_id"]},
{
"$set": set_commands
}
)
def migrate_backend(db: Database): def migrate_backend(db: Database):

View File

@ -38,7 +38,13 @@ from apiserver.database.model import validate_id
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem from apiserver.database.model.task.task import (
Task,
TaskStatus,
ModelItem,
TaskModelNames,
TaskModelTypes,
)
from apiserver.database.utils import ( from apiserver.database.utils import (
parse_from_call, parse_from_call,
get_company_or_none_constraint, get_company_or_none_constraint,
@ -287,7 +293,11 @@ def update_for_task(call: APICall, company_id, _):
company_id=company_id, company_id=company_id,
last_iteration_max=iteration, last_iteration_max=iteration,
models__output=[ models__output=[
ModelItem(model=model.id, name=model.name, updated=datetime.utcnow()) ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
], ],
) )

View File

@ -47,7 +47,6 @@ from apiserver.apimodels.tasks import (
ArchiveRequest, ArchiveRequest,
AddUpdateModelRequest, AddUpdateModelRequest,
DeleteModelsRequest, DeleteModelsRequest,
ModelItemType,
StopManyResponse, StopManyResponse,
StopManyRequest, StopManyRequest,
EnqueueManyRequest, EnqueueManyRequest,
@ -98,6 +97,7 @@ from apiserver.database.model.task.task import (
TaskStatus, TaskStatus,
Script, Script,
ModelItem, ModelItem,
TaskModelTypes,
) )
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
from apiserver.service_repo import APICall, endpoint from apiserver.service_repo import APICall, endpoint
@ -458,7 +458,7 @@ def prepare_create_fields(
models = fields.get("models") models = fields.get("models")
if models: if models:
now = datetime.utcnow() now = datetime.utcnow()
for field in ("input", "output"): for field in (TaskModelTypes.input, TaskModelTypes.output):
field_models = models.get(field) field_models = models.get(field)
if not field_models: if not field_models:
continue continue
@ -1242,7 +1242,7 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
delete_names = { delete_names = {
type_: [m.name for m in request.models if m.type == type_] type_: [m.name for m in request.models if m.type == type_]
for type_ in get_options(ModelItemType) for type_ in get_options(TaskModelTypes)
} }
commands = { commands = {
f"pull__models__{field}__name__in": names f"pull__models__{field}__name__in": names

View File

@ -5,6 +5,7 @@ from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
from apiserver.apimodels.organization import Filter from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.database.utils import partition_tags from apiserver.database.utils import partition_tags
from apiserver.service_repo import APICall from apiserver.service_repo import APICall
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
@ -135,7 +136,10 @@ def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
class ModelsBackwardsCompatibility: class ModelsBackwardsCompatibility:
max_version = PartialVersion("2.13") max_version = PartialVersion("2.13")
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")} mode_to_fields = {
TaskModelTypes.input: ("execution", "model"),
TaskModelTypes.output: ("output", "model"),
}
models_field = "models" models_field = "models"
@classmethod @classmethod
@ -149,7 +153,13 @@ class ModelsBackwardsCompatibility:
nested_set( nested_set(
fields, fields,
(cls.models_field, mode), (cls.models_field, mode),
value=[dict(name=mode, model=value, updated=datetime.utcnow())], value=[
dict(
name=TaskModelNames[mode],
model=value,
updated=datetime.utcnow(),
)
],
) )
nested_delete(fields, field) nested_delete(fields, field)
@ -170,7 +180,7 @@ class ModelsBackwardsCompatibility:
if not models: if not models:
continue continue
model = models[0] if mode == "input" else models[-1] model = models[0] if mode == TaskModelTypes.input else models[-1]
if model: if model:
nested_set(task, field, model.get("model")) nested_set(task, field, model.get("model"))