Support projects.get_hyperparam_values

This commit is contained in:
allegroai 2021-05-03 17:34:40 +03:00
parent 4cd4b2914d
commit 7f4ad0d1ca
6 changed files with 270 additions and 16 deletions

View File

@ -18,7 +18,15 @@ class ProjectTagsRequest(TagsRequest):
projects = ListField(str)
class ProjectTaskParentsRequest(ProjectReq):
projects = ListField(str)
class MultiProjectReq(models.Base):
projects = fields.ListField(str)
class ProjectTaskParentsRequest(MultiProjectReq):
tasks_state = ActualEnumField(EntityVisibility)
class ProjectHyperparamValuesRequest(MultiProjectReq):
section = fields.StringField(required=True)
name = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)

View File

@ -1,10 +1,12 @@
import json
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, timedelta
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
from six import string_types
import apiserver.database.utils as dbutils
@ -29,6 +31,7 @@ from apiserver.database.model.task.task import (
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags
from apiserver.timing_context import TimingContext
@ -44,10 +47,9 @@ project_bll = ProjectBLL()
class TaskBLL:
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
def __init__(self, events_es=None, redis=None):
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
@ -240,7 +242,8 @@ class TaskBLL:
return [
tag
for tag in input_tags
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
if tag
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
@ -632,6 +635,8 @@ class TaskBLL:
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
@ -639,16 +644,9 @@ class TaskBLL:
"results": {"$push": "$$ROOT"},
}
},
{
"$project": {
"total": 1,
"results": {"$slice": ["$results", page * page_size, page_size]},
}
},
]
with translate_errors_context():
result = next(Task.aggregate(pipeline), None)
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
@ -669,6 +667,103 @@ class TaskBLL:
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
company_constraint = {"company": {"$in": [None, "", company_id]}}
else:
company_constraint = {"company": company_id}
if project_ids:
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,

View File

@ -9,3 +9,14 @@ non_responsive_tasks_watchdog {
}
multi_task_histogram_limit: 100
hyperparam_values {
# maximal amount of distinct hyperparam values to retrieve
max_count: 100
# max allowed outdate time for the cashed result
cache_allowed_outdate_sec: 60
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@ -531,6 +531,48 @@ get_unique_metric_variants {
}
}
}
get_hyperparam_values {
"2.13" {
description: """Get a list of distinct values for the chosen hyperparameter"""
request {
type: object
required: [section, name]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
section {
description: "Hyperparameter section name"
type: string
}
name {
description: "Hyperparameter name"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
type: boolean
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct parameter values"
type: integer
}
values {
description: "The list of the unique values for the parameter"
type: array
items {type: string}
}
}
}
}
}
get_hyper_parameters {
"2.9" {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""

View File

@ -14,6 +14,7 @@ from apiserver.apimodels.projects import (
ProjectReq,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
@ -397,6 +398,27 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
}
@endpoint(
"projects.get_hyperparam_values",
min_version="2.13",
request_data_model=ProjectHyperparamValuesRequest,
)
def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
total, values = task_bll.get_hyperparam_distinct_values(
company_id,
project_ids=request.projects,
section=request.section,
name=request.name,
allow_public=request.allow_public,
)
call.result.data = {
"total": total,
"values": values,
}
@endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)

View File

@ -0,0 +1,76 @@
from datetime import datetime, timedelta
from apiserver.tests.automated import TestService
class TestTasksFiltering(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_hyperparam_values(self):
project = self.temp_project()
param1 = ("Se$tion1", "pa__ram1", True)
param2 = ("Section2", "param2", False)
task_count = 5
for p in (param1, param2):
for idx in range(task_count):
t = self.temp_task(project=project)
self.api.tasks.edit_hyper_params(
task=t,
hyperparams=[
{
"section": p[0],
"name": p[1],
"type": "str",
"value": str(idx) if p[2] else "Constant",
}
],
)
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param1[0], name=param1[1]
)
self.assertEqual(res.total, task_count)
self.assertEqual(res["values"], [str(i) for i in range(task_count)])
res = self.api.projects.get_hyperparam_values(
projects=[project], section=param2[0], name=param2[1]
)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], ["Constant"])
res = self.api.projects.get_hyperparam_values(
projects=[project], section="missing", name="missing"
)
self.assertEqual(res.total, 0)
self.assertEqual(res["values"], [])
def test_range_queries(self):
tasks = [self.temp_task() for _ in range(5)]
now = datetime.utcnow()
for task in tasks:
self.api.tasks.started(task=task)
res = self.api.tasks.get_all_ex(started=[now.isoformat(), None]).tasks
self.assertTrue(set(tasks).issubset({t.id for t in res}))
res = self.api.tasks.get_all_ex(
started=[(now - timedelta(seconds=60)).isoformat(), now.isoformat()]
).tasks
self.assertFalse(set(tasks).issubset({t.id for t in res}))
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test tasks filtering",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)
def temp_task(self, **kwargs) -> str:
self.update_missing(
kwargs,
type="testing",
name="test tasks filtering",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)