Add pattern parameter to projects.get_hyperparam_values

This commit is contained in:
allegroai 2023-11-17 09:34:13 +02:00
parent d0252a6dd9
commit 6d507616b3
5 changed files with 54 additions and 12 deletions

View File

@ -72,6 +72,7 @@ class MultiProjectPagedRequest(MultiProjectRequest):
class ProjectHyperparamValuesRequest(MultiProjectPagedRequest):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
pattern = fields.StringField()
class ProjectModelMetadataValuesRequest(MultiProjectPagedRequest):

View File

@ -140,6 +140,7 @@ class ProjectQueries:
name: str,
include_subprojects: bool,
allow_public: bool = True,
pattern: str = None,
page: int = 0,
page_size: int = 500,
) -> ParamValues:
@ -164,7 +165,20 @@ class ProjectQueries:
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}_{page}_{page_size}"
redis_key = "_".join(
str(part)
for part in (
"hyperparam_values",
company_id,
"_".join(project_ids),
section,
name,
allow_public,
pattern,
page,
page_size,
)
)
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key,
@ -176,14 +190,22 @@ class ProjectQueries:
if cached_res:
return cached_res
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
match_condition = {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
if pattern:
match_condition["$expr"] = {
"$regexMatch": {
"input": f"${key_path}.value",
"regex": pattern,
"options": "i",
}
},
}
pipeline = [
{"$match": match_condition},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},

View File

@ -1000,6 +1000,12 @@ get_hyperparam_values {
}
}
}
"999.0": ${get_hyperparam_values."2.26"} {
request.properties.pattern {
type: string
description: The search pattern regex
}
}
}
get_hyper_parameters {
"2.9" {

View File

@ -452,6 +452,7 @@ def get_hyperparam_values(
name=request.name,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
pattern=request.pattern,
page=request.page,
page_size=request.page_size,
)

View File

@ -12,17 +12,17 @@ class TestTasksFiltering(TestService):
param1 = ("Se$tion1", "pa__ram1", True)
param2 = ("Section2", "param2", False)
task_count = 5
for p in (param1, param2):
for (section, name, unique_value) in (param1, param2):
for idx in range(task_count):
t = self.temp_task(project=project)
self.api.tasks.edit_hyper_params(
task=t,
hyperparams=[
{
"section": p[0],
"name": p[1],
"section": section,
"name": name,
"type": "str",
"value": str(idx) if p[2] else "Constant",
"value": str(idx) if unique_value else "Constant",
}
],
)
@ -42,6 +42,18 @@ class TestTasksFiltering(TestService):
self.assertEqual(res.total, 0)
self.assertEqual(res["values"], [])
# search pattern
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param1[0], name=param1[1], pattern="^1"
)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], ["1"])
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param1[0], name=param1[1], pattern="11"
)
self.assertEqual(res.total, 0)
def test_datetime_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()