Rename default input and output models

Better handling of backwards compatibility in task models
Code cleanup
This commit is contained in:
allegroai 2021-05-03 17:56:50 +03:00
parent 3d22ca1888
commit 179661a0d4
8 changed files with 77 additions and 41 deletions
apiserver
apimodels
bll/task
database/model/task
mongo
initialize
migrations
services

View File

@ -12,6 +12,7 @@ from apiserver.database.model.task.task import (
TaskType,
ArtifactModes,
DEFAULT_ARTIFACT_MODE,
TaskModelTypes,
)
from apiserver.database.utils import get_options
@ -279,21 +280,16 @@ class PublishManyRequest(TaskBatchRequest):
force = BoolField(default=False)
class ModelItemType(object):
input = "input"
output = "output"
class AddUpdateModelRequest(TaskRequest):
name = StringField(required=True)
model = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
iteration = IntField()
class ModelItemKey(models.Base):
name = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
class DeleteModelsRequest(TaskRequest):

View File

@ -29,6 +29,8 @@ from apiserver.database.model.task.task import (
ModelItem,
Models,
DEFAULT_ARTIFACT_MODE,
TaskModelNames,
TaskModelTypes,
)
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
@ -196,13 +198,21 @@ class TaskBLL:
now = datetime.utcnow()
if input_models:
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models]
input_models = [
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
]
execution_dict = task.execution.to_proper_dict() if task.execution else {}
if execution_overrides:
execution_model = execution_overrides.pop("model", None)
if not input_models and execution_model:
input_models = [ModelItem(model=execution_model, name="input")]
input_models = [
ModelItem(
model=execution_model,
name=TaskModelNames[TaskModelTypes.input],
updated=now,
)
]
docker_cmd = execution_overrides.pop("docker_cmd", None)
if not container and docker_cmd:

View File

@ -106,6 +106,17 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
description = StringField()
class TaskModelTypes:
input = "input"
output = "output"
TaskModelNames = {
TaskModelTypes.input: "Input Model",
TaskModelTypes.output: "Output Model",
}
class ModelItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
model = StringField(required=True, reference_field="Model")

View File

@ -44,7 +44,13 @@ from apiserver.config.info import get_default_company
from apiserver.database.model import EntityVisibility, User
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
from apiserver.database.model.task.task import (
Task,
ArtifactModes,
TaskStatus,
TaskModelTypes,
TaskModelNames,
)
from apiserver.database.utils import get_options
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
@ -778,19 +784,20 @@ class PrePopulate:
models = task_data.get("models", {})
now = datetime.utcnow()
for old_field, type_ in (
("execution.model", "input"),
("output.model", "output"),
("execution.model", TaskModelTypes.input),
("output.model", TaskModelTypes.output),
):
old_path = old_field.split(".")
old_model = nested_get(task_data, old_path)
new_models = models.get(type_, [])
name = TaskModelNames[type_]
if old_model and not any(
m
for m in new_models
if m.get("model") == old_model or m.get("name") == type_
if m.get("model") == old_model or m.get("name") == name
):
model_item = {"model": old_model, "name": type_, "updated": now}
if type_ == "input":
model_item = {"model": old_model, "name": name, "updated": now}
if type_ == TaskModelTypes.input:
new_models = [model_item, *new_models]
else:
new_models = [*new_models, model_item]

View File

@ -3,6 +3,7 @@ from datetime import datetime
from pymongo.collection import Collection
from pymongo.database import Database
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.services.utils import escape_dict
from apiserver.utilities.dicts import nested_get
from .utils import _drop_all_indices_from_collections
@ -17,8 +18,6 @@ def _migrate_task_models(db: Database):
models: Collection = db["model"]
models_field = "models"
input = "input"
output = "output"
now = datetime.utcnow()
pipeline = [
@ -26,7 +25,7 @@ def _migrate_task_models(db: Database):
{"$project": {"name": 1, "task": 1}},
{"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}},
]
output_models = f"{models_field}.{output}"
output_models = f"{models_field}.{TaskModelTypes.output}"
for group in models.aggregate(pipeline=pipeline, allowDiskUse=True):
task_id = group.get("_id")
task_models = group.get("models")
@ -41,19 +40,17 @@ def _migrate_task_models(db: Database):
upsert=False,
)
fields = {input: "execution.model", output: "output.model"}
query = {
"$or": [
{field: {"$exists": True}} for field in fields.values()
]
fields = {
TaskModelTypes.input: "execution.model",
TaskModelTypes.output: "output.model",
}
query = {"$or": [{field: {"$exists": True}} for field in fields.values()]}
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
set_commands = {}
for mode, field in fields.items():
value = nested_get(doc, field.split("."))
if value:
model_doc = models.find_one(filter={"_id": value}, projection=["name"])
name = model_doc.get("name", mode) if model_doc else mode
name = TaskModelNames[mode]
model_item = {"model": value, "name": name, "updated": now}
existing_models = nested_get(doc, (models_field, mode), default=[])
existing_models = (
@ -61,7 +58,7 @@ def _migrate_task_models(db: Database):
for m in existing_models
if m.get("name") != name and m.get("model") != value
)
if mode == input:
if mode == TaskModelTypes.input:
updated_models = [model_item, *existing_models]
else:
updated_models = [*existing_models, model_item]
@ -94,7 +91,7 @@ def _migrate_docker_cmd(db: Database):
{
"$unset": {docker_cmd_field: 1},
**({"$set": set_commands} if set_commands else {}),
}
},
)
@ -116,12 +113,7 @@ def _migrate_model_labels(db: Database):
set_commands[field] = escaped
if set_commands:
tasks.update_one(
{"_id": doc["_id"]},
{
"$set": set_commands
}
)
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
def migrate_backend(db: Database):

View File

@ -38,7 +38,13 @@ from apiserver.database.model import validate_id
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
from apiserver.database.model.task.task import (
Task,
TaskStatus,
ModelItem,
TaskModelNames,
TaskModelTypes,
)
from apiserver.database.utils import (
parse_from_call,
get_company_or_none_constraint,
@ -287,7 +293,11 @@ def update_for_task(call: APICall, company_id, _):
company_id=company_id,
last_iteration_max=iteration,
models__output=[
ModelItem(model=model.id, name=model.name, updated=datetime.utcnow())
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)

View File

@ -47,7 +47,6 @@ from apiserver.apimodels.tasks import (
ArchiveRequest,
AddUpdateModelRequest,
DeleteModelsRequest,
ModelItemType,
StopManyResponse,
StopManyRequest,
EnqueueManyRequest,
@ -98,6 +97,7 @@ from apiserver.database.model.task.task import (
TaskStatus,
Script,
ModelItem,
TaskModelTypes,
)
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
from apiserver.service_repo import APICall, endpoint
@ -458,7 +458,7 @@ def prepare_create_fields(
models = fields.get("models")
if models:
now = datetime.utcnow()
for field in ("input", "output"):
for field in (TaskModelTypes.input, TaskModelTypes.output):
field_models = models.get(field)
if not field_models:
continue
@ -1242,7 +1242,7 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
delete_names = {
type_: [m.name for m in request.models if m.type == type_]
for type_ in get_options(ModelItemType)
for type_ in get_options(TaskModelTypes)
}
commands = {
f"pull__models__{field}__name__in": names

View File

@ -5,6 +5,7 @@ from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
from apiserver.database.utils import partition_tags
from apiserver.service_repo import APICall
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
@ -135,7 +136,10 @@ def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
class ModelsBackwardsCompatibility:
max_version = PartialVersion("2.13")
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
mode_to_fields = {
TaskModelTypes.input: ("execution", "model"),
TaskModelTypes.output: ("output", "model"),
}
models_field = "models"
@classmethod
@ -149,7 +153,13 @@ class ModelsBackwardsCompatibility:
nested_set(
fields,
(cls.models_field, mode),
value=[dict(name=mode, model=value, updated=datetime.utcnow())],
value=[
dict(
name=TaskModelNames[mode],
model=value,
updated=datetime.utcnow(),
)
],
)
nested_delete(fields, field)
@ -170,7 +180,7 @@ class ModelsBackwardsCompatibility:
if not models:
continue
model = models[0] if mode == "input" else models[-1]
model = models[0] if mode == TaskModelTypes.input else models[-1]
if model:
nested_set(task, field, model.get("model"))