clearml-server/apiserver/tests/automated/test_task_models.py
2023-05-25 19:13:10 +03:00

115 lines
4.0 KiB
Python

from copy import deepcopy
from typing import Sequence, Optional
from packaging.version import parse
from apiserver.tests.automated import TestService
class TestTaskModels(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
def test_new_apis(self):
# no models
empty_task = self.new_task()
self.assertModels(empty_task, [], [])
id1, id2 = self.new_model("model1"), self.new_model("model2")
input_models = [
{"name": "input1", "model": id1},
{"name": "input2", "model": id2},
]
output_models = [
{"name": "output1", "model": "id3"},
{"name": "output2", "model": "id4"},
]
# task creation with models
task = self.new_task(models={"input": input_models, "output": output_models})
self.assertModels(task, input_models, output_models)
# add_or_update existing model
res = self.api.tasks.add_or_update_model(
task=task, name="input1", type="input", model="Test"
)
self.assertEqual(res.updated, 1)
modified_input = deepcopy(input_models)
modified_input[0]["model"] = "Test"
self.assertModels(task, modified_input, output_models)
# add_or_update new mode
res = self.api.tasks.add_or_update_model(
task=task, name="output3", type="output", model="TestOutput"
)
self.assertEqual(res.updated, 1)
modified_output = deepcopy(output_models)
modified_output.append({"name": "output3", "model": "TestOutput"})
self.assertModels(task, modified_input, modified_output)
# task editing
self.api.tasks.edit(
task=task, models={"input": input_models, "output": output_models}
)
self.assertModels(task, input_models, output_models)
# delete models
res = self.api.tasks.delete_models(
task=task,
models=[
{"name": "input1", "type": "input"},
{"name": "input2", "type": "input"},
{"name": "output1", "type": "output"},
{"name": "not_existing", "type": "output"},
]
)
self.assertEqual(res.updated, 1)
self.assertModels(task, [], output_models[1:])
def assertModels(
self, task_id: str, input_models: Sequence[dict], output_models: Sequence[dict],
):
def get_model_id(model: dict) -> Optional[str]:
if not model:
return None
id_ = model.get("model")
if isinstance(id_, str):
return id_
if id_ is None or id_ == {}:
return None
return id_.get("id")
def compare_models(actual: Sequence[dict], expected: Sequence[dict]):
self.assertEqual(
[(m["name"], get_model_id(m)) for m in actual],
[(m["name"], m["model"]) for m in expected],
)
for task in (
self.api.tasks.get_all_ex(id=task_id).tasks[0],
self.api.tasks.get_all(id=task_id).tasks[0],
self.api.tasks.get_by_id(task=task_id).task,
):
compare_models(task.models.input, input_models)
compare_models(task.models.output, output_models)
if self._version < parse("2.13"):
self.assertEqual(
get_model_id(task.execution),
input_models[0]["model"] if input_models else None,
)
self.assertEqual(
get_model_id(task.output),
output_models[-1]["model"] if output_models else None,
)
def new_task(self, **kwargs):
self.update_missing(
kwargs, type="testing", name="test task models"
)
return self.create_temp("tasks", **kwargs)
def new_model(self, name: str, **kwargs):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)