Add skip_empty parameter in get_configuration_names

This commit is contained in:
allegroai
2021-05-03 17:53:56 +03:00
parent dad935e81d
commit 9d9a44b927
5 changed files with 56 additions and 31 deletions

View File

@@ -105,7 +105,9 @@ class TestTasksHyperparams(TestService):
)
# clone task
new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id
new_task = self.api.tasks.clone(
task=task, new_task_hyperparams=new_params_dict
).id
try:
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
self.assertEqual(new_params, res.hyperparams)
@@ -123,7 +125,9 @@ class TestTasksHyperparams(TestService):
task=task, hyperparams=[dict(section="test")]
)
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True
task=task,
hyperparams=[dict(section="test", name="x", value="123")],
force=True,
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")], force=True
@@ -146,7 +150,12 @@ class TestTasksHyperparams(TestService):
return [
dict(section="Args", name=k, value=str(v), type="legacy")
if not k.startswith("TF_DEFINE/")
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
else dict(
section="TF_DEFINE",
name=k[len("TF_DEFINE/") :],
value=str(v),
type="legacy",
)
for k, v in legacy.items()
]
@@ -168,6 +177,7 @@ class TestTasksHyperparams(TestService):
new_config = [
dict(name="param$1", type="type1", value="10"),
dict(name="param/2", type="type1", value="20"),
dict(name="param_empty", type="type1", value=""),
]
new_config_dict = self._config_dict_from_list(new_config)
task, _ = self.new_task(
@@ -188,7 +198,14 @@ class TestTasksHyperparams(TestService):
# names
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", "param$1", "param/2"], res.names)
self.assertEqual(
["design", *[c["name"] for c in new_config if c["value"]]], res.names
)
res = self.api.tasks.get_configuration_names(
tasks=[task], skip_empty=False
).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", *[c["name"] for c in new_config]], res.names)
# returned as one list with names filtering
res = self.api.tasks.get_configurations(
@@ -216,14 +233,14 @@ class TestTasksHyperparams(TestService):
# delete
new_to_delete = self._get_config_keys(new_config[1:])
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config[:1], res.configuration)
# clone task
new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id
new_task = self.api.tasks.clone(
task=task, new_task_configuration=new_config_dict
).id
try:
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
self.assertEqual(new_config, res.configuration)
@@ -233,13 +250,9 @@ class TestTasksHyperparams(TestService):
# edit/delete of running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_configuration(
task=task, configuration=new_config
)
self.api.tasks.edit_configuration(task=task, configuration=new_config)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
self.api.tasks.edit_configuration(
task=task, configuration=new_config, force=True
)
@@ -292,7 +305,9 @@ class TestTasksHyperparams(TestService):
task_id, _ = self.new_task(
execution={"parameters": legacy_params, "model_desc": legacy_config}
)
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
config = self._config_dict_from_list(
self._new_config_from_legacy(legacy_config)
)
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
old_api = APIClient(base_url="http://localhost:8008/v2.8")
@@ -304,7 +319,10 @@ class TestTasksHyperparams(TestService):
modified_params = {"legacy.2": "val2"}
modified_config = {"design": "by"}
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
old_api.tasks.edit(
task=task_id,
execution=dict(parameters=modified_params, model_desc=modified_config),
)
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(modified_params, task.execution.parameters)
self.assertEqual(modified_config, task.execution.model_desc)