Add pattern parameter to projects.get_hyperparam_values

This commit is contained in:
allegroai 2023-11-17 09:34:13 +02:00
parent d0252a6dd9
commit 6d507616b3
5 changed files with 54 additions and 12 deletions

View File

@ -72,6 +72,7 @@ class MultiProjectPagedRequest(MultiProjectRequest):
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest): class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
section = fields.StringField(required=True) section = fields.StringField(required=True)
name = fields.StringField(required=True) name = fields.StringField(required=True)
pattern = fields.StringField()
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest): class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):

View File

@ -140,6 +140,7 @@ class ProjectQueries:
name: str, name: str,
include_subprojects: bool, include_subprojects: bool,
allow_public: bool = True, allow_public: bool = True,
pattern: str = None,
page: int = 0, page: int = 0,
page_size: int = 500, page_size: int = 500,
) -> ParamValues: ) -> ParamValues:
@ -164,7 +165,20 @@ class ProjectQueries:
if not last_updated_task: if not last_updated_task:
return 0, [] return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}_{page}_{page_size}" redis_key = "_".join(
str(part)
for part in (
"hyperparam_values",
company_id,
"_".join(project_ids),
section,
name,
allow_public,
pattern,
page,
page_size,
)
)
last_update = last_updated_task.last_update or datetime.utcnow() last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values( cached_res = self._get_cached_param_values(
key=redis_key, key=redis_key,
@ -176,14 +190,22 @@ class ProjectQueries:
if cached_res: if cached_res:
return cached_res return cached_res
pipeline = [ match_condition = {
{ **company_constraint,
"$match": { **project_constraint,
**company_constraint, key_path: {"$exists": True},
**project_constraint, }
key_path: {"$exists": True}, if pattern:
match_condition["$expr"] = {
"$regexMatch": {
"input": f"${key_path}.value",
"regex": pattern,
"options": "i",
} }
}, }
pipeline = [
{"$match": match_condition},
{"$project": {"value": f"${key_path}.value"}}, {"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}}, {"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}}, {"$sort": {"_id": 1}},

View File

@ -1000,6 +1000,12 @@ get_hyperparam_values {
} }
} }
} }
"999.0": ${get_hyperparam_values."2.26"} {
request.properties.pattern {
type: string
description: The search pattern regex
}
}
} }
get_hyper_parameters { get_hyper_parameters {
"2.9" { "2.9" {

View File

@ -452,6 +452,7 @@ def get_hyperparam_values(
name=request.name, name=request.name,
include_subprojects=request.include_subprojects, include_subprojects=request.include_subprojects,
allow_public=request.allow_public, allow_public=request.allow_public,
pattern=request.pattern,
page=request.page, page=request.page,
page_size=request.page_size, page_size=request.page_size,
) )

View File

@ -12,17 +12,17 @@ class TestTasksFiltering(TestService):
param1 = ("Se$tion1", "pa__ram1", True) param1 = ("Se$tion1", "pa__ram1", True)
param2 = ("Section2", "param2", False) param2 = ("Section2", "param2", False)
task_count = 5 task_count = 5
for p in (param1, param2): for (section, name, unique_value) in (param1, param2):
for idx in range(task_count): for idx in range(task_count):
t = self.temp_task(project=project) t = self.temp_task(project=project)
self.api.tasks.edit_hyper_params( self.api.tasks.edit_hyper_params(
task=t, task=t,
hyperparams=[ hyperparams=[
{ {
"section": p[0], "section": section,
"name": p[1], "name": name,
"type": "str", "type": "str",
"value": str(idx) if p[2] else "Constant", "value": str(idx) if unique_value else "Constant",
} }
], ],
) )
@ -42,6 +42,18 @@ class TestTasksFiltering(TestService):
self.assertEqual(res.total, 0) self.assertEqual(res.total, 0)
self.assertEqual(res["values"], []) self.assertEqual(res["values"], [])
# search pattern
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param1[0], name=param1[1], pattern="^1"
)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], ["1"])
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param1[0], name=param1[1], pattern="11"
)
self.assertEqual(res.total, 0)
def test_datetime_queries(self): def test_datetime_queries(self):
tasks = [self.temp_task() for _ in range(5)] tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow() now = datetime.utcnow()