From 3d22ca18888fc91b819a40a01cdf05fea5a8d978 Mon Sep 17 00:00:00 2001
From: allegroai <>
Date: Mon, 3 May 2021 17:56:17 +0300
Subject: [PATCH] Escape task.container and task.execution.model_labels fields
 in DB

---
 apiserver/bll/task/artifacts.py               |  2 +-
 apiserver/bll/task/task_bll.py                |  9 +++--
 apiserver/mongo/migrations/0_18_0.py          | 28 ++++++++++++++
 apiserver/schema/services/tasks.conf          | 25 ++++---------
 apiserver/services/tasks.py                   | 11 ++++++
 apiserver/services/utils.py                   | 37 +++++++++++++++++++
 .../tests/automated/test_task_artifacts.py    |  6 +--
 apiserver/tests/automated/test_tasks_edit.py  |  1 -
 8 files changed, 94 insertions(+), 25 deletions(-)

diff --git a/apiserver/bll/task/artifacts.py b/apiserver/bll/task/artifacts.py
index 36205f4..b44e428 100644
--- a/apiserver/bll/task/artifacts.py
+++ b/apiserver/bll/task/artifacts.py
@@ -40,7 +40,7 @@ def artifacts_unprepare_from_saved(fields):
     nested_set(
         fields,
         artifacts_field,
-        value=sorted(artifacts.values(), key=itemgetter("key", "mode")),
+        value=sorted(artifacts.values(), key=itemgetter("key")),
     )
 
 
diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py
index 691f3e8..f6f0604 100644
--- a/apiserver/bll/task/task_bll.py
+++ b/apiserver/bll/task/task_bll.py
@@ -28,13 +28,14 @@ from apiserver.database.model.task.task import (
     ArtifactModes,
     ModelItem,
     Models,
+    DEFAULT_ARTIFACT_MODE,
 )
 from apiserver.database.model import EntityVisibility
 from apiserver.database.utils import get_company_or_none_constraint, id as create_id
 from apiserver.es_factory import es_factory
 from apiserver.redis_manager import redman
 from apiserver.service_repo import APICall
-from apiserver.services.utils import validate_tags
+from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
 from apiserver.timing_context import TimingContext
 from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
 from .artifacts import artifacts_prepare_for_save
@@ -216,6 +217,8 @@ class TaskBLL:
                 if legacy_value is not None:
                     params_dict["execution"] = legacy_value
 
+            escape_dict_field(execution_overrides, "model_labels")
+
             execution_dict.update(execution_overrides)
 
         params_prepare_for_save(params_dict, previous_task=task)
@@ -225,7 +228,7 @@ class TaskBLL:
             execution_dict["artifacts"] = {
                 k: a
                 for k, a in artifacts.items()
-                if a.get("mode") != ArtifactModes.output
+                if a.get("mode", DEFAULT_ARTIFACT_MODE) != ArtifactModes.output
             }
         execution_dict.pop("queue", None)
 
@@ -276,7 +279,7 @@ class TaskBLL:
                 if task.output
                 else None,
                 models=Models(input=input_models or task.models.input),
-                container=container or task.container,
+                container=escape_dict(container) or task.container,
                 execution=execution_dict,
                 configuration=params_dict.get("configuration") or task.configuration,
                 hyperparams=params_dict.get("hyperparams") or task.hyperparams,
diff --git a/apiserver/mongo/migrations/0_18_0.py b/apiserver/mongo/migrations/0_18_0.py
index a5db0e3..b183f03 100644
--- a/apiserver/mongo/migrations/0_18_0.py
+++ b/apiserver/mongo/migrations/0_18_0.py
@@ -3,6 +3,7 @@ from datetime import datetime
 from pymongo.collection import Collection
 from pymongo.database import Database
 
+from apiserver.services.utils import escape_dict
 from apiserver.utilities.dicts import nested_get
 from .utils import _drop_all_indices_from_collections
 
@@ -97,7 +98,34 @@ def _migrate_docker_cmd(db: Database):
         )
 
 
+def _migrate_model_labels(db: Database):
+    tasks: Collection = db["task"]
+
+    fields = ("execution.model_labels", "container")
+    query = {"$or": [{field: {"$nin": [None, {}]}} for field in fields]}
+
+    for doc in tasks.find(filter=query, projection=fields):
+        set_commands = {}
+        for field in fields:
+            data = nested_get(doc, field.split("."))
+            if not data:
+                continue
+            escaped = escape_dict(data)
+            if data == escaped:
+                continue
+            set_commands[field] = escaped
+
+        if set_commands:
+            tasks.update_one(
+                {"_id": doc["_id"]},
+                {
+                    "$set": set_commands
+                }
+            )
+
+
 def migrate_backend(db: Database):
     _migrate_task_models(db)
     _migrate_docker_cmd(db)
+    _migrate_model_labels(db)
     _drop_all_indices_from_collections(db, ["task*"])
diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf
index 2ac1f56..e7491e5 100644
--- a/apiserver/schema/services/tasks.conf
+++ b/apiserver/schema/services/tasks.conf
@@ -270,19 +270,6 @@ _definitions {
             }
         }
     }
-    container {
-        type: object
-        properties {
-            image {
-                type: string
-                description: "Docker image"
-            }
-            arguments {
-                type: string
-                description: "Docker command arguments"
-            }
-        }
-    }
     execution {
         type: object
         properties {
@@ -532,7 +519,8 @@ _definitions {
             }
             container {
                 description: "Docker container parameters"
-                "$ref": "#/definitions/container"
+                type: object
+                additionalProperties { type: string }
             }
             models {
                 description: "Task models"
@@ -1105,7 +1093,8 @@ create {
                 }
                 container {
                     description: "Docker container parameters"
-                    "$ref": "#/definitions/container"
+                    type: object
+                    additionalProperties { type: string }
                 }
             }
         }
@@ -1193,7 +1182,8 @@ validate {
                 }
                 container {
                     description: "Docker container parameters"
-                    "$ref": "#/definitions/container"
+                    type: object
+                    additionalProperties { type: string }
                 }
             }
         }
@@ -1354,7 +1344,8 @@ edit {
                 }
                 container {
                     description: "Docker container parameters"
-                    "$ref": "#/definitions/container"
+                    type: object
+                    additionalProperties { type: string }
                 }
             }
         }
diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py
index 34a320e..2e51afe 100644
--- a/apiserver/services/tasks.py
+++ b/apiserver/services/tasks.py
@@ -106,6 +106,8 @@ from apiserver.services.utils import (
     conform_output_tags,
     ModelsBackwardsCompatibility,
     DockerCmdBackwardsCompatibility,
+    escape_dict_field,
+    unescape_dict_field,
 )
 from apiserver.timing_context import TimingContext
 from apiserver.utilities.partial_version import PartialVersion
@@ -385,6 +387,8 @@ create_fields = {
     "script": None,
 }
 
+dict_fields_paths = [("execution", "model_labels"), "container"]
+
 
 def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
     conform_tag_fields(call, fields, validate=True)
@@ -392,6 +396,8 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
     artifacts_prepare_for_save(fields)
     ModelsBackwardsCompatibility.prepare_for_save(call, fields)
     DockerCmdBackwardsCompatibility.prepare_for_save(call, fields)
+    for path in dict_fields_paths:
+        escape_dict_field(fields, path)
 
     # Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
     for field in task_script_stripped_fields:
@@ -412,6 +418,11 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
         tasks_data = [tasks_data]
 
     conform_output_tags(call, tasks_data)
+
+    for data in tasks_data:
+        for path in dict_fields_paths:
+            unescape_dict_field(data, path)
+
     ModelsBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
     DockerCmdBackwardsCompatibility.unprepare_from_saved(call, tasks_data)
 
diff --git a/apiserver/services/utils.py b/apiserver/services/utils.py
index 947fd73..281ef08 100644
--- a/apiserver/services/utils.py
+++ b/apiserver/services/utils.py
@@ -8,6 +8,7 @@ from apiserver.database.model.base import GetMixin
 from apiserver.database.utils import partition_tags
 from apiserver.service_repo import APICall
 from apiserver.utilities.dicts import nested_set, nested_get, nested_delete
+from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
 from apiserver.utilities.partial_version import PartialVersion
 
 
@@ -96,6 +97,42 @@ def validate_tags(tags: Sequence[str], system_tags: Sequence[str]):
             )
 
 
+def escape_dict(data: dict) -> dict:
+    if not data:
+        return data
+
+    return {ParameterKeyEscaper.escape(k): v for k, v in data.items()}
+
+
+def unescape_dict(data: dict) -> dict:
+    if not data:
+        return data
+
+    return {ParameterKeyEscaper.unescape(k): v for k, v in data.items()}
+
+
+def escape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
+    if isinstance(path, str):
+        path = (path,)
+
+    data = nested_get(fields, path)
+    if not data or not isinstance(data, dict):
+        return
+
+    nested_set(fields, path, escape_dict(data))
+
+
+def unescape_dict_field(fields: dict, path: Union[str, Sequence[str]]):
+    if isinstance(path, str):
+        path = (path,)
+
+    data = nested_get(fields, path)
+    if not data or not isinstance(data, dict):
+        return
+
+    nested_set(fields, path, unescape_dict(data))
+
+
 class ModelsBackwardsCompatibility:
     max_version = PartialVersion("2.13")
     mode_to_fields = {"input": ("execution", "model"), "output": ("output", "model")}
diff --git a/apiserver/tests/automated/test_task_artifacts.py b/apiserver/tests/automated/test_task_artifacts.py
index 44e4f86..60e0487 100644
--- a/apiserver/tests/automated/test_task_artifacts.py
+++ b/apiserver/tests/automated/test_task_artifacts.py
@@ -47,9 +47,9 @@ class TestTasksArtifacts(TestService):
         self._assertTaskArtifacts([a for a in artifacts if a["mode"] != "output"], res)
 
         new_artifacts = [
-            dict(key="x", type="str", uri="x_test"),
-            dict(key="y", type="int", uri="y_test"),
-            dict(key="z", type="int", uri="y_test"),
+            dict(key="x", type="str", uri="x_test", mode="input"),
+            dict(key="y", type="int", uri="y_test", mode="input"),
+            dict(key="z", type="int", uri="y_test", mode="input"),
         ]
         new_task = self.api.tasks.clone(
             task=task, execution_overrides={"artifacts": new_artifacts}
diff --git a/apiserver/tests/automated/test_tasks_edit.py b/apiserver/tests/automated/test_tasks_edit.py
index d6ac6e2..b55b9c8 100644
--- a/apiserver/tests/automated/test_tasks_edit.py
+++ b/apiserver/tests/automated/test_tasks_edit.py
@@ -202,7 +202,6 @@ class TestTasksEdit(TestService):
         for task in tasks:
             self.assertIn(system_tag, task.system_tags)
             self.assertIn("archived", task.system_tags)
-            self.assertNotIn("queue", task.execution)
             self.assertIn(status_message, task.status_message)
             self.assertIn(status_reason, task.status_reason)