diff --git a/apiserver/apimodels/tasks.py b/apiserver/apimodels/tasks.py index 39f7430..4a7ede4 100644 --- a/apiserver/apimodels/tasks.py +++ b/apiserver/apimodels/tasks.py @@ -201,7 +201,7 @@ class GetConfigurationsRequest(MultiTaskRequest): class GetConfigurationNamesRequest(MultiTaskRequest): - pass + skip_empty = BoolField(default=True) class Configuration(models.Base): diff --git a/apiserver/bll/task/hyperparams.py b/apiserver/bll/task/hyperparams.py index eb8e8bb..d61cc7b 100644 --- a/apiserver/bll/task/hyperparams.py +++ b/apiserver/bll/task/hyperparams.py @@ -175,21 +175,23 @@ class HyperParams: @classmethod def get_configuration_names( - cls, company_id: str, task_ids: Sequence[str] + cls, company_id: str, task_ids: Sequence[str], skip_empty: bool ) -> Dict[str, list]: - with TimingContext("mongo", "get_configuration_names"): - pipeline = [ - { - "$match": { - "company": {"$in": [None, "", company_id]}, - "_id": {"$in": task_ids}, - } - }, - {"$project": {"items": {"$objectToArray": "$configuration"}}}, - {"$unwind": "$items"}, - {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}}, - ] + skip_empty_condition = {"$match": {"items.v.value": {"$nin": [None, ""]}}} + pipeline = [ + { + "$match": { + "company": {"$in": [None, "", company_id]}, + "_id": {"$in": task_ids}, + } + }, + {"$project": {"items": {"$objectToArray": "$configuration"}}}, + {"$unwind": "$items"}, + *([skip_empty_condition] if skip_empty else []), + {"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}}, + ] + with TimingContext("mongo", "get_configuration_names"): tasks = Task.aggregate(pipeline) return { diff --git a/apiserver/schema/services/tasks.conf b/apiserver/schema/services/tasks.conf index 6bd5000..2ac1f56 100644 --- a/apiserver/schema/services/tasks.conf +++ b/apiserver/schema/services/tasks.conf @@ -2200,6 +2200,11 @@ get_configuration_names { type: array items { type: string } } + skip_empty { + description: If set to 'true' then the names for configurations with missing values are not returned + type: boolean + default: true + } } } response { diff --git a/apiserver/services/tasks.py b/apiserver/services/tasks.py index f6be307..34a320e 100644 --- a/apiserver/services/tasks.py +++ b/apiserver/services/tasks.py @@ -789,7 +789,7 @@ def get_configuration_names( ): with translate_errors_context(): tasks_params = HyperParams.get_configuration_names( - company_id, task_ids=request.tasks + company_id, task_ids=request.tasks, skip_empty=request.skip_empty ) call.result.data = { diff --git a/apiserver/tests/automated/test_task_hyperparams.py b/apiserver/tests/automated/test_task_hyperparams.py index 67f0376..2f2784a 100644 --- a/apiserver/tests/automated/test_task_hyperparams.py +++ b/apiserver/tests/automated/test_task_hyperparams.py @@ -105,7 +105,9 @@ class TestTasksHyperparams(TestService): ) # clone task - new_task = self.api.tasks.clone(task=task, new_task_hyperparams=new_params_dict).id + new_task = self.api.tasks.clone( + task=task, new_task_hyperparams=new_params_dict + ).id try: res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0] self.assertEqual(new_params, res.hyperparams) @@ -123,7 +125,9 @@ class TestTasksHyperparams(TestService): task=task, hyperparams=[dict(section="test")] ) self.api.tasks.edit_hyper_params( - task=task, hyperparams=[dict(section="test", name="x", value="123")], force=True + task=task, + hyperparams=[dict(section="test", name="x", value="123")], + force=True, ) self.api.tasks.delete_hyper_params( task=task, hyperparams=[dict(section="test")], force=True @@ -146,7 +150,12 @@ class TestTasksHyperparams(TestService): return [ dict(section="Args", name=k, value=str(v), type="legacy") if not k.startswith("TF_DEFINE/") - else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy") + else dict( + section="TF_DEFINE", + name=k[len("TF_DEFINE/") :], + value=str(v), + type="legacy", + ) for k, v in legacy.items() ] @@ -168,6 +177,7 @@ class TestTasksHyperparams(TestService): new_config = [ dict(name="param$1", type="type1", value="10"), dict(name="param/2", type="type1", value="20"), + dict(name="param_empty", type="type1", value=""), ] new_config_dict = self._config_dict_from_list(new_config) task, _ = self.new_task( @@ -188,7 +198,14 @@ class TestTasksHyperparams(TestService): # names res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0] self.assertEqual(task, res.task) - self.assertEqual(["design", "param$1", "param/2"], res.names) + self.assertEqual( + ["design", *[c["name"] for c in new_config if c["value"]]], res.names + ) + res = self.api.tasks.get_configuration_names( + tasks=[task], skip_empty=False + ).configurations[0] + self.assertEqual(task, res.task) + self.assertEqual(["design", *[c["name"] for c in new_config]], res.names) # returned as one list with names filtering res = self.api.tasks.get_configurations( @@ -216,14 +233,14 @@ class TestTasksHyperparams(TestService): # delete new_to_delete = self._get_config_keys(new_config[1:]) - self.api.tasks.delete_configuration( - task=task, configuration=new_to_delete - ) + self.api.tasks.delete_configuration(task=task, configuration=new_to_delete) res = self.api.tasks.get_configurations(tasks=[task]).configurations[0] self.assertEqual(old_config + new_config[:1], res.configuration) # clone task - new_task = self.api.tasks.clone(task=task, new_task_configuration=new_config_dict).id + new_task = self.api.tasks.clone( + task=task, new_task_configuration=new_config_dict + ).id try: res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0] self.assertEqual(new_config, res.configuration) @@ -233,13 +250,9 @@ class TestTasksHyperparams(TestService): # edit/delete of running task self.api.tasks.started(task=task) with self.api.raises(InvalidTaskStatus): - self.api.tasks.edit_configuration( - task=task, configuration=new_config - ) + self.api.tasks.edit_configuration(task=task, configuration=new_config) with self.api.raises(InvalidTaskStatus): - self.api.tasks.delete_configuration( - task=task, configuration=new_to_delete - ) + self.api.tasks.delete_configuration(task=task, configuration=new_to_delete) self.api.tasks.edit_configuration( task=task, configuration=new_config, force=True ) @@ -292,7 +305,9 @@ class TestTasksHyperparams(TestService): task_id, _ = self.new_task( execution={"parameters": legacy_params, "model_desc": legacy_config} ) - config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config)) + config = self._config_dict_from_list( + self._new_config_from_legacy(legacy_config) + ) params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params)) old_api = APIClient(base_url="http://localhost:8008/v2.8") @@ -304,7 +319,10 @@ class TestTasksHyperparams(TestService): modified_params = {"legacy.2": "val2"} modified_config = {"design": "by"} - old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config)) + old_api.tasks.edit( + task=task_id, + execution=dict(parameters=modified_params, model_desc=modified_config), + ) task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0] self.assertEqual(modified_params, task.execution.parameters) self.assertEqual(modified_config, task.execution.model_desc)