Add skip_empty parameter in get_configuration_names

This commit is contained in:
allegroai 2021-05-03 17:53:56 +03:00
parent dad935e81d
commit 9d9a44b927
5 changed files with 56 additions and 31 deletions

View File

@ -201,7 +201,7 @@ class GetConfigurationsRequest(MultiTaskRequest):
class GetConfigurationNamesRequest(MultiTaskRequest):
pass
skip_empty = BoolField(default=True)
class Configuration(models.Base):

View File

@ -175,21 +175,23 @@ class HyperParams:
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
) -> Dict[str, list]:
with TimingContext("mongo", "get_configuration_names"):
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
*([skip_empty_condition] if skip_empty else []),
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
with TimingContext("mongo", "get_configuration_names"):
tasks = Task.aggregate(pipeline)
return {

View File

@ -2200,6 +2200,11 @@ get_configuration_names {
type: array
items { type: string }
}
skip_empty {
description: If set to 'true' then the names for configurations with missing values are not returned
type: boolean
default: true
}
}
}
response {

View File

@ -789,7 +789,7 @@ def get_configuration_names(
):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
call.result.data = {

View File

@ -105,7 +105,9 @@ class TestTasksHyperparams(TestService):
)
# clone task
new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id
new_task = self.api.tasks.clone(
task=task, new_task_hyperparams=new_params_dict
).id
try:
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
self.assertEqual(new_params, res.hyperparams)
@ -123,7 +125,9 @@ class TestTasksHyperparams(TestService):
task=task, hyperparams=[dict(section="test")]
)
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True
task=task,
hyperparams=[dict(section="test", name="x", value="123")],
force=True,
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")], force=True
@ -146,7 +150,12 @@ class TestTasksHyperparams(TestService):
return [
dict(section="Args", name=k, value=str(v), type="legacy")
if not k.startswith("TF_DEFINE/")
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
else dict(
section="TF_DEFINE",
name=k[len("TF_DEFINE/") :],
value=str(v),
type="legacy",
)
for k, v in legacy.items()
]
@ -168,6 +177,7 @@ class TestTasksHyperparams(TestService):
new_config = [
dict(name="param$1", type="type1", value="10"),
dict(name="param/2", type="type1", value="20"),
dict(name="param_empty", type="type1", value=""),
]
new_config_dict = self._config_dict_from_list(new_config)
task, _ = self.new_task(
@ -188,7 +198,14 @@ class TestTasksHyperparams(TestService):
# names
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", "param$1", "param/2"], res.names)
self.assertEqual(
["design", *[c["name"] for c in new_config if c["value"]]], res.names
)
res = self.api.tasks.get_configuration_names(
tasks=[task], skip_empty=False
).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", *[c["name"] for c in new_config]], res.names)
# returned as one list with names filtering
res = self.api.tasks.get_configurations(
@ -216,14 +233,14 @@ class TestTasksHyperparams(TestService):
# delete
new_to_delete = self._get_config_keys(new_config[1:])
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config[:1], res.configuration)
# clone task
new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id
new_task = self.api.tasks.clone(
task=task, new_task_configuration=new_config_dict
).id
try:
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
self.assertEqual(new_config, res.configuration)
@ -233,13 +250,9 @@ class TestTasksHyperparams(TestService):
# edit/delete of running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_configuration(
task=task, configuration=new_config
)
self.api.tasks.edit_configuration(task=task, configuration=new_config)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
self.api.tasks.edit_configuration(
task=task, configuration=new_config, force=True
)
@ -292,7 +305,9 @@ class TestTasksHyperparams(TestService):
task_id, _ = self.new_task(
execution={"parameters": legacy_params, "model_desc": legacy_config}
)
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
config = self._config_dict_from_list(
self._new_config_from_legacy(legacy_config)
)
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
old_api = APIClient(base_url="http://localhost:8008/v2.8")
@ -304,7 +319,10 @@ class TestTasksHyperparams(TestService):
modified_params = {"legacy.2": "val2"}
modified_config = {"design": "by"}
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
old_api.tasks.edit(
task=task_id,
execution=dict(parameters=modified_params, model_desc=modified_config),
)
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(modified_params, task.execution.parameters)
self.assertEqual(modified_config, task.execution.model_desc)