Make sure model label values are integer

This commit is contained in:
allegroai 2023-01-24 16:11:12 +02:00
parent 15db9cdaef
commit f0d68b1ce9
2 changed files with 13 additions and 2 deletions

View File

@ -256,6 +256,16 @@ class TaskBLL:
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
def ensure_int_labels(execution: dict) -> dict:
if not execution:
return execution
model_labels = execution.get("model_labels")
if model_labels:
execution["model_labels"] = {k: int(v) for k, v in model_labels.items()}
return execution
parent_task = (
task.parent
if task.parent and not task.parent.startswith(deleted_prefix)
@ -280,7 +290,7 @@ class TaskBLL:
output=Output(destination=task.output.destination) if task.output else None,
models=Models(input=input_models or task.models.input),
container=escape_dict(container) or task.container,
execution=execution_dict,
execution=ensure_int_labels(execution_dict),
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)

View File

@ -103,7 +103,7 @@ class TestTasksEdit(TestService):
new_name = "new test"
new_tags = ["by"]
execution_overrides = dict(framework="Caffe")
execution_overrides = dict(framework="Caffe", model_labels={"test": 1.0})
new_task_id = self._clone_task(
task=task,
new_task_name=new_name,
@ -120,6 +120,7 @@ class TestTasksEdit(TestService):
self.assertEqual(new_task.parent, task)
# self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.execution.model_labels, {"test": 1})
self.assertEqual(new_task.system_tags, ["test"])
def test_model_check_in_clone(self):