From 6d507616b381099b137c3226ae180cb1775696ac Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Fri, 17 Nov 2023 09:34:13 +0200 Subject: [PATCH] Add pattern parameter to projects.get_hyperparam_values --- apiserver/apimodels/projects.py | 1 + apiserver/bll/project/project_queries.py | 38 +++++++++++++++---- apiserver/schema/services/projects.conf | 6 +++ apiserver/services/projects.py | 1 + .../tests/automated/test_tasks_filtering.py | 20 ++++++++-- 5 files changed, 54 insertions(+), 12 deletions(-) diff --git a/apiserver/apimodels/projects.py b/apiserver/apimodels/projects.py index 0fcc627..a352c2a 100644 --- a/apiserver/apimodels/projects.py +++ b/apiserver/apimodels/projects.py @@ -72,6 +72,7 @@ class MultiProjectPagedRequest(MultiProjectRequest): class ProjectHyperparamValuesRequest(MultiProjectPagedRequest): section = fields.StringField(required=True) name = fields.StringField(required=True) + pattern = fields.StringField() class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest): diff --git a/apiserver/bll/project/project_queries.py b/apiserver/bll/project/project_queries.py index 6e9619a..3e95103 100644 --- a/apiserver/bll/project/project_queries.py +++ b/apiserver/bll/project/project_queries.py @@ -140,6 +140,7 @@ class ProjectQueries: name: str, include_subprojects: bool, allow_public: bool = True, + pattern: str = None, page: int = 0, page_size: int = 500, ) -> ParamValues: @@ -164,7 +165,20 @@ class ProjectQueries: if not last_updated_task: return 0, [] - redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}_{page}_{page_size}" + redis_key = "_".join( + str(part) + for part in ( + "hyperparam_values", + company_id, + "_".join(project_ids), + section, + name, + allow_public, + pattern, + page, + page_size, + ) + ) last_update = last_updated_task.last_update or datetime.utcnow() cached_res = self._get_cached_param_values( key=redis_key, @@ -176,14 +190,22 @@ class ProjectQueries: if cached_res: return cached_res - pipeline = [ - { - "$match": { - **company_constraint, - **project_constraint, - key_path: {"$exists": True}, + match_condition = { + **company_constraint, + **project_constraint, + key_path: {"$exists": True}, + } + if pattern: + match_condition["$expr"] = { + "$regexMatch": { + "input": f"${key_path}.value", + "regex": pattern, + "options": "i", } - }, + } + + pipeline = [ + {"$match": match_condition}, {"$project": {"value": f"${key_path}.value"}}, {"$group": {"_id": "$value"}}, {"$sort": {"_id": 1}}, diff --git a/apiserver/schema/services/projects.conf b/apiserver/schema/services/projects.conf index f85d3fd..4821eab 100644 --- a/apiserver/schema/services/projects.conf +++ b/apiserver/schema/services/projects.conf @@ -1000,6 +1000,12 @@ get_hyperparam_values { } } } + "999.0": ${get_hyperparam_values."2.26"} { + request.properties.pattern { + type: string + description: The search pattern regex + } + } } get_hyper_parameters { "2.9" { diff --git a/apiserver/services/projects.py b/apiserver/services/projects.py index e2ddc2b..858ab39 100644 --- a/apiserver/services/projects.py +++ b/apiserver/services/projects.py @@ -452,6 +452,7 @@ def get_hyperparam_values( name=request.name, include_subprojects=request.include_subprojects, allow_public=request.allow_public, + pattern=request.pattern, page=request.page, page_size=request.page_size, ) diff --git a/apiserver/tests/automated/test_tasks_filtering.py b/apiserver/tests/automated/test_tasks_filtering.py index c0905e8..5b52564 100644 --- a/apiserver/tests/automated/test_tasks_filtering.py +++ b/apiserver/tests/automated/test_tasks_filtering.py @@ -12,17 +12,17 @@ class TestTasksFiltering(TestService): param1 = ("Se$tion1", "pa__ram1", True) param2 = ("Section2", "param2", False) task_count = 5 - for p in (param1, param2): + for (section, name, unique_value) in (param1, param2): for idx in range(task_count): t = self.temp_task(project=project) self.api.tasks.edit_hyper_params( task=t, hyperparams=[ { - "section": p[0], - "name": p[1], + "section": section, + "name": name, "type": "str", - "value": str(idx) if p[2] else "Constant", + "value": str(idx) if unique_value else "Constant", } ], ) @@ -42,6 +42,18 @@ class TestTasksFiltering(TestService): self.assertEqual(res.total, 0) self.assertEqual(res["values"], []) + # search pattern + res = self.api.projects.get_hyperparam_values( + projects=[project], section=param1[0], name=param1[1], pattern="^1" + ) + self.assertEqual(res.total, 1) + self.assertEqual(res["values"], ["1"]) + + res = self.api.projects.get_hyperparam_values( + projects=[project], section=param1[0], name=param1[1], pattern="11" + ) + self.assertEqual(res.total, 0) + def test_datetime_queries(self): tasks = [self.temp_task() for _ in range(5)] now = datetime.utcnow()