diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py index 36205f4..b44e428 100644 --- a/apiserver/bll/task/artifacts.py +++ b/apiserver/bll/task/artifacts.py @@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields): nested_set( fields, artifacts_field, - value=sorted(artifacts.values(), key=itemgetter("key", "mode")), + value=sorted(artifacts.values(), key=itemgetter("key")), ) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 691f3e8..f6f0604 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -28,13 +28,14 @@ from apiserver.database.model.task.task import ( ArtifactModes, ModelItem, Models, + DEFAULT_ARTIFACT_MODE, ) from apiserver.database.model import EntityVisibility from apiserver.database.utils import get_company_or_none_constraint, id as create_id from apiserver.es_factory import es_factory from apiserver.redis_manager import redman from apiserver.service_repo import APICall -from apiserver.services.utils import validate_tags +from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict from apiserver.timing_context import TimingContext from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from .artifacts import artifacts_prepare_for_save @@ -216,6 +217,8 @@ class TaskBLL: if legacy_value is not None: params_dict["execution"] = legacy_value + escape_dict_field(execution_overrides, "model_labels") + execution_dict.update(execution_overrides) params_prepare_for_save(params_dict, previous_task=task) @@ -225,7 +228,7 @@ class TaskBLL: execution_dict["artifacts"] = { k: a for k, a in artifacts.items() - if a.get("mode") != ArtifactModes.output + if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output } execution_dict.pop("queue", None) @@ -276,7 +279,7 @@ class TaskBLL: if task.output else None, models=Models(input=input_models or task.models.input), - container=container or task.container, + container=escape_dict(container) or task.container, execution=execution_dict, configuration=params_dict.get("configuration") or task.configuration, hyperparams=params_dict.get("hyperparams") or task.hyperparams, diff --git a/apiserver/mongo/migrations/0_18_0.py b/apiserver/mongo/migrations/0_18_0.py index a5db0e3..b183f03 100644 --- a/apiserver/mongo/migrations/0_18_0.py +++ b/apiserver/mongo/migrations/0_18_0.py @@ -3,6 +3,7 @@ from datetime import datetime from pymongo.collection import Collection from pymongo.database import Database +from apiserver.services.utils import escape_dict from apiserver.utilities.dicts import nested_get from .utils import _drop_all_indices_from_collections @@ -97,7 +98,34 @@ def _migrate_docker_cmd(db: Database): ) +def _migrate_model_labels(db: Database): + tasks: Collection = db["task"] + + fields = ("execution.model_labels", "container") + query = {"$or": [{field: {"$nin": [None, {}]}} for field in fields]} + + for doc in tasks.find(filter=query, projection=fields): + set_commands = {} + for field in fields: + data = nested_get(doc, field.split(".")) + if not data: + continue + escaped = escape_dict(data) + if data == escaped: + continue + set_commands[field] = escaped + + if set_commands: + tasks.update_one( + {"_id": doc["_id"]}, + { + "$set": set_commands + } + ) + + def migrate_backend(db: Database): _migrate_task_models(db) _migrate_docker_cmd(db) + _migrate_model_labels(db) _drop_all_indices_from_collections(db, ["task*"]) diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 2ac1f56..e7491e5 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -270,19 +270,6 @@ _definitions { } } } - container { - type: object - properties { - image { - type: string - description: "Docker image" - } - arguments { - type: string - description: "Docker command arguments" - } - } - } execution { type: object properties { @@ -532,7 +519,8 @@ _definitions { } container { description: "Docker container parameters" - "$ref": "#/definitions/container" + type: object + additionalProperties { type: string } } models { description: "Task models" @@ -1105,7 +1093,8 @@ create { } container { description: "Docker container parameters" - "$ref": "#/definitions/container" + type: object + additionalProperties { type: string } } } } @@ -1193,7 +1182,8 @@ validate { } container { description: "Docker container parameters" - "$ref": "#/definitions/container" + type: object + additionalProperties { type: string } } } } @@ -1354,7 +1344,8 @@ edit { } container { description: "Docker container parameters" - "$ref": "#/definitions/container" + type: object + additionalProperties { type: string } } } } diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index 34a320e..2e51afe 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -106,6 +106,8 @@ from apiserver.services.utils import ( conform_output_tags, ModelsBackwardsCompatibility, DockerCmdBackwardsCompatibility, + escape_dict_field, + unescape_dict_field, ) from apiserver.timing_context import TimingContext from apiserver.utilities.partial_version import PartialVersion @@ -385,6 +387,8 @@ create_fields = { "script": None, } +dict_fields_paths = [("execution", "model_labels"), "container"] + def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None): conform_tag_fields(call, fields, validate=True) @@ -392,6 +396,8 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None): artifacts_prepare_for_save(fields) ModelsBackwardsCompatibility.prepare_for_save(call, fields) DockerCmdBackwardsCompatibility.prepare_for_save(call, fields) + for path in dict_fields_paths: + escape_dict_field(fields, path) # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths for field in task_script_stripped_fields: @@ -412,6 +418,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict]) tasks_data = [tasks_data] conform_output_tags(call, tasks_data) + + for data in tasks_data: + for path in dict_fields_paths: + unescape_dict_field(data, path) + ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data) DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data) diff --git a/apiserver/services/utils.py b/apiserver/services/utils.py index 947fd73..281ef08 100644 --- a/apiserver/services/utils.py +++ b/apiserver/services/utils.py @@ -8,6 +8,7 @@ from apiserver.database.model.base import GetMixin from apiserver.database.utils import partition_tags from apiserver.service_repo import APICall from apiserver.utilities.dicts import nested_set, nested_get, nested_delete +from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper from apiserver.utilities.partial_version import PartialVersion @@ -96,6 +97,42 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]): ) +def escape_dict(data: dict) -> dict: + if not data: + return data + + return {ParameterKeyEscaper.escape(k): v for k, v in data.items()} + + +def unescape_dict(data: dict) -> dict: + if not data: + return data + + return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()} + + +def escape_dict_field(fields: dict, path: Union[str, Sequence[str]]): + if isinstance(path, str): + path = (path,) + + data = nested_get(fields, path) + if not data or not isinstance(data, dict): + return + + nested_set(fields, path, escape_dict(data)) + + +def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]): + if isinstance(path, str): + path = (path,) + + data = nested_get(fields, path) + if not data or not isinstance(data, dict): + return + + nested_set(fields, path, unescape_dict(data)) + + class ModelsBackwardsCompatibility: max_version = PartialVersion("2.13") mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")} diff --git a/apiserver/tests/automated/test_task_artifacts.py b/apiserver/tests/automated/test_task_artifacts.py index 44e4f86..60e0487 100644 --- a/apiserver/tests/automated/test_task_artifacts.py +++ b/apiserver/tests/automated/test_task_artifacts.py @@ -47,9 +47,9 @@ class TestTasksArtifacts(TestService): self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res) new_artifacts = [ - dict(key="x", type="str", uri="x_test"), - dict(key="y", type="int", uri="y_test"), - dict(key="z", type="int", uri="y_test"), + dict(key="x", type="str", uri="x_test", mode="input"), + dict(key="y", type="int", uri="y_test", mode="input"), + dict(key="z", type="int", uri="y_test", mode="input"), ] new_task = self.api.tasks.clone( task=task, execution_overrides={"artifacts": new_artifacts} diff --git a/apiserver/tests/automated/test_tasks_edit.py b/apiserver/tests/automated/test_tasks_edit.py index d6ac6e2..b55b9c8 100644 --- a/apiserver/tests/automated/test_tasks_edit.py +++ b/apiserver/tests/automated/test_tasks_edit.py @@ -202,7 +202,6 @@ class TestTasksEdit(TestService): for task in tasks: self.assertIn(system_tag, task.system_tags) self.assertIn("archived", task.system_tags) - self.assertNotIn("queue", task.execution) self.assertIn(status_message, task.status_message) self.assertIn(status_reason, task.status_reason)