mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Support projects.get_hyperparam_values
This commit is contained in:
parent
4cd4b2914d
commit
7f4ad0d1ca
@ -18,7 +18,15 @@ class ProjectTagsRequest(TagsRequest):
|
|||||||
projects = ListField(str)
|
projects = ListField(str)
|
||||||
|
|
||||||
|
|
||||||
class ProjectTaskParentsRequest(ProjectReq):
|
class MultiProjectReq(models.Base):
|
||||||
projects = ListField(str)
|
projects = fields.ListField(str)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectTaskParentsRequest(MultiProjectReq):
|
||||||
tasks_state = ActualEnumField(EntityVisibility)
|
tasks_state = ActualEnumField(EntityVisibility)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectHyperparamValuesRequest(MultiProjectReq):
|
||||||
|
section = fields.StringField(required=True)
|
||||||
|
name = fields.StringField(required=True)
|
||||||
|
allow_public = fields.BoolField(default=True)
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
|
import json
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||||
|
|
||||||
import dpath
|
import dpath
|
||||||
import six
|
import six
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
|
from redis import StrictRedis
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
|
||||||
import apiserver.database.utils as dbutils
|
import apiserver.database.utils as dbutils
|
||||||
@ -29,6 +31,7 @@ from apiserver.database.model.task.task import (
|
|||||||
from apiserver.database.model import EntityVisibility
|
from apiserver.database.model import EntityVisibility
|
||||||
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
from apiserver.database.utils import get_company_or_none_constraint, id as create_id
|
||||||
from apiserver.es_factory import es_factory
|
from apiserver.es_factory import es_factory
|
||||||
|
from apiserver.redis_manager import redman
|
||||||
from apiserver.service_repo import APICall
|
from apiserver.service_repo import APICall
|
||||||
from apiserver.services.utils import validate_tags
|
from apiserver.services.utils import validate_tags
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
@ -44,10 +47,9 @@ project_bll = ProjectBLL()
|
|||||||
|
|
||||||
|
|
||||||
class TaskBLL:
|
class TaskBLL:
|
||||||
def __init__(self, events_es=None):
|
def __init__(self, events_es=None, redis=None):
|
||||||
self.events_es = (
|
self.events_es = events_es or es_factory.connect("events")
|
||||||
events_es if events_es is not None else es_factory.connect("events")
|
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
def get_types(cls, company, project_ids: Optional[Sequence]) -> set:
|
||||||
@ -240,7 +242,8 @@ class TaskBLL:
|
|||||||
return [
|
return [
|
||||||
tag
|
tag
|
||||||
for tag in input_tags
|
for tag in input_tags
|
||||||
if tag not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
if tag
|
||||||
|
not in [TaskSystemTags.development, EntityVisibility.archived.value]
|
||||||
]
|
]
|
||||||
|
|
||||||
with TimingContext("mongo", "clone task"):
|
with TimingContext("mongo", "clone task"):
|
||||||
@ -632,6 +635,8 @@ class TaskBLL:
|
|||||||
{"$unwind": "$names"},
|
{"$unwind": "$names"},
|
||||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||||
|
{"$skip": page * page_size},
|
||||||
|
{"$limit": page_size},
|
||||||
{
|
{
|
||||||
"$group": {
|
"$group": {
|
||||||
"_id": 1,
|
"_id": 1,
|
||||||
@ -639,16 +644,9 @@ class TaskBLL:
|
|||||||
"results": {"$push": "$$ROOT"},
|
"results": {"$push": "$$ROOT"},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"$project": {
|
|
||||||
"total": 1,
|
|
||||||
"results": {"$slice": ["$results", page * page_size, page_size]},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
with translate_errors_context():
|
result = next(Task.aggregate(pipeline), None)
|
||||||
result = next(Task.aggregate(pipeline), None)
|
|
||||||
|
|
||||||
total = 0
|
total = 0
|
||||||
remaining = 0
|
remaining = 0
|
||||||
@ -669,6 +667,103 @@ class TaskBLL:
|
|||||||
|
|
||||||
return total, remaining, results
|
return total, remaining, results
|
||||||
|
|
||||||
|
HyperParamValues = Tuple[int, Sequence[str]]
|
||||||
|
|
||||||
|
def _get_cached_hyperparam_values(
|
||||||
|
self, key: str, last_update: datetime
|
||||||
|
) -> Optional[HyperParamValues]:
|
||||||
|
allowed_delta = timedelta(
|
||||||
|
seconds=config.get(
|
||||||
|
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
cached = self.redis.get(key)
|
||||||
|
if not cached:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = json.loads(cached)
|
||||||
|
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||||
|
if (last_update - cached_last_update) < allowed_delta:
|
||||||
|
return data["total"], data["values"]
|
||||||
|
except Exception as ex:
|
||||||
|
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||||
|
|
||||||
|
def get_hyperparam_distinct_values(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
project_ids: Sequence[str],
|
||||||
|
section: str,
|
||||||
|
name: str,
|
||||||
|
allow_public: bool = True,
|
||||||
|
) -> HyperParamValues:
|
||||||
|
if allow_public:
|
||||||
|
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
||||||
|
else:
|
||||||
|
company_constraint = {"company": company_id}
|
||||||
|
if project_ids:
|
||||||
|
project_constraint = {"project": {"$in": project_ids}}
|
||||||
|
else:
|
||||||
|
project_constraint = {}
|
||||||
|
|
||||||
|
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||||
|
last_updated_task = (
|
||||||
|
Task.objects(
|
||||||
|
**company_constraint,
|
||||||
|
**project_constraint,
|
||||||
|
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||||
|
)
|
||||||
|
.only("last_update")
|
||||||
|
.order_by("-last_update")
|
||||||
|
.limit(1)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not last_updated_task:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||||
|
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||||
|
cached_res = self._get_cached_hyperparam_values(
|
||||||
|
key=redis_key, last_update=last_update
|
||||||
|
)
|
||||||
|
if cached_res:
|
||||||
|
return cached_res
|
||||||
|
|
||||||
|
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
**company_constraint,
|
||||||
|
**project_constraint,
|
||||||
|
key_path: {"$exists": True},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"value": f"${key_path}.value"}},
|
||||||
|
{"$group": {"_id": "$value"}},
|
||||||
|
{"$sort": {"_id": 1}},
|
||||||
|
{"$limit": max_values},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": 1,
|
||||||
|
"total": {"$sum": 1},
|
||||||
|
"results": {"$push": "$$ROOT._id"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||||
|
if not result:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
total = int(result.get("total", 0))
|
||||||
|
values = result.get("results", [])
|
||||||
|
|
||||||
|
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||||
|
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||||
|
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||||
|
|
||||||
|
return total, values
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def dequeue_and_change_status(
|
def dequeue_and_change_status(
|
||||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||||
|
|||||||
@ -9,3 +9,14 @@ non_responsive_tasks_watchdog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
multi_task_histogram_limit: 100
|
multi_task_histogram_limit: 100
|
||||||
|
|
||||||
|
hyperparam_values {
|
||||||
|
# maximal amount of distinct hyperparam values to retrieve
|
||||||
|
max_count: 100
|
||||||
|
|
||||||
|
# max allowed outdate time for the cashed result
|
||||||
|
cache_allowed_outdate_sec: 60
|
||||||
|
|
||||||
|
# cache ttl sec
|
||||||
|
cache_ttl_sec: 86400
|
||||||
|
}
|
||||||
@ -531,6 +531,48 @@ get_unique_metric_variants {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
get_hyperparam_values {
|
||||||
|
"2.13" {
|
||||||
|
description: """Get a list of distinct values for the chosen hyperparameter"""
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [section, name]
|
||||||
|
properties {
|
||||||
|
projects {
|
||||||
|
description: "Project IDs"
|
||||||
|
type: array
|
||||||
|
items {type: string}
|
||||||
|
}
|
||||||
|
section {
|
||||||
|
description: "Hyperparameter section name"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
name {
|
||||||
|
description: "Hyperparameter name"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
allow_public {
|
||||||
|
description: "If set to 'true' then collect values from both company and public tasks otherwise company tasks only. The default is 'true'"
|
||||||
|
type: boolean
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
total {
|
||||||
|
description: "Total number of distinct parameter values"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
values {
|
||||||
|
description: "The list of the unique values for the parameter"
|
||||||
|
type: array
|
||||||
|
items {type: string}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
get_hyper_parameters {
|
get_hyper_parameters {
|
||||||
"2.9" {
|
"2.9" {
|
||||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from apiserver.apimodels.projects import (
|
|||||||
ProjectReq,
|
ProjectReq,
|
||||||
ProjectTagsRequest,
|
ProjectTagsRequest,
|
||||||
ProjectTaskParentsRequest,
|
ProjectTaskParentsRequest,
|
||||||
|
ProjectHyperparamValuesRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
from apiserver.bll.project import ProjectBLL
|
from apiserver.bll.project import ProjectBLL
|
||||||
@ -397,6 +398,27 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint(
|
||||||
|
"projects.get_hyperparam_values",
|
||||||
|
min_version="2.13",
|
||||||
|
request_data_model=ProjectHyperparamValuesRequest,
|
||||||
|
)
|
||||||
|
def get_hyperparam_values(
|
||||||
|
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||||
|
):
|
||||||
|
total, values = task_bll.get_hyperparam_distinct_values(
|
||||||
|
company_id,
|
||||||
|
project_ids=request.projects,
|
||||||
|
section=request.section,
|
||||||
|
name=request.name,
|
||||||
|
allow_public=request.allow_public,
|
||||||
|
)
|
||||||
|
call.result.data = {
|
||||||
|
"total": total,
|
||||||
|
"values": values,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||||
)
|
)
|
||||||
|
|||||||
76
apiserver/tests/automated/test_tasks_filtering.py
Normal file
76
apiserver/tests/automated/test_tasks_filtering.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from apiserver.tests.automated import TestService
|
||||||
|
|
||||||
|
|
||||||
|
class TestTasksFiltering(TestService):
|
||||||
|
def setUp(self, **kwargs):
|
||||||
|
super().setUp(version="2.13")
|
||||||
|
|
||||||
|
def test_hyperparam_values(self):
|
||||||
|
project = self.temp_project()
|
||||||
|
param1 = ("Se$tion1", "pa__ram1", True)
|
||||||
|
param2 = ("Section2", "param2", False)
|
||||||
|
task_count = 5
|
||||||
|
for p in (param1, param2):
|
||||||
|
for idx in range(task_count):
|
||||||
|
t = self.temp_task(project=project)
|
||||||
|
self.api.tasks.edit_hyper_params(
|
||||||
|
task=t,
|
||||||
|
hyperparams=[
|
||||||
|
{
|
||||||
|
"section": p[0],
|
||||||
|
"name": p[1],
|
||||||
|
"type": "str",
|
||||||
|
"value": str(idx) if p[2] else "Constant",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
res = self.api.projects.get_hyperparam_values(
|
||||||
|
projects=[project], section=param1[0], name=param1[1]
|
||||||
|
)
|
||||||
|
self.assertEqual(res.total, task_count)
|
||||||
|
self.assertEqual(res["values"], [str(i) for i in range(task_count)])
|
||||||
|
res = self.api.projects.get_hyperparam_values(
|
||||||
|
projects=[project], section=param2[0], name=param2[1]
|
||||||
|
)
|
||||||
|
self.assertEqual(res.total, 1)
|
||||||
|
self.assertEqual(res["values"], ["Constant"])
|
||||||
|
res = self.api.projects.get_hyperparam_values(
|
||||||
|
projects=[project], section="missing", name="missing"
|
||||||
|
)
|
||||||
|
self.assertEqual(res.total, 0)
|
||||||
|
self.assertEqual(res["values"], [])
|
||||||
|
|
||||||
|
def test_range_queries(self):
|
||||||
|
tasks = [self.temp_task() for _ in range(5)]
|
||||||
|
now = datetime.utcnow()
|
||||||
|
for task in tasks:
|
||||||
|
self.api.tasks.started(task=task)
|
||||||
|
|
||||||
|
res = self.api.tasks.get_all_ex(started=[now.isoformat(), None]).tasks
|
||||||
|
self.assertTrue(set(tasks).issubset({t.id for t in res}))
|
||||||
|
|
||||||
|
res = self.api.tasks.get_all_ex(
|
||||||
|
started=[(now - timedelta(seconds=60)).isoformat(), now.isoformat()]
|
||||||
|
).tasks
|
||||||
|
self.assertFalse(set(tasks).issubset({t.id for t in res}))
|
||||||
|
|
||||||
|
def temp_project(self, **kwargs) -> str:
|
||||||
|
self.update_missing(
|
||||||
|
kwargs,
|
||||||
|
name="Test tasks filtering",
|
||||||
|
description="test",
|
||||||
|
delete_params=dict(force=True),
|
||||||
|
)
|
||||||
|
return self.create_temp("projects", **kwargs)
|
||||||
|
|
||||||
|
def temp_task(self, **kwargs) -> str:
|
||||||
|
self.update_missing(
|
||||||
|
kwargs,
|
||||||
|
type="testing",
|
||||||
|
name="test tasks filtering",
|
||||||
|
input=dict(view=dict()),
|
||||||
|
delete_params=dict(force=True),
|
||||||
|
)
|
||||||
|
return self.create_temp("tasks", **kwargs)
|
||||||
Loading…
Reference in New Issue
Block a user