mirror of
https://github.com/clearml/clearml-server
synced 2025-04-28 17:51:24 +00:00
Support projects.get_hyperparam_values
This commit is contained in:
parent
4cd4b2914d
commit
7f4ad0d1ca
@ -18,7 +18,15 @@ class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(ProjectReq):
|
||||
projects = ListField(str)
|
||||
class MultiProjectReq(models.Base):
|
||||
projects = fields.ListField(str)
|
||||
|
||||
|
||||
class ProjectTaskParentsRequest(MultiProjectReq):
|
||||
tasks_state = ActualEnumField(EntityVisibility)
|
||||
|
||||
|
||||
class ProjectHyperparamValuesRequest(MultiProjectReq):
|
||||
section = fields.StringField(required=True)
|
||||
name = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
@ -1,10 +1,12 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
from six import string_types
|
||||
|
||||
import apiserver.database.utils as dbutils
|
||||
@ -29,6 +31,7 @@ from apiserver.database.model.task.task import (
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags
|
||||
from apiserver.timing_context import TimingContext
|
||||
@ -44,10 +47,9 @@ project_bll = ProjectBLL()
|
||||
|
||||
|
||||
class TaskBLL:
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
def __init__(self, events_es=None, redis=None):
|
||||
self.events_es = events_es or es_factory.connect("events")
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@classmethod
|
||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||
@ -240,7 +242,8 @@ class TaskBLL:
|
||||
return [
|
||||
tag
|
||||
for tag in input_tags
|
||||
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
if tag
|
||||
not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||
]
|
||||
|
||||
with TimingContext("mongo", "clone task"):
|
||||
@ -632,6 +635,8 @@ class TaskBLL:
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
@ -639,16 +644,9 @@ class TaskBLL:
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
{
|
||||
"$project": {
|
||||
"total": 1,
|
||||
"results": {"$slice": ["$results", page * page_size, page_size]},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
@ -669,6 +667,103 @@ class TaskBLL:
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
if allow_public:
|
||||
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
||||
else:
|
||||
company_constraint = {"company": company_id}
|
||||
if project_ids:
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
|
@ -9,3 +9,14 @@ non_responsive_tasks_watchdog {
|
||||
}
|
||||
|
||||
multi_task_histogram_limit: 100
|
||||
|
||||
hyperparam_values {
|
||||
# maximal amount of distinct hyperparam values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# max allowed outdate time for the cashed result
|
||||
cache_allowed_outdate_sec: 60
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
@ -531,6 +531,48 @@ get_unique_metric_variants {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyperparam_values {
|
||||
"2.13" {
|
||||
description: """Get a list of distinct values for the chosen hyperparameter"""
|
||||
request {
|
||||
type: object
|
||||
required: [section, name]
|
||||
properties {
|
||||
projects {
|
||||
description: "Project IDs"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
section {
|
||||
description: "Hyperparameter section name"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Hyperparameter name"
|
||||
type: string
|
||||
}
|
||||
allow_public {
|
||||
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
total {
|
||||
description: "Total number of distinct parameter values"
|
||||
type: integer
|
||||
}
|
||||
values {
|
||||
description: "The list of the unique values for the parameter"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_hyper_parameters {
|
||||
"2.9" {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
|
@ -14,6 +14,7 @@ from apiserver.apimodels.projects import (
|
||||
ProjectReq,
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
@ -397,6 +398,27 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyperparam_values",
|
||||
min_version="2.13",
|
||||
request_data_model=ProjectHyperparamValuesRequest,
|
||||
)
|
||||
def get_hyperparam_values(
|
||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||
):
|
||||
total, values = task_bll.get_hyperparam_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
name=request.name,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"values": values,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
|
76
apiserver/tests/automated/test_tasks_filtering.py
Normal file
76
apiserver/tests/automated/test_tasks_filtering.py
Normal file
@ -0,0 +1,76 @@
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTasksFiltering(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_hyperparam_values(self):
|
||||
project = self.temp_project()
|
||||
param1 = ("Se$tion1", "pa__ram1", True)
|
||||
param2 = ("Section2", "param2", False)
|
||||
task_count = 5
|
||||
for p in (param1, param2):
|
||||
for idx in range(task_count):
|
||||
t = self.temp_task(project=project)
|
||||
self.api.tasks.edit_hyper_params(
|
||||
task=t,
|
||||
hyperparams=[
|
||||
{
|
||||
"section": p[0],
|
||||
"name": p[1],
|
||||
"type": "str",
|
||||
"value": str(idx) if p[2] else "Constant",
|
||||
}
|
||||
],
|
||||
)
|
||||
res = self.api.projects.get_hyperparam_values(
|
||||
projects=[project], section=param1[0], name=param1[1]
|
||||
)
|
||||
self.assertEqual(res.total, task_count)
|
||||
self.assertEqual(res["values"], [str(i) for i in range(task_count)])
|
||||
res = self.api.projects.get_hyperparam_values(
|
||||
projects=[project], section=param2[0], name=param2[1]
|
||||
)
|
||||
self.assertEqual(res.total, 1)
|
||||
self.assertEqual(res["values"], ["Constant"])
|
||||
res = self.api.projects.get_hyperparam_values(
|
||||
projects=[project], section="missing", name="missing"
|
||||
)
|
||||
self.assertEqual(res.total, 0)
|
||||
self.assertEqual(res["values"], [])
|
||||
|
||||
def test_range_queries(self):
|
||||
tasks = [self.temp_task() for _ in range(5)]
|
||||
now = datetime.utcnow()
|
||||
for task in tasks:
|
||||
self.api.tasks.started(task=task)
|
||||
|
||||
res = self.api.tasks.get_all_ex(started=[now.isoformat(), None]).tasks
|
||||
self.assertTrue(set(tasks).issubset({t.id for t in res}))
|
||||
|
||||
res = self.api.tasks.get_all_ex(
|
||||
started=[(now - timedelta(seconds=60)).isoformat(), now.isoformat()]
|
||||
).tasks
|
||||
self.assertFalse(set(tasks).issubset({t.id for t in res}))
|
||||
|
||||
def temp_project(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
name="Test tasks filtering",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
||||
def temp_task(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
type="testing",
|
||||
name="test tasks filtering",
|
||||
input=dict(view=dict()),
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
Loading…
Reference in New Issue
Block a user