From f0d68b1ce9cae972d0348c8ae3844567531fc4c6 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Tue, 24 Jan 2023 16:11:12 +0200 Subject: [PATCH] Make sure model label values are integer --- apiserver/bll/task/task_bll.py | 12 +++++++++++- apiserver/tests/automated/test_tasks_edit.py | 3 ++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/apiserver/bll/task/task_bll.py b/apiserver/bll/task/task_bll.py index 31cfae1..a7a93aa 100644 --- a/apiserver/bll/task/task_bll.py +++ b/apiserver/bll/task/task_bll.py @@ -256,6 +256,16 @@ class TaskBLL: not in [TaskSystemTags.development, EntityVisibility.archived.value] ] + def ensure_int_labels(execution: dict) -> dict: + if not execution: + return execution + + model_labels = execution.get("model_labels") + if model_labels: + execution["model_labels"] = {k: int(v) for k, v in model_labels.items()} + + return execution + parent_task = ( task.parent if task.parent and not task.parent.startswith(deleted_prefix) @@ -280,7 +290,7 @@ class TaskBLL: output=Output(destination=task.output.destination) if task.output else None, models=Models(input=input_models or task.models.input), container=escape_dict(container) or task.container, - execution=execution_dict, + execution=ensure_int_labels(execution_dict), configuration=params_dict.get("configuration") or task.configuration, hyperparams=params_dict.get("hyperparams") or task.hyperparams, ) diff --git a/apiserver/tests/automated/test_tasks_edit.py b/apiserver/tests/automated/test_tasks_edit.py index b55b9c8..b444472 100644 --- a/apiserver/tests/automated/test_tasks_edit.py +++ b/apiserver/tests/automated/test_tasks_edit.py @@ -103,7 +103,7 @@ class TestTasksEdit(TestService): new_name = "new test" new_tags = ["by"] - execution_overrides = dict(framework="Caffe") + execution_overrides = dict(framework="Caffe", model_labels={"test": 1.0}) new_task_id = self._clone_task( task=task, new_task_name=new_name, @@ -120,6 +120,7 @@ class TestTasksEdit(TestService): self.assertEqual(new_task.parent, task) # self.assertEqual(new_task.execution.parameters, execution["parameters"]) self.assertEqual(new_task.execution.framework, execution_overrides["framework"]) + self.assertEqual(new_task.execution.model_labels, {"test": 1}) self.assertEqual(new_task.system_tags, ["test"]) def test_model_check_in_clone(self):