mirror of
https://github.com/clearml/clearml-server
synced 2025-04-16 21:41:37 +00:00
Escape task.container and task.execution.model_labels fields in DB
This commit is contained in:
parent
fdf6798d0c
commit
3d22ca1888
@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
|
|||||||
nested_set(
|
nested_set(
|
||||||
fields,
|
fields,
|
||||||
artifacts_field,
|
artifacts_field,
|
||||||
value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
|
value=sorted(artifacts.values(), key=itemgetter("key")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,13 +28,14 @@ from apiserver.database.model.task.task import (
|
|||||||
ArtifactModes,
|
ArtifactModes,
|
||||||
ModelItem,
|
ModelItem,
|
||||||
Models,
|
Models,
|
||||||
|
DEFAULT_ARTIFACT_MODE,
|
||||||
)
|
)
|
||||||
from apiserver.database.model import EntityVisibility
|
from apiserver.database.model import EntityVisibility
|
||||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||||
from apiserver.es_factory import es_factory
|
from apiserver.es_factory import es_factory
|
||||||
from apiserver.redis_manager import redman
|
from apiserver.redis_manager import redman
|
||||||
from apiserver.service_repo import APICall
|
from apiserver.service_repo import APICall
|
||||||
from apiserver.services.utils import validate_tags
|
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
from .artifacts import artifacts_prepare_for_save
|
from .artifacts import artifacts_prepare_for_save
|
||||||
@ -216,6 +217,8 @@ class TaskBLL:
|
|||||||
if legacy_value is not None:
|
if legacy_value is not None:
|
||||||
params_dict["execution"] = legacy_value
|
params_dict["execution"] = legacy_value
|
||||||
|
|
||||||
|
escape_dict_field(execution_overrides, "model_labels")
|
||||||
|
|
||||||
execution_dict.update(execution_overrides)
|
execution_dict.update(execution_overrides)
|
||||||
|
|
||||||
params_prepare_for_save(params_dict, previous_task=task)
|
params_prepare_for_save(params_dict, previous_task=task)
|
||||||
@ -225,7 +228,7 @@ class TaskBLL:
|
|||||||
execution_dict["artifacts"] = {
|
execution_dict["artifacts"] = {
|
||||||
k: a
|
k: a
|
||||||
for k, a in artifacts.items()
|
for k, a in artifacts.items()
|
||||||
if a.get("mode") != ArtifactModes.output
|
if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
|
||||||
}
|
}
|
||||||
execution_dict.pop("queue", None)
|
execution_dict.pop("queue", None)
|
||||||
|
|
||||||
@ -276,7 +279,7 @@ class TaskBLL:
|
|||||||
if task.output
|
if task.output
|
||||||
else None,
|
else None,
|
||||||
models=Models(input=input_models or task.models.input),
|
models=Models(input=input_models or task.models.input),
|
||||||
container=container or task.container,
|
container=escape_dict(container) or task.container,
|
||||||
execution=execution_dict,
|
execution=execution_dict,
|
||||||
configuration=params_dict.get("configuration") or task.configuration,
|
configuration=params_dict.get("configuration") or task.configuration,
|
||||||
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
|
||||||
|
@ -3,6 +3,7 @@ from datetime import datetime
|
|||||||
from pymongo.collection import Collection
|
from pymongo.collection import Collection
|
||||||
from pymongo.database import Database
|
from pymongo.database import Database
|
||||||
|
|
||||||
|
from apiserver.services.utils import escape_dict
|
||||||
from apiserver.utilities.dicts import nested_get
|
from apiserver.utilities.dicts import nested_get
|
||||||
from .utils import _drop_all_indices_from_collections
|
from .utils import _drop_all_indices_from_collections
|
||||||
|
|
||||||
@ -97,7 +98,34 @@ def _migrate_docker_cmd(db: Database):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _migrate_model_labels(db: Database):
|
||||||
|
tasks: Collection = db["task"]
|
||||||
|
|
||||||
|
fields = ("execution.model_labels", "container")
|
||||||
|
query = {"$or": [{field: {"$nin": [None, {}]}} for field in fields]}
|
||||||
|
|
||||||
|
for doc in tasks.find(filter=query, projection=fields):
|
||||||
|
set_commands = {}
|
||||||
|
for field in fields:
|
||||||
|
data = nested_get(doc, field.split("."))
|
||||||
|
if not data:
|
||||||
|
continue
|
||||||
|
escaped = escape_dict(data)
|
||||||
|
if data == escaped:
|
||||||
|
continue
|
||||||
|
set_commands[field] = escaped
|
||||||
|
|
||||||
|
if set_commands:
|
||||||
|
tasks.update_one(
|
||||||
|
{"_id": doc["_id"]},
|
||||||
|
{
|
||||||
|
"$set": set_commands
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def migrate_backend(db: Database):
|
def migrate_backend(db: Database):
|
||||||
_migrate_task_models(db)
|
_migrate_task_models(db)
|
||||||
_migrate_docker_cmd(db)
|
_migrate_docker_cmd(db)
|
||||||
|
_migrate_model_labels(db)
|
||||||
_drop_all_indices_from_collections(db, ["task*"])
|
_drop_all_indices_from_collections(db, ["task*"])
|
||||||
|
@ -270,19 +270,6 @@ _definitions {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
container {
|
|
||||||
type: object
|
|
||||||
properties {
|
|
||||||
image {
|
|
||||||
type: string
|
|
||||||
description: "Docker image"
|
|
||||||
}
|
|
||||||
arguments {
|
|
||||||
type: string
|
|
||||||
description: "Docker command arguments"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
execution {
|
execution {
|
||||||
type: object
|
type: object
|
||||||
properties {
|
properties {
|
||||||
@ -532,7 +519,8 @@ _definitions {
|
|||||||
}
|
}
|
||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
"$ref": "#/definitions/container"
|
type: object
|
||||||
|
additionalProperties { type: string }
|
||||||
}
|
}
|
||||||
models {
|
models {
|
||||||
description: "Task models"
|
description: "Task models"
|
||||||
@ -1105,7 +1093,8 @@ create {
|
|||||||
}
|
}
|
||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
"$ref": "#/definitions/container"
|
type: object
|
||||||
|
additionalProperties { type: string }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1193,7 +1182,8 @@ validate {
|
|||||||
}
|
}
|
||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
"$ref": "#/definitions/container"
|
type: object
|
||||||
|
additionalProperties { type: string }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1354,7 +1344,8 @@ edit {
|
|||||||
}
|
}
|
||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
"$ref": "#/definitions/container"
|
type: object
|
||||||
|
additionalProperties { type: string }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -106,6 +106,8 @@ from apiserver.services.utils import (
|
|||||||
conform_output_tags,
|
conform_output_tags,
|
||||||
ModelsBackwardsCompatibility,
|
ModelsBackwardsCompatibility,
|
||||||
DockerCmdBackwardsCompatibility,
|
DockerCmdBackwardsCompatibility,
|
||||||
|
escape_dict_field,
|
||||||
|
unescape_dict_field,
|
||||||
)
|
)
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
@ -385,6 +387,8 @@ create_fields = {
|
|||||||
"script": None,
|
"script": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dict_fields_paths = [("execution", "model_labels"), "container"]
|
||||||
|
|
||||||
|
|
||||||
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||||
conform_tag_fields(call, fields, validate=True)
|
conform_tag_fields(call, fields, validate=True)
|
||||||
@ -392,6 +396,8 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
|||||||
artifacts_prepare_for_save(fields)
|
artifacts_prepare_for_save(fields)
|
||||||
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
|
ModelsBackwardsCompatibility.prepare_for_save(call, fields)
|
||||||
DockerCmdBackwardsCompatibility.prepare_for_save(call, fields)
|
DockerCmdBackwardsCompatibility.prepare_for_save(call, fields)
|
||||||
|
for path in dict_fields_paths:
|
||||||
|
escape_dict_field(fields, path)
|
||||||
|
|
||||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||||
for field in task_script_stripped_fields:
|
for field in task_script_stripped_fields:
|
||||||
@ -412,6 +418,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
|
|||||||
tasks_data = [tasks_data]
|
tasks_data = [tasks_data]
|
||||||
|
|
||||||
conform_output_tags(call, tasks_data)
|
conform_output_tags(call, tasks_data)
|
||||||
|
|
||||||
|
for data in tasks_data:
|
||||||
|
for path in dict_fields_paths:
|
||||||
|
unescape_dict_field(data, path)
|
||||||
|
|
||||||
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
||||||
DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from apiserver.database.model.base import GetMixin
|
|||||||
from apiserver.database.utils import partition_tags
|
from apiserver.database.utils import partition_tags
|
||||||
from apiserver.service_repo import APICall
|
from apiserver.service_repo import APICall
|
||||||
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
|
||||||
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
from apiserver.utilities.partial_version import PartialVersion
|
from apiserver.utilities.partial_version import PartialVersion
|
||||||
|
|
||||||
|
|
||||||
@ -96,6 +97,42 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def escape_dict(data: dict) -> dict:
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
|
|
||||||
|
return {ParameterKeyEscaper.escape(k): v for k, v in data.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def unescape_dict(data: dict) -> dict:
|
||||||
|
if not data:
|
||||||
|
return data
|
||||||
|
|
||||||
|
return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def escape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = (path,)
|
||||||
|
|
||||||
|
data = nested_get(fields, path)
|
||||||
|
if not data or not isinstance(data, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
nested_set(fields, path, escape_dict(data))
|
||||||
|
|
||||||
|
|
||||||
|
def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = (path,)
|
||||||
|
|
||||||
|
data = nested_get(fields, path)
|
||||||
|
if not data or not isinstance(data, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
nested_set(fields, path, unescape_dict(data))
|
||||||
|
|
||||||
|
|
||||||
class ModelsBackwardsCompatibility:
|
class ModelsBackwardsCompatibility:
|
||||||
max_version = PartialVersion("2.13")
|
max_version = PartialVersion("2.13")
|
||||||
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
|
mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
|
||||||
|
@ -47,9 +47,9 @@ class TestTasksArtifacts(TestService):
|
|||||||
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
|
self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
|
||||||
|
|
||||||
new_artifacts = [
|
new_artifacts = [
|
||||||
dict(key="x", type="str", uri="x_test"),
|
dict(key="x", type="str", uri="x_test", mode="input"),
|
||||||
dict(key="y", type="int", uri="y_test"),
|
dict(key="y", type="int", uri="y_test", mode="input"),
|
||||||
dict(key="z", type="int", uri="y_test"),
|
dict(key="z", type="int", uri="y_test", mode="input"),
|
||||||
]
|
]
|
||||||
new_task = self.api.tasks.clone(
|
new_task = self.api.tasks.clone(
|
||||||
task=task, execution_overrides={"artifacts": new_artifacts}
|
task=task, execution_overrides={"artifacts": new_artifacts}
|
||||||
|
@ -202,7 +202,6 @@ class TestTasksEdit(TestService):
|
|||||||
for task in tasks:
|
for task in tasks:
|
||||||
self.assertIn(system_tag, task.system_tags)
|
self.assertIn(system_tag, task.system_tags)
|
||||||
self.assertIn("archived", task.system_tags)
|
self.assertIn("archived", task.system_tags)
|
||||||
self.assertNotIn("queue", task.execution)
|
|
||||||
self.assertIn(status_message, task.status_message)
|
self.assertIn(status_message, task.status_message)
|
||||||
self.assertIn(status_reason, task.status_reason)
|
self.assertIn(status_reason, task.status_reason)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user