Make sure model label values are integer

This commit is contained in:
allegroai 2023-01-24 16:11:12 +02:00
parent 15db9cdaef
commit f0d68b1ce9
2 changed files with 13 additions and 2 deletions

View File

@ -256,6 +256,16 @@ class TaskBLL:
not in [TaskSystemTags.development, EntityVisibility.archived.value] not in [TaskSystemTags.development, EntityVisibility.archived.value]
] ]
def ensure_int_labels(execution: dict) -> dict:
if not execution:
return execution
model_labels = execution.get("model_labels")
if model_labels:
execution["model_labels"] = {k: int(v) for k, v in model_labels.items()}
return execution
parent_task = ( parent_task = (
task.parent task.parent
if task.parent and not task.parent.startswith(deleted_prefix) if task.parent and not task.parent.startswith(deleted_prefix)
@ -280,7 +290,7 @@ class TaskBLL:
output=Output(destination=task.output.destination) if task.output else None, output=Output(destination=task.output.destination) if task.output else None,
models=Models(input=input_models or task.models.input), models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container, container=escape_dict(container) or task.container,
execution=execution_dict, execution=ensure_int_labels(execution_dict),
configuration=params_dict.get("configuration") or task.configuration, configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams, hyperparams=params_dict.get("hyperparams") or task.hyperparams,
) )

View File

@ -103,7 +103,7 @@ class TestTasksEdit(TestService):
new_name = "new test" new_name = "new test"
new_tags = ["by"] new_tags = ["by"]
execution_overrides = dict(framework="Caffe") execution_overrides = dict(framework="Caffe", model_labels={"test": 1.0})
new_task_id = self._clone_task( new_task_id = self._clone_task(
task=task, task=task,
new_task_name=new_name, new_task_name=new_name,
@ -120,6 +120,7 @@ class TestTasksEdit(TestService):
self.assertEqual(new_task.parent, task) self.assertEqual(new_task.parent, task)
# self.assertEqual(new_task.execution.parameters, execution["parameters"]) # self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"]) self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.execution.model_labels, {"test": 1})
self.assertEqual(new_task.system_tags, ["test"]) self.assertEqual(new_task.system_tags, ["test"])
def test_model_check_in_clone(self): def test_model_check_in_clone(self):