Support projects.get_hyperparam_values

This commit is contained in:
allegroai
2021-05-03 17:34:40 +03:00
parent 4cd4b2914d
commit 7f4ad0d1ca
6 changed files with 270 additions and 16 deletions

View File

@@ -1,10 +1,12 @@
import json
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, timedelta
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
from six import string_types
import apiserver.database.utils as dbutils
@@ -29,6 +31,7 @@ from apiserver.database.model.task.task import (
from apiserver.database.model import EntityVisibility
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
from apiserver.es_factory import es_factory
from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags
from apiserver.timing_context import TimingContext
@@ -44,10 +47,9 @@ project_bll = ProjectBLL()
class TaskBLL:
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
def __init__(self, events_es=None, redis=None):
self.events_es = events_es or es_factory.connect("events")
self.redis: StrictRedis = redis or redman.connection("apiserver")
@classmethod
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
@@ -240,7 +242,8 @@ class TaskBLL:
return [
tag
for tag in input_tags
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
if tag
not in [TaskSystemTags.development, EntityVisibility.archived.value]
]
with TimingContext("mongo", "clone task"):
@@ -632,6 +635,8 @@ class TaskBLL:
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
@@ -639,16 +644,9 @@ class TaskBLL:
"results": {"$push": "$$ROOT"},
}
},
{
"$project": {
"total": 1,
"results": {"$slice": ["$results", page * page_size, page_size]},
}
},
]
with translate_errors_context():
result = next(Task.aggregate(pipeline), None)
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
@@ -669,6 +667,103 @@ class TaskBLL:
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
company_constraint = {"company": {"$in": [None, "", company_id]}}
else:
company_constraint = {"company": company_id}
if project_ids:
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,