mirror of
https://github.com/clearml/clearml-server
synced 2025-04-05 21:46:29 +00:00
Escape task.container and task.execution.model_labels fields in DB
This commit is contained in:
parent
fdf6798d0c
commit
3d22ca1888
@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
|
||||
nested_set(
|
||||
fields,
|
||||
artifacts_field,
|
||||
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
|
||||
value=sorted(artifacts.values(), key=itemgetter("key")),
|
||||
)
|
||||
|
||||
|
||||
|
@ -28,13 +28,14 @@ from apiserver.database.model.task.task import (
|
||||
ArtifactModes,
|
||||
ModelItem,
|
||||
Models,
|
||||
DEFAULT_ARTIFACT_MODE,
|
||||
)
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
@ -216,6 +217,8 @@ class TaskBLL:
|
||||
if legacy_value is not None:
|
||||
params_dict["execution"] = legacy_value
|
||||
|
||||
escape_dict_field(execution_overrides, "model_labels")
|
||||
|
||||
execution_dict.update(execution_overrides)
|
||||
|
||||
params_prepare_for_save(params_dict, previous_task=task)
|
||||
@ -225,7 +228,7 @@ class TaskBLL:
|
||||
execution_dict["artifacts"] = {
|
||||
k: a
|
||||
for k, a in artifacts.items()
|
||||
if a.get("mode") != ArtifactModes.output
|
||||
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
|
||||
}
|
||||
execution_dict.pop("queue", None)
|
||||
|
||||
@ -276,7 +279,7 @@ class TaskBLL:
|
||||
if task.output
|
||||
else None,
|
||||
models=Models(input=input_models or task.models.input),
|
||||
container=container or task.container,
|
||||
container=escape_dict(container) or task.container,
|
||||
execution=execution_dict,
|
||||
configuration=params_dict.get("configuration") or task.configuration,
|
||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||
|
@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
from apiserver.services.utils import escape_dict
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
@ -97,7 +98,34 @@ def _migrate_docker_cmd(db: Database):
|
||||
)
|
||||
|
||||
|
||||
def _migrate_model_labels(db: Database):
|
||||
tasks: Collection = db["task"]
|
||||
|
||||
fields = ("execution.model_labels", "container")
|
||||
query = {"$or": [{field: {"$nin": [None, {}]}} for field in fields]}
|
||||
|
||||
for doc in tasks.find(filter=query, projection=fields):
|
||||
set_commands = {}
|
||||
for field in fields:
|
||||
data = nested_get(doc, field.split("."))
|
||||
if not data:
|
||||
continue
|
||||
escaped = escape_dict(data)
|
||||
if data == escaped:
|
||||
continue
|
||||
set_commands[field] = escaped
|
||||
|
||||
if set_commands:
|
||||
tasks.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{
|
||||
"$set": set_commands
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
_migrate_task_models(db)
|
||||
_migrate_docker_cmd(db)
|
||||
_migrate_model_labels(db)
|
||||
_drop_all_indices_from_collections(db, ["task*"])
|
||||
|
@ -270,19 +270,6 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
container {
|
||||
type: object
|
||||
properties {
|
||||
image {
|
||||
type: string
|
||||
description: "Docker image"
|
||||
}
|
||||
arguments {
|
||||
type: string
|
||||
description: "Docker command arguments"
|
||||
}
|
||||
}
|
||||
}
|
||||
execution {
|
||||
type: object
|
||||
properties {
|
||||
@ -532,7 +519,8 @@ _definitions {
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
"$ref": "#/definitions/container"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
@ -1105,7 +1093,8 @@ create {
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
"$ref": "#/definitions/container"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1193,7 +1182,8 @@ validate {
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
"$ref": "#/definitions/container"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1354,7 +1344,8 @@ edit {
|
||||
}
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
"$ref": "#/definitions/container"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -106,6 +106,8 @@ from apiserver.services.utils import (
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
DockerCmdBackwardsCompatibility,
|
||||
escape_dict_field,
|
||||
unescape_dict_field,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
@ -385,6 +387,8 @@ create_fields = {
|
||||
"script": None,
|
||||
}
|
||||
|
||||
dict_fields_paths = [("execution", "model_labels"), "container"]
|
||||
|
||||
|
||||
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
@ -392,6 +396,8 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
artifacts_prepare_for_save(fields)
|
||||
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
|
||||
DockerCmdBackwardsCompatibility.prepare_for_save(call, fields)
|
||||
for path in dict_fields_paths:
|
||||
escape_dict_field(fields, path)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_stripped_fields:
|
||||
@ -412,6 +418,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
|
||||
tasks_data = [tasks_data]
|
||||
|
||||
conform_output_tags(call, tasks_data)
|
||||
|
||||
for data in tasks_data:
|
||||
for path in dict_fields_paths:
|
||||
unescape_dict_field(data, path)
|
||||
|
||||
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
||||
DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
||||
|
||||
|
@ -8,6 +8,7 @@ from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.utils import partition_tags
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
|
||||
|
||||
@ -96,6 +97,42 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
|
||||
)
|
||||
|
||||
|
||||
def escape_dict(data: dict) -> dict:
|
||||
if not data:
|
||||
return data
|
||||
|
||||
return {ParameterKeyEscaper.escape(k): v for k, v in data.items()}
|
||||
|
||||
|
||||
def unescape_dict(data: dict) -> dict:
|
||||
if not data:
|
||||
return data
|
||||
|
||||
return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()}
|
||||
|
||||
|
||||
def escape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
||||
if isinstance(path, str):
|
||||
path = (path,)
|
||||
|
||||
data = nested_get(fields, path)
|
||||
if not data or not isinstance(data, dict):
|
||||
return
|
||||
|
||||
nested_set(fields, path, escape_dict(data))
|
||||
|
||||
|
||||
def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
||||
if isinstance(path, str):
|
||||
path = (path,)
|
||||
|
||||
data = nested_get(fields, path)
|
||||
if not data or not isinstance(data, dict):
|
||||
return
|
||||
|
||||
nested_set(fields, path, unescape_dict(data))
|
||||
|
||||
|
||||
class ModelsBackwardsCompatibility:
|
||||
max_version = PartialVersion("2.13")
|
||||
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
|
||||
|
@ -47,9 +47,9 @@ class TestTasksArtifacts(TestService):
|
||||
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
|
||||
|
||||
new_artifacts = [
|
||||
dict(key="x", type="str", uri="x_test"),
|
||||
dict(key="y", type="int", uri="y_test"),
|
||||
dict(key="z", type="int", uri="y_test"),
|
||||
dict(key="x", type="str", uri="x_test", mode="input"),
|
||||
dict(key="y", type="int", uri="y_test", mode="input"),
|
||||
dict(key="z", type="int", uri="y_test", mode="input"),
|
||||
]
|
||||
new_task = self.api.tasks.clone(
|
||||
task=task, execution_overrides={"artifacts": new_artifacts}
|
||||
|
@ -202,7 +202,6 @@ class TestTasksEdit(TestService):
|
||||
for task in tasks:
|
||||
self.assertIn(system_tag, task.system_tags)
|
||||
self.assertIn("archived", task.system_tags)
|
||||
self.assertNotIn("queue", task.execution)
|
||||
self.assertIn(status_message, task.status_message)
|
||||
self.assertIn(status_reason, task.status_reason)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user