mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Rename default input and output models
Better handling of backwards compatibility in task models Code cleanup
This commit is contained in:
parent
3d22ca1888
commit
179661a0d4
@ -12,6 +12,7 @@ from apiserver.database.model.task.task import (
|
|||||||
TaskType,
|
TaskType,
|
||||||
ArtifactModes,
|
ArtifactModes,
|
||||||
DEFAULT_ARTIFACT_MODE,
|
DEFAULT_ARTIFACT_MODE,
|
||||||
|
TaskModelTypes,
|
||||||
)
|
)
|
||||||
from apiserver.database.utils import get_options
|
from apiserver.database.utils import get_options
|
||||||
|
|
||||||
@ -279,21 +280,16 @@ class PublishManyRequest(TaskBatchRequest):
|
|||||||
force = BoolField(default=False)
|
force = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class ModelItemType(object):
|
|
||||||
input = "input"
|
|
||||||
output = "output"
|
|
||||||
|
|
||||||
|
|
||||||
class AddUpdateModelRequest(TaskRequest):
|
class AddUpdateModelRequest(TaskRequest):
|
||||||
name = StringField(required=True)
|
name = StringField(required=True)
|
||||||
model = StringField(required=True)
|
model = StringField(required=True)
|
||||||
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
|
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||||
iteration = IntField()
|
iteration = IntField()
|
||||||
|
|
||||||
|
|
||||||
class ModelItemKey(models.Base):
|
class ModelItemKey(models.Base):
|
||||||
name = StringField(required=True)
|
name = StringField(required=True)
|
||||||
type = StringField(required=True, validators=Enum(*get_options(ModelItemType)))
|
type = StringField(required=True, validators=Enum(*get_options(TaskModelTypes)))
|
||||||
|
|
||||||
|
|
||||||
class DeleteModelsRequest(TaskRequest):
|
class DeleteModelsRequest(TaskRequest):
|
||||||
|
@ -29,6 +29,8 @@ from apiserver.database.model.task.task import (
|
|||||||
ModelItem,
|
ModelItem,
|
||||||
Models,
|
Models,
|
||||||
DEFAULT_ARTIFACT_MODE,
|
DEFAULT_ARTIFACT_MODE,
|
||||||
|
TaskModelNames,
|
||||||
|
TaskModelTypes,
|
||||||
)
|
)
|
||||||
from apiserver.database.model import EntityVisibility
|
from apiserver.database.model import EntityVisibility
|
||||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||||
@ -196,13 +198,21 @@ class TaskBLL:
|
|||||||
|
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
if input_models:
|
if input_models:
|
||||||
input_models = [ModelItem(model=m.model, name=m.name) for m in input_models]
|
input_models = [
|
||||||
|
ModelItem(model=m.model, name=m.name, updated=now) for m in input_models
|
||||||
|
]
|
||||||
|
|
||||||
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
execution_dict = task.execution.to_proper_dict() if task.execution else {}
|
||||||
if execution_overrides:
|
if execution_overrides:
|
||||||
execution_model = execution_overrides.pop("model", None)
|
execution_model = execution_overrides.pop("model", None)
|
||||||
if not input_models and execution_model:
|
if not input_models and execution_model:
|
||||||
input_models = [ModelItem(model=execution_model, name="input")]
|
input_models = [
|
||||||
|
ModelItem(
|
||||||
|
model=execution_model,
|
||||||
|
name=TaskModelNames[TaskModelTypes.input],
|
||||||
|
updated=now,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
docker_cmd = execution_overrides.pop("docker_cmd", None)
|
docker_cmd = execution_overrides.pop("docker_cmd", None)
|
||||||
if not container and docker_cmd:
|
if not container and docker_cmd:
|
||||||
|
@ -106,6 +106,17 @@ class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
|
|||||||
description = StringField()
|
description = StringField()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskModelTypes:
|
||||||
|
input = "input"
|
||||||
|
output = "output"
|
||||||
|
|
||||||
|
|
||||||
|
TaskModelNames = {
|
||||||
|
TaskModelTypes.input: "Input Model",
|
||||||
|
TaskModelTypes.output: "Output Model",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelItem(EmbeddedDocument, ProperDictMixin):
|
class ModelItem(EmbeddedDocument, ProperDictMixin):
|
||||||
name = StringField(required=True)
|
name = StringField(required=True)
|
||||||
model = StringField(required=True, reference_field="Model")
|
model = StringField(required=True, reference_field="Model")
|
||||||
|
@ -44,7 +44,13 @@ from apiserver.config.info import get_default_company
|
|||||||
from apiserver.database.model import EntityVisibility, User
|
from apiserver.database.model import EntityVisibility, User
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
from apiserver.database.model.project import Project
|
from apiserver.database.model.project import Project
|
||||||
from apiserver.database.model.task.task import Task, ArtifactModes, TaskStatus
|
from apiserver.database.model.task.task import (
|
||||||
|
Task,
|
||||||
|
ArtifactModes,
|
||||||
|
TaskStatus,
|
||||||
|
TaskModelTypes,
|
||||||
|
TaskModelNames,
|
||||||
|
)
|
||||||
from apiserver.database.utils import get_options
|
from apiserver.database.utils import get_options
|
||||||
from apiserver.utilities import json
|
from apiserver.utilities import json
|
||||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||||
@ -778,19 +784,20 @@ class PrePopulate:
|
|||||||
models = task_data.get("models", {})
|
models = task_data.get("models", {})
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
for old_field, type_ in (
|
for old_field, type_ in (
|
||||||
("execution.model", "input"),
|
("execution.model", TaskModelTypes.input),
|
||||||
("output.model", "output"),
|
("output.model", TaskModelTypes.output),
|
||||||
):
|
):
|
||||||
old_path = old_field.split(".")
|
old_path = old_field.split(".")
|
||||||
old_model = nested_get(task_data, old_path)
|
old_model = nested_get(task_data, old_path)
|
||||||
new_models = models.get(type_, [])
|
new_models = models.get(type_, [])
|
||||||
|
name = TaskModelNames[type_]
|
||||||
if old_model and not any(
|
if old_model and not any(
|
||||||
m
|
m
|
||||||
for m in new_models
|
for m in new_models
|
||||||
if m.get("model") == old_model or m.get("name") == type_
|
if m.get("model") == old_model or m.get("name") == name
|
||||||
):
|
):
|
||||||
model_item = {"model": old_model, "name": type_, "updated": now}
|
model_item = {"model": old_model, "name": name, "updated": now}
|
||||||
if type_ == "input":
|
if type_ == TaskModelTypes.input:
|
||||||
new_models = [model_item, *new_models]
|
new_models = [model_item, *new_models]
|
||||||
else:
|
else:
|
||||||
new_models = [*new_models, model_item]
|
new_models = [*new_models, model_item]
|
||||||
|
@ -3,6 +3,7 @@ from datetime import datetime
|
|||||||
from pymongo.collection import Collection
|
from pymongo.collection import Collection
|
||||||
from pymongo.database import Database
|
from pymongo.database import Database
|
||||||
|
|
||||||
|
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||||
from apiserver.services.utils import escape_dict
|
from apiserver.services.utils import escape_dict
|
||||||
from apiserver.utilities.dicts import nested_get
|
from apiserver.utilities.dicts import nested_get
|
||||||
from .utils import _drop_all_indices_from_collections
|
from .utils import _drop_all_indices_from_collections
|
||||||
@ -17,8 +18,6 @@ def _migrate_task_models(db: Database):
|
|||||||
models: Collection = db["model"]
|
models: Collection = db["model"]
|
||||||
|
|
||||||
models_field = "models"
|
models_field = "models"
|
||||||
input = "input"
|
|
||||||
output = "output"
|
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
|
|
||||||
pipeline = [
|
pipeline = [
|
||||||
@ -26,7 +25,7 @@ def _migrate_task_models(db: Database):
|
|||||||
{"$project": {"name": 1, "task": 1}},
|
{"$project": {"name": 1, "task": 1}},
|
||||||
{"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}},
|
{"$group": {"_id": "$task", "models": {"$push": "$$ROOT"}}},
|
||||||
]
|
]
|
||||||
output_models = f"{models_field}.{output}"
|
output_models = f"{models_field}.{TaskModelTypes.output}"
|
||||||
for group in models.aggregate(pipeline=pipeline, allowDiskUse=True):
|
for group in models.aggregate(pipeline=pipeline, allowDiskUse=True):
|
||||||
task_id = group.get("_id")
|
task_id = group.get("_id")
|
||||||
task_models = group.get("models")
|
task_models = group.get("models")
|
||||||
@ -41,19 +40,17 @@ def _migrate_task_models(db: Database):
|
|||||||
upsert=False,
|
upsert=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
fields = {input: "execution.model", output: "output.model"}
|
fields = {
|
||||||
query = {
|
TaskModelTypes.input: "execution.model",
|
||||||
"$or": [
|
TaskModelTypes.output: "output.model",
|
||||||
{field: {"$exists": True}} for field in fields.values()
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
query = {"$or": [{field: {"$exists": True}} for field in fields.values()]}
|
||||||
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
|
for doc in tasks.find(filter=query, projection=[*fields.values(), models_field]):
|
||||||
set_commands = {}
|
set_commands = {}
|
||||||
for mode, field in fields.items():
|
for mode, field in fields.items():
|
||||||
value = nested_get(doc, field.split("."))
|
value = nested_get(doc, field.split("."))
|
||||||
if value:
|
if value:
|
||||||
model_doc = models.find_one(filter={"_id": value}, projection=["name"])
|
name = TaskModelNames[mode]
|
||||||
name = model_doc.get("name", mode) if model_doc else mode
|
|
||||||
model_item = {"model": value, "name": name, "updated": now}
|
model_item = {"model": value, "name": name, "updated": now}
|
||||||
existing_models = nested_get(doc, (models_field, mode), default=[])
|
existing_models = nested_get(doc, (models_field, mode), default=[])
|
||||||
existing_models = (
|
existing_models = (
|
||||||
@ -61,7 +58,7 @@ def _migrate_task_models(db: Database):
|
|||||||
for m in existing_models
|
for m in existing_models
|
||||||
if m.get("name") != name and m.get("model") != value
|
if m.get("name") != name and m.get("model") != value
|
||||||
)
|
)
|
||||||
if mode == input:
|
if mode == TaskModelTypes.input:
|
||||||
updated_models = [model_item, *existing_models]
|
updated_models = [model_item, *existing_models]
|
||||||
else:
|
else:
|
||||||
updated_models = [*existing_models, model_item]
|
updated_models = [*existing_models, model_item]
|
||||||
@ -94,7 +91,7 @@ def _migrate_docker_cmd(db: Database):
|
|||||||
{
|
{
|
||||||
"$unset": {docker_cmd_field: 1},
|
"$unset": {docker_cmd_field: 1},
|
||||||
**({"$set": set_commands} if set_commands else {}),
|
**({"$set": set_commands} if set_commands else {}),
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -116,12 +113,7 @@ def _migrate_model_labels(db: Database):
|
|||||||
set_commands[field] = escaped
|
set_commands[field] = escaped
|
||||||
|
|
||||||
if set_commands:
|
if set_commands:
|
||||||
tasks.update_one(
|
tasks.update_one({"_id": doc["_id"]}, {"$set": set_commands})
|
||||||
{"_id": doc["_id"]},
|
|
||||||
{
|
|
||||||
"$set": set_commands
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_backend(db: Database):
|
def migrate_backend(db: Database):
|
||||||
|
@ -38,7 +38,13 @@ from apiserver.database.model import validate_id
|
|||||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
from apiserver.database.model.project import Project
|
from apiserver.database.model.project import Project
|
||||||
from apiserver.database.model.task.task import Task, TaskStatus, ModelItem
|
from apiserver.database.model.task.task import (
|
||||||
|
Task,
|
||||||
|
TaskStatus,
|
||||||
|
ModelItem,
|
||||||
|
TaskModelNames,
|
||||||
|
TaskModelTypes,
|
||||||
|
)
|
||||||
from apiserver.database.utils import (
|
from apiserver.database.utils import (
|
||||||
parse_from_call,
|
parse_from_call,
|
||||||
get_company_or_none_constraint,
|
get_company_or_none_constraint,
|
||||||
@ -287,7 +293,11 @@ def update_for_task(call: APICall, company_id, _):
|
|||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
last_iteration_max=iteration,
|
last_iteration_max=iteration,
|
||||||
models__output=[
|
models__output=[
|
||||||
ModelItem(model=model.id, name=model.name, updated=datetime.utcnow())
|
ModelItem(
|
||||||
|
model=model.id,
|
||||||
|
name=TaskModelNames[TaskModelTypes.output],
|
||||||
|
updated=datetime.utcnow(),
|
||||||
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -47,7 +47,6 @@ from apiserver.apimodels.tasks import (
|
|||||||
ArchiveRequest,
|
ArchiveRequest,
|
||||||
AddUpdateModelRequest,
|
AddUpdateModelRequest,
|
||||||
DeleteModelsRequest,
|
DeleteModelsRequest,
|
||||||
ModelItemType,
|
|
||||||
StopManyResponse,
|
StopManyResponse,
|
||||||
StopManyRequest,
|
StopManyRequest,
|
||||||
EnqueueManyRequest,
|
EnqueueManyRequest,
|
||||||
@ -98,6 +97,7 @@ from apiserver.database.model.task.task import (
|
|||||||
TaskStatus,
|
TaskStatus,
|
||||||
Script,
|
Script,
|
||||||
ModelItem,
|
ModelItem,
|
||||||
|
TaskModelTypes,
|
||||||
)
|
)
|
||||||
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
from apiserver.database.utils import get_fields_attr, parse_from_call, get_options
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
@ -458,7 +458,7 @@ def prepare_create_fields(
|
|||||||
models = fields.get("models")
|
models = fields.get("models")
|
||||||
if models:
|
if models:
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
for field in ("input", "output"):
|
for field in (TaskModelTypes.input, TaskModelTypes.output):
|
||||||
field_models = models.get(field)
|
field_models = models.get(field)
|
||||||
if not field_models:
|
if not field_models:
|
||||||
continue
|
continue
|
||||||
@ -1242,7 +1242,7 @@ def delete_models(_: APICall, company_id: str, request: DeleteModelsRequest):
|
|||||||
|
|
||||||
delete_names = {
|
delete_names = {
|
||||||
type_: [m.name for m in request.models if m.type == type_]
|
type_: [m.name for m in request.models if m.type == type_]
|
||||||
for type_ in get_options(ModelItemType)
|
for type_ in get_options(TaskModelTypes)
|
||||||
}
|
}
|
||||||
commands = {
|
commands = {
|
||||||
f"pull__models__{field}__name__in": names
|
f"pull__models__{field}__name__in": names
|
||||||
|
@ -5,6 +5,7 @@ from apiserver.apierrors import errors
|
|||||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||||
from apiserver.apimodels.organization import Filter
|
from apiserver.apimodels.organization import Filter
|
||||||
from apiserver.database.model.base import GetMixin
|
from apiserver.database.model.base import GetMixin
|
||||||
|
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||||
from apiserver.database.utils import partition_tags
|
from apiserver.database.utils import partition_tags
|
||||||
from apiserver.service_repo import APICall
|
from apiserver.service_repo import APICall
|
||||||
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
||||||
@ -135,7 +136,10 @@ def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
|||||||
|
|
||||||
class ModelsBackwardsCompatibility:
|
class ModelsBackwardsCompatibility:
|
||||||
max_version = PartialVersion("2.13")
|
max_version = PartialVersion("2.13")
|
||||||
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
|
mode_to_fields = {
|
||||||
|
TaskModelTypes.input: ("execution", "model"),
|
||||||
|
TaskModelTypes.output: ("output", "model"),
|
||||||
|
}
|
||||||
models_field = "models"
|
models_field = "models"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -149,7 +153,13 @@ class ModelsBackwardsCompatibility:
|
|||||||
nested_set(
|
nested_set(
|
||||||
fields,
|
fields,
|
||||||
(cls.models_field, mode),
|
(cls.models_field, mode),
|
||||||
value=[dict(name=mode, model=value, updated=datetime.utcnow())],
|
value=[
|
||||||
|
dict(
|
||||||
|
name=TaskModelNames[mode],
|
||||||
|
model=value,
|
||||||
|
updated=datetime.utcnow(),
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
nested_delete(fields, field)
|
nested_delete(fields, field)
|
||||||
@ -170,7 +180,7 @@ class ModelsBackwardsCompatibility:
|
|||||||
if not models:
|
if not models:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model = models[0] if mode == "input" else models[-1]
|
model = models[0] if mode == TaskModelTypes.input else models[-1]
|
||||||
if model:
|
if model:
|
||||||
nested_set(task, field, model.get("model"))
|
nested_set(task, field, model.get("model"))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user