mirror of
https://github.com/clearml/clearml-server
synced 2025-04-26 00:49:45 +00:00
Rename default input and output models
Better handling of backwards compatibility in task models Code cleanup
This commit is contained in:
parent
3d22ca1888
commit
179661a0d4
apiserver
apimodels
bll/task
database/model/task
mongo
services
@ -12,6 +12,7 @@ from apiserver.database.model.task.task import (
|
||||
TaskType,
|
||||
ArtifactModes,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
|
||||
@ -279,21 +280,16 @@ class PublishManyRequest(TaskBatchRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class ModelItemType(object):
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
class AddUpdateModelRequest(TaskRequest):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
iteration = IntField()
|
||||
|
||||
|
||||
class ModelItemKey(models.Base):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||
|
||||
|
||||
class DeleteModelsRequest(TaskRequest):
|
||||
|
@ -29,6 +29,8 @@ from apiserver.database.model.task.task import (
|
||||
ModelItem,
|
||||
Models,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
TaskModelNames,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
@ -196,13 +198,21 @@ class TaskBLL:
|
||||
|
||||
now = datetime.utcnow()
|
||||
if input_models:
|
||||
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models]
|
||||
input_models = [
|
||||
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
|
||||
]
|
||||
|
||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||
if execution_overrides:
|
||||
execution_model = execution_overrides.pop("model", None)
|
||||
if not input_models and execution_model:
|
||||
input_models = [ModelItem(model=execution_model, name="input")]
|
||||
input_models = [
|
||||
ModelItem(
|
||||
model=execution_model,
|
||||
name=TaskModelNames[TaskModelTypes.input],
|
||||
updated=now,
|
||||
)
|
||||
]
|
||||
|
||||
docker_cmd = execution_overrides.pop("docker_cmd", None)
|
||||
if not container and docker_cmd:
|
||||
|
@ -106,6 +106,17 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
||||
description = StringField()
|
||||
|
||||
|
||||
class TaskModelTypes:
|
||||
input = "input"
|
||||
output = "output"
|
||||
|
||||
|
||||
TaskModelNames = {
|
||||
TaskModelTypes.input: "Input Model",
|
||||
TaskModelTypes.output: "Output Model",
|
||||
}
|
||||
|
||||
|
||||
class ModelItem(EmbeddedDocument, ProperDictMixin):
|
||||
name = StringField(required=True)
|
||||
model = StringField(required=True, reference_field="Model")
|
||||
|
@ -44,7 +44,13 @@ from apiserver.config.info import get_default_company
|
||||
from apiserver.database.model import EntityVisibility, User
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
ArtifactModes,
|
||||
TaskStatus,
|
||||
TaskModelTypes,
|
||||
TaskModelNames,
|
||||
)
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
@ -778,19 +784,20 @@ class PrePopulate:
|
||||
models = task_data.get("models", {})
|
||||
now = datetime.utcnow()
|
||||
for old_field, type_ in (
|
||||
("execution.model", "input"),
|
||||
("output.model", "output"),
|
||||
("execution.model", TaskModelTypes.input),
|
||||
("output.model", TaskModelTypes.output),
|
||||
):
|
||||
old_path = old_field.split(".")
|
||||
old_model = nested_get(task_data, old_path)
|
||||
new_models = models.get(type_, [])
|
||||
name = TaskModelNames[type_]
|
||||
if old_model and not any(
|
||||
m
|
||||
for m in new_models
|
||||
if m.get("model") == old_model or m.get("name") == type_
|
||||
if m.get("model") == old_model or m.get("name") == name
|
||||
):
|
||||
model_item = {"model": old_model, "name": type_, "updated": now}
|
||||
if type_ == "input":
|
||||
model_item = {"model": old_model, "name": name, "updated": now}
|
||||
if type_ == TaskModelTypes.input:
|
||||
new_models = [model_item, *new_models]
|
||||
else:
|
||||
new_models = [*new_models, model_item]
|
||||
|
@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
from apiserver.services.utils import escape_dict
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
@ -17,8 +18,6 @@ def _migrate_task_models(db: Database):
|
||||
models: Collection = db["model"]
|
||||
|
||||
models_field = "models"
|
||||
input = "input"
|
||||
output = "output"
|
||||
now = datetime.utcnow()
|
||||
|
||||
pipeline = [
|
||||
@ -26,7 +25,7 @@ def _migrate_task_models(db: Database):
|
||||
{"$project": {"name": 1, "task": 1}},
|
||||
{"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}},
|
||||
]
|
||||
output_models = f"{models_field}.{output}"
|
||||
output_models = f"{models_field}.{TaskModelTypes.output}"
|
||||
for group in models.aggregate(pipeline=pipeline, allowDiskUse=True):
|
||||
task_id = group.get("_id")
|
||||
task_models = group.get("models")
|
||||
@ -41,19 +40,17 @@ def _migrate_task_models(db: Database):
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
fields = {input: "execution.model", output: "output.model"}
|
||||
query = {
|
||||
"$or": [
|
||||
{field: {"$exists": True}} for field in fields.values()
|
||||
]
|
||||
fields = {
|
||||
TaskModelTypes.input: "execution.model",
|
||||
TaskModelTypes.output: "output.model",
|
||||
}
|
||||
query = {"$or": [{field: {"$exists": True}} for field in fields.values()]}
|
||||
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
|
||||
set_commands = {}
|
||||
for mode, field in fields.items():
|
||||
value = nested_get(doc, field.split("."))
|
||||
if value:
|
||||
model_doc = models.find_one(filter={"_id": value}, projection=["name"])
|
||||
name = model_doc.get("name", mode) if model_doc else mode
|
||||
name = TaskModelNames[mode]
|
||||
model_item = {"model": value, "name": name, "updated": now}
|
||||
existing_models = nested_get(doc, (models_field, mode), default=[])
|
||||
existing_models = (
|
||||
@ -61,7 +58,7 @@ def _migrate_task_models(db: Database):
|
||||
for m in existing_models
|
||||
if m.get("name") != name and m.get("model") != value
|
||||
)
|
||||
if mode == input:
|
||||
if mode == TaskModelTypes.input:
|
||||
updated_models = [model_item, *existing_models]
|
||||
else:
|
||||
updated_models = [*existing_models, model_item]
|
||||
@ -94,7 +91,7 @@ def _migrate_docker_cmd(db: Database):
|
||||
{
|
||||
"$unset": {docker_cmd_field: 1},
|
||||
**({"$set": set_commands} if set_commands else {}),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -116,12 +113,7 @@ def _migrate_model_labels(db: Database):
|
||||
set_commands[field] = escaped
|
||||
|
||||
if set_commands:
|
||||
tasks.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{
|
||||
"$set": set_commands
|
||||
}
|
||||
)
|
||||
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
|
@ -38,7 +38,13 @@ from apiserver.database.model import validate_id
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
|
||||
from apiserver.database.model.task.task import (
|
||||
Task,
|
||||
TaskStatus,
|
||||
ModelItem,
|
||||
TaskModelNames,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import (
|
||||
parse_from_call,
|
||||
get_company_or_none_constraint,
|
||||
@ -287,7 +293,11 @@ def update_for_task(call: APICall, company_id, _):
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(model=model.id, name=model.name, updated=datetime.utcnow())
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -47,7 +47,6 @@ from apiserver.apimodels.tasks import (
|
||||
ArchiveRequest,
|
||||
AddUpdateModelRequest,
|
||||
DeleteModelsRequest,
|
||||
ModelItemType,
|
||||
StopManyResponse,
|
||||
StopManyRequest,
|
||||
EnqueueManyRequest,
|
||||
@ -98,6 +97,7 @@ from apiserver.database.model.task.task import (
|
||||
TaskStatus,
|
||||
Script,
|
||||
ModelItem,
|
||||
TaskModelTypes,
|
||||
)
|
||||
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
@ -458,7 +458,7 @@ def prepare_create_fields(
|
||||
models = fields.get("models")
|
||||
if models:
|
||||
now = datetime.utcnow()
|
||||
for field in ("input", "output"):
|
||||
for field in (TaskModelTypes.input, TaskModelTypes.output):
|
||||
field_models = models.get(field)
|
||||
if not field_models:
|
||||
continue
|
||||
@ -1242,7 +1242,7 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
||||
|
||||
delete_names = {
|
||||
type_: [m.name for m in request.models if m.type == type_]
|
||||
for type_ in get_options(ModelItemType)
|
||||
for type_ in get_options(TaskModelTypes)
|
||||
}
|
||||
commands = {
|
||||
f"pull__models__{field}__name__in": names
|
||||
|
@ -5,6 +5,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||
from apiserver.apimodels.organization import Filter
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
from apiserver.database.utils import partition_tags
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
||||
@ -135,7 +136,10 @@ def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
||||
|
||||
class ModelsBackwardsCompatibility:
|
||||
max_version = PartialVersion("2.13")
|
||||
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
|
||||
mode_to_fields = {
|
||||
TaskModelTypes.input: ("execution", "model"),
|
||||
TaskModelTypes.output: ("output", "model"),
|
||||
}
|
||||
models_field = "models"
|
||||
|
||||
@classmethod
|
||||
@ -149,7 +153,13 @@ class ModelsBackwardsCompatibility:
|
||||
nested_set(
|
||||
fields,
|
||||
(cls.models_field, mode),
|
||||
value=[dict(name=mode, model=value, updated=datetime.utcnow())],
|
||||
value=[
|
||||
dict(
|
||||
name=TaskModelNames[mode],
|
||||
model=value,
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
nested_delete(fields, field)
|
||||
@ -170,7 +180,7 @@ class ModelsBackwardsCompatibility:
|
||||
if not models:
|
||||
continue
|
||||
|
||||
model = models[0] if mode == "input" else models[-1]
|
||||
model = models[0] if mode == TaskModelTypes.input else models[-1]
|
||||
if model:
|
||||
nested_set(task, field, model.get("model"))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user