Add skip_empty parameter in get_configuration_names

This commit is contained in:
allegroai 2021-05-03 17:53:56 +03:00
parent dad935e81d
commit 9d9a44b927
5 changed files with 56 additions and 31 deletions

View File

@ -201,7 +201,7 @@ class GetConfigurationsRequest(MultiTaskRequest):
class GetConfigurationNamesRequest(MultiTaskRequest): class GetConfigurationNamesRequest(MultiTaskRequest):
pass skip_empty = BoolField(default=True)
class Configuration(models.Base): class Configuration(models.Base):

View File

@ -175,21 +175,23 @@ class HyperParams:
@classmethod @classmethod
def get_configuration_names( def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str] cls, company_id: str, task_ids: Sequence[str], skip_empty: bool
) -> Dict[str, list]: ) -> Dict[str, list]:
with TimingContext("mongo", "get_configuration_names"): skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}}
pipeline = [ pipeline = [
{ {
"$match": { "$match": {
"company": {"$in": [None, "", company_id]}, "company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids}, "_id": {"$in": task_ids},
} }
}, },
{"$project": {"items": {"$objectToArray": "$configuration"}}}, {"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"}, {"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}}, *([skip_empty_condition] if skip_empty else []),
] {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
with TimingContext("mongo", "get_configuration_names"):
tasks = Task.aggregate(pipeline) tasks = Task.aggregate(pipeline)
return { return {

View File

@ -2200,6 +2200,11 @@ get_configuration_names {
type: array type: array
items { type: string } items { type: string }
} }
skip_empty {
description: If set to 'true' then the names for configurations with missing values are not returned
type: boolean
default: true
}
} }
} }
response { response {

View File

@ -789,7 +789,7 @@ def get_configuration_names(
): ):
with translate_errors_context(): with translate_errors_context():
tasks_params = HyperParams.get_configuration_names( tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks company_id, task_ids=request.tasks, skip_empty=request.skip_empty
) )
call.result.data = { call.result.data = {

View File

@ -105,7 +105,9 @@ class TestTasksHyperparams(TestService):
) )
# clone task # clone task
new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id new_task = self.api.tasks.clone(
task=task, new_task_hyperparams=new_params_dict
).id
try: try:
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0] res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
self.assertEqual(new_params, res.hyperparams) self.assertEqual(new_params, res.hyperparams)
@ -123,7 +125,9 @@ class TestTasksHyperparams(TestService):
task=task, hyperparams=[dict(section="test")] task=task, hyperparams=[dict(section="test")]
) )
self.api.tasks.edit_hyper_params( self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True task=task,
hyperparams=[dict(section="test", name="x", value="123")],
force=True,
) )
self.api.tasks.delete_hyper_params( self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="test")], force=True task=task, hyperparams=[dict(section="test")], force=True
@ -146,7 +150,12 @@ class TestTasksHyperparams(TestService):
return [ return [
dict(section="Args", name=k, value=str(v), type="legacy") dict(section="Args", name=k, value=str(v), type="legacy")
if not k.startswith("TF_DEFINE/") if not k.startswith("TF_DEFINE/")
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy") else dict(
section="TF_DEFINE",
name=k[len("TF_DEFINE/") :],
value=str(v),
type="legacy",
)
for k, v in legacy.items() for k, v in legacy.items()
] ]
@ -168,6 +177,7 @@ class TestTasksHyperparams(TestService):
new_config = [ new_config = [
dict(name="param$1", type="type1", value="10"), dict(name="param$1", type="type1", value="10"),
dict(name="param/2", type="type1", value="20"), dict(name="param/2", type="type1", value="20"),
dict(name="param_empty", type="type1", value=""),
] ]
new_config_dict = self._config_dict_from_list(new_config) new_config_dict = self._config_dict_from_list(new_config)
task, _ = self.new_task( task, _ = self.new_task(
@ -188,7 +198,14 @@ class TestTasksHyperparams(TestService):
# names # names
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0] res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
self.assertEqual(task, res.task) self.assertEqual(task, res.task)
self.assertEqual(["design", "param$1", "param/2"], res.names) self.assertEqual(
["design", *[c["name"] for c in new_config if c["value"]]], res.names
)
res = self.api.tasks.get_configuration_names(
tasks=[task], skip_empty=False
).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", *[c["name"] for c in new_config]], res.names)
# returned as one list with names filtering # returned as one list with names filtering
res = self.api.tasks.get_configurations( res = self.api.tasks.get_configurations(
@ -216,14 +233,14 @@ class TestTasksHyperparams(TestService):
# delete # delete
new_to_delete = self._get_config_keys(new_config[1:]) new_to_delete = self._get_config_keys(new_config[1:])
self.api.tasks.delete_configuration( self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
task=task, configuration=new_to_delete
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config[:1], res.configuration) self.assertEqual(old_config + new_config[:1], res.configuration)
# clone task # clone task
new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id new_task = self.api.tasks.clone(
task=task, new_task_configuration=new_config_dict
).id
try: try:
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0] res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
self.assertEqual(new_config, res.configuration) self.assertEqual(new_config, res.configuration)
@ -233,13 +250,9 @@ class TestTasksHyperparams(TestService):
# edit/delete of running task # edit/delete of running task
self.api.tasks.started(task=task) self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus): with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_configuration( self.api.tasks.edit_configuration(task=task, configuration=new_config)
task=task, configuration=new_config
)
with self.api.raises(InvalidTaskStatus): with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_configuration( self.api.tasks.delete_configuration(task=task, configuration=new_to_delete)
task=task, configuration=new_to_delete
)
self.api.tasks.edit_configuration( self.api.tasks.edit_configuration(
task=task, configuration=new_config, force=True task=task, configuration=new_config, force=True
) )
@ -292,7 +305,9 @@ class TestTasksHyperparams(TestService):
task_id, _ = self.new_task( task_id, _ = self.new_task(
execution={"parameters": legacy_params, "model_desc": legacy_config} execution={"parameters": legacy_params, "model_desc": legacy_config}
) )
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config)) config = self._config_dict_from_list(
self._new_config_from_legacy(legacy_config)
)
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params)) params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
old_api = APIClient(base_url="http://localhost:8008/v2.8") old_api = APIClient(base_url="http://localhost:8008/v2.8")
@ -304,7 +319,10 @@ class TestTasksHyperparams(TestService):
modified_params = {"legacy.2": "val2"} modified_params = {"legacy.2": "val2"}
modified_config = {"design": "by"} modified_config = {"design": "by"}
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config)) old_api.tasks.edit(
task=task_id,
execution=dict(parameters=modified_params, model_desc=modified_config),
)
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0] task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(modified_params, task.execution.parameters) self.assertEqual(modified_params, task.execution.parameters)
self.assertEqual(modified_config, task.execution.model_desc) self.assertEqual(modified_config, task.execution.model_desc)