mirror of
https://github.com/clearml/clearml-server
synced 2025-04-24 16:14:42 +00:00
Add support for credentials label
Support no_scroll in events.get_task_plots Support better project stats Fix Redis required on mongodb initialization Update tests
This commit is contained in:
parent
92fd98d5ad
commit
447adb9090
apiserver
apimodels
bll
database
elastic/mappings
schema/services
services
tests/automated
@ -75,6 +75,7 @@ class CreateUserResponse(Base):
|
||||
class Credentials(Base):
|
||||
access_key = StringField(required=True)
|
||||
secret_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
@ -82,6 +83,10 @@ class CredentialsResponse(Credentials):
|
||||
last_used = DateTimeField(default=None)
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Base):
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
credentials = EmbeddedField(Credentials)
|
||||
|
||||
|
@ -135,4 +135,5 @@ class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetHyperParamRequest(ProjectOrNoneRequest):
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
|
@ -2,7 +2,11 @@ from datetime import datetime
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
||||
from apiserver.apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
CreateUserRequest,
|
||||
Credentials as CredModel,
|
||||
)
|
||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
@ -145,7 +149,7 @@ class AuthBLL:
|
||||
|
||||
@classmethod
|
||||
def create_credentials(
|
||||
cls, user_id: str, company_id: str, role: str = None
|
||||
cls, user_id: str, company_id: str, role: str = None, label: str = None,
|
||||
) -> CredModel:
|
||||
|
||||
with translate_errors_context():
|
||||
@ -154,7 +158,9 @@ class AuthBLL:
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
||||
cred = CredModel(
|
||||
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||
)
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
)
|
||||
|
@ -534,6 +534,7 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
no_scroll: bool = False,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
@ -611,7 +612,7 @@ class EventBLL(object):
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@ -680,6 +681,7 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
no_scroll=False,
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@ -740,7 +742,7 @@ class EventBLL(object):
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
|
@ -1,2 +1,3 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .project_queries import ProjectQueries
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
|
@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
@ -306,6 +306,7 @@ class ProjectBLL:
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
|
||||
|
||||
@classmethod
|
||||
def make_projects_get_all_pipelines(
|
||||
@ -367,6 +368,26 @@ class ProjectBLL:
|
||||
},
|
||||
]
|
||||
|
||||
def completed_after_subquery(additional_cond, time_thresh: datetime):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed after the time_thresh
|
||||
"if": {
|
||||
"$and": [
|
||||
"$completed",
|
||||
{"$gt": ["$completed", time_thresh]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
"then": 1,
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
@ -397,16 +418,19 @@ class ProjectBLL:
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
time_thresh = datetime.utcnow() - timedelta(hours=24)
|
||||
for state in cls.visibility_states:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery(
|
||||
{"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
|
||||
cond = (
|
||||
cls.archived_tasks_cond
|
||||
if state == EntityVisibility.archived
|
||||
else {"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
group_step[state.value] = runtime_subquery(cond)
|
||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||
cond, time_thresh=time_thresh
|
||||
)
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
@ -534,15 +558,24 @@ class ProjectBLL:
|
||||
)
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
project_runtime = runtime.get(project_id, {})
|
||||
project_section_statuses = nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
)
|
||||
return {
|
||||
"total_runtime": nested_get(runtime, (project_id, section), default=0),
|
||||
"status_count": nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
"status_count": project_section_statuses,
|
||||
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
||||
"total_tasks": sum(project_section_statuses.values()),
|
||||
"total_runtime": project_runtime.get(section, 0),
|
||||
"completed_tasks": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
s
|
||||
for s in cls.visibility_states
|
||||
if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
stats = {
|
||||
|
289
apiserver/bll/project/project_queries.py
Normal file
289
apiserver/bll/project/project_queries.py
Normal file
@ -0,0 +1,289 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .sub_projects import _ids_with_children
|
||||
from ...database.model.model import Model
|
||||
from ...database.model.task.task import Task
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectQueries:
|
||||
def __init__(self, redis=None):
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
if include_subprojects:
|
||||
if project_ids is None:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
if allow_public:
|
||||
return {"company": {"$in": [None, "", company_id]}}
|
||||
|
||||
return {"company": company_id}
|
||||
|
||||
@classmethod
|
||||
def get_aggregated_project_parameters(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "section"))
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "name"))
|
||||
),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def get_unique_metric_variants(
|
||||
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
def get_model_metadata_keys(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"metadata": {"$exists": True, "$ne": []},
|
||||
}
|
||||
},
|
||||
{"$project": {"metadata": 1}},
|
||||
{"$unwind": "$metadata"},
|
||||
{"$group": {"_id": "$metadata.key"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [r.get("_id") for r in result.get("results", [])]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
@ -1,9 +1,6 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
@ -14,7 +11,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
@ -350,54 +346,6 @@ class TaskBLL:
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(
|
||||
company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company={"$in": [None, "", company_id]}, **project_constraint,
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str],
|
||||
@ -494,173 +442,6 @@ class TaskBLL:
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**project_constraint,
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
if allow_public:
|
||||
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
||||
else:
|
||||
company_constraint = {"company": company_id}
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
|
@ -48,6 +48,7 @@ class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
label = StringField()
|
||||
last_used = DateTimeField()
|
||||
|
||||
|
||||
|
@ -142,7 +142,9 @@ class GetMixin(PropsMixin):
|
||||
self.allow_empty = False
|
||||
|
||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None
|
||||
op = (
|
||||
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
return tup[0] if tup else None
|
||||
@ -177,7 +179,9 @@ class GetMixin(PropsMixin):
|
||||
"all": Q.AND,
|
||||
}
|
||||
data = (x for x in data if x is not None)
|
||||
first_op = self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
||||
first_op = (
|
||||
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
||||
)
|
||||
return op_to_res.get(first_op, self.default_mongo_op)
|
||||
|
||||
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
||||
@ -202,13 +206,20 @@ class GetMixin(PropsMixin):
|
||||
id = StringField(primary_key=True)
|
||||
position = IntField(default=0)
|
||||
|
||||
cache_manager = RedisCacheManager(
|
||||
state_class=GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
_cache_manager = None
|
||||
|
||||
@classmethod
|
||||
def get_cache_manager(cls):
|
||||
if not cls._cache_manager:
|
||||
cls._cache_manager = RedisCacheManager(
|
||||
state_class=cls.GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
return cls._cache_manager
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
@ -463,10 +474,7 @@ class GetMixin(PropsMixin):
|
||||
if not queries:
|
||||
q = RegexQ()
|
||||
else:
|
||||
q = RegexQCombination(
|
||||
operation=global_op,
|
||||
children=queries
|
||||
)
|
||||
q = RegexQCombination(operation=global_op, children=queries)
|
||||
|
||||
if not helper.allow_empty:
|
||||
return q
|
||||
@ -609,7 +617,7 @@ class GetMixin(PropsMixin):
|
||||
state: Optional[cls.GetManyScrollState] = None
|
||||
if "scroll_id" in query_dict:
|
||||
size = cls.validate_scroll_size(query_dict)
|
||||
state = cls.cache_manager.get_or_create_state_core(
|
||||
state = cls.get_cache_manager().get_or_create_state_core(
|
||||
query_dict.get("scroll_id")
|
||||
)
|
||||
if query_dict.get("refresh_scroll"):
|
||||
@ -625,7 +633,7 @@ class GetMixin(PropsMixin):
|
||||
if not state:
|
||||
return
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.cache_manager.set_state(state)
|
||||
cls.get_cache_manager().set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
|
@ -4,7 +4,12 @@ from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocument,
|
||||
Document,
|
||||
)
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from apiserver.database.fields import (
|
||||
@ -25,6 +30,13 @@ class PropsMixin(object):
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
_document_classes = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, (Document, EmbeddedDocument)):
|
||||
PropsMixin._document_classes[cls._class_name] = cls
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
@ -57,8 +69,14 @@ class PropsMixin(object):
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
|
||||
if v == "self":
|
||||
return cls_.owner_document
|
||||
|
||||
doc_cls = PropsMixin._document_classes.get(v)
|
||||
if doc_cls:
|
||||
return doc_cls
|
||||
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
@ -72,7 +90,7 @@ class PropsMixin(object):
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
".".join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
@ -80,10 +98,10 @@ class PropsMixin(object):
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
|
||||
|
||||
return fields
|
||||
|
||||
@ -94,7 +112,7 @@ class PropsMixin(object):
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
@ -103,7 +121,7 @@ class PropsMixin(object):
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
raise ValueError("Invalid field path %s" % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
@ -119,7 +137,7 @@ class PropsMixin(object):
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
translated_parts.append("*")
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
@ -128,7 +146,7 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
fields = cls._get_fields_with_attr(cls, "reference_field")
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@ -143,12 +161,12 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
def get_dpath_translated_path(cls, path, separator="."):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
|
@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
@ -26,6 +26,10 @@ credentials {
|
||||
type: string
|
||||
description: Credentials secret key
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
batch_operation {
|
||||
|
@ -15,6 +15,10 @@ _definitions {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
last_used {
|
||||
type: string
|
||||
description: ""
|
||||
@ -222,6 +226,12 @@ create_credentials {
|
||||
}
|
||||
}
|
||||
}
|
||||
"999.0": ${create_credentials."2.1"} {
|
||||
request.properties.label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_credentials {
|
||||
|
@ -889,6 +889,13 @@ get_task_plots {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.16": ${get_task_plots."2.14"} {
|
||||
request.properties.no_scroll {
|
||||
description: If true then no scroll is created. Suitable for one time calls
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_plots {
|
||||
"2.1" {
|
||||
@ -939,6 +946,13 @@ get_multi_task_plots {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.16": ${get_multi_task_plots."2.1"} {
|
||||
request.properties.no_scroll {
|
||||
description: If true then no scroll is created. Suitable for one time calls
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
get_vector_metrics_and_variants {
|
||||
"2.1" {
|
||||
@ -1218,6 +1232,13 @@ get_scalar_metric_data {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.16": ${get_scalar_metric_data."2.1"} {
|
||||
request.properties.no_scroll {
|
||||
description: If true then no scroll is created. Suitable for one time calls
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_metrics_iter_raw {
|
||||
"2.16" {
|
||||
|
@ -545,6 +545,42 @@ get_all_ex {
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
stats {
|
||||
properties {
|
||||
active.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
archived.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
@ -893,7 +929,55 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_keys {
|
||||
"999.0" {
|
||||
description: """Get a list of all metadata keys used in models within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
|
||||
page {
|
||||
description: "Page number"
|
||||
default: 0
|
||||
type: integer
|
||||
}
|
||||
page_size {
|
||||
description: "Page size"
|
||||
default: 500
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
keys {
|
||||
description: "A list of model keys"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
type: integer
|
||||
}
|
||||
total {
|
||||
description: "Total number of results"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
|
@ -534,7 +534,7 @@ _definitions {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
@ -981,7 +981,7 @@ clone {
|
||||
new_task_container {
|
||||
description: "The docker container properties for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1159,7 +1159,7 @@ create {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1248,7 +1248,7 @@ validate {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1410,7 +1410,7 @@ edit {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
|
@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
|
||||
CredentialsResponse,
|
||||
RevokeCredentialsRequest,
|
||||
EditUserReq,
|
||||
CreateCredentialsRequest,
|
||||
)
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.bll.auth import AuthBLL
|
||||
@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
||||
if call.identity.role not in Role.get_system_roles():
|
||||
if call.identity.role != Role.admin and call.identity.user != request.user:
|
||||
raise errors.bad_request.InvalidUserId("cannot generate token for another user")
|
||||
raise errors.bad_request.InvalidUserId(
|
||||
"cannot generate token for another user"
|
||||
)
|
||||
if call.identity.company != request.company:
|
||||
raise errors.bad_request.InvalidId("cannot generate token in another company")
|
||||
raise errors.bad_request.InvalidId(
|
||||
"cannot generate token in another company"
|
||||
)
|
||||
|
||||
call.result.data_model = AuthBLL.get_token_for_user(
|
||||
user_id=request.user,
|
||||
@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
|
||||
)
|
||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
""" Create a user from. INTERNAL. """
|
||||
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company:
|
||||
if (
|
||||
call.identity.role not in Role.get_system_roles()
|
||||
and request.company != call.identity.company
|
||||
):
|
||||
raise errors.bad_request.InvalidId("cannot create user in another company")
|
||||
|
||||
user_id = AuthBLL.create_user(request=request, call=call)
|
||||
@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
|
||||
|
||||
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
|
||||
def create_credentials(call: APICall, _, __):
|
||||
def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
|
||||
if _is_protected_user(call.identity.user):
|
||||
raise errors.bad_request.InvalidUserId("protected identity")
|
||||
|
||||
@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
|
||||
user_id=call.identity.user,
|
||||
company_id=call.identity.company,
|
||||
role=call.identity.role,
|
||||
label=request.label,
|
||||
)
|
||||
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
|
||||
|
||||
@ -151,7 +160,9 @@ def get_credentials(call: APICall, _, __):
|
||||
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
|
||||
call.result.data_model = GetCredentialsResponse(
|
||||
credentials=[
|
||||
CredentialsResponse(access_key=c.key, last_used=c.last_used)
|
||||
CredentialsResponse(
|
||||
access_key=c.key, last_used=c.last_used, label=c.label
|
||||
)
|
||||
for c in user.credentials
|
||||
]
|
||||
)
|
||||
|
@ -361,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@ -372,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
@ -494,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
@ -515,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
@ -593,6 +597,7 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
||||
from apiserver.apimodels.projects import (
|
||||
GetHyperParamRequest,
|
||||
GetParamsRequest,
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
@ -19,12 +19,11 @@ from apiserver.apimodels.projects import (
|
||||
ProjectRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||
from apiserver.bll.project.project_cleanup import (
|
||||
delete_project,
|
||||
validate_project_delete,
|
||||
)
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.utils import (
|
||||
@ -41,8 +40,8 @@ from apiserver.services.utils import (
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
task_bll = TaskBLL()
|
||||
project_bll = ProjectBLL()
|
||||
project_queries = ProjectQueries()
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
@ -267,7 +266,7 @@ def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
||||
):
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
metrics = project_queries.get_unique_metric_variants(
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
@ -276,14 +275,31 @@ def get_unique_metric_variants(
|
||||
call.result.data = {"metrics": metrics}
|
||||
|
||||
|
||||
@endpoint("projects.get_model_metadata_keys",)
|
||||
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
total, remaining, keys = project_queries.get_model_metadata_keys(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"remaining": remaining,
|
||||
"keys": keys,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
request_data_model=GetHyperParamRequest,
|
||||
request_data_model=GetParamsRequest,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
||||
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
@ -306,7 +322,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
def get_hyperparam_values(
|
||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||
):
|
||||
total, values = task_bll.get_hyperparam_distinct_values(
|
||||
total, values = project_queries.get_hyperparam_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
|
@ -6,9 +6,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestQueueAndModelMetadata(TestService):
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||
|
||||
def test_queue_metas(self):
|
||||
@ -29,6 +26,23 @@ class TestQueueAndModelMetadata(TestService):
|
||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||
|
||||
def test_project_meta_query(self):
|
||||
self._temp_model("TestMetadata", metadata=self.meta1)
|
||||
project = self.temp_project(name="MetaParent")
|
||||
self._temp_model(
|
||||
"TestMetadata2",
|
||||
project=project,
|
||||
metadata=[
|
||||
{"key": "test_key", "type": "str", "value": "test_value"},
|
||||
{"key": "test_key2", "type": "str", "value": "test_value"},
|
||||
],
|
||||
)
|
||||
res = self.api.projects.get_model_metadata_keys()
|
||||
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
|
||||
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
|
||||
self.assertTrue("test_key" in res["keys"])
|
||||
self.assertFalse("test_key2" in res["keys"])
|
||||
|
||||
def _test_meta_operations(
|
||||
self, service: APIClient.Service, entity: str, _id: str,
|
||||
):
|
||||
@ -72,3 +86,12 @@ class TestQueueAndModelMetadata(TestService):
|
||||
return self.create_temp(
|
||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||
)
|
||||
|
||||
def temp_project(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
name="Test models meta",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestSubProjects(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_project_aggregations(self):
|
||||
"""This test requires user with user_auth_only... credentials in db"""
|
||||
user2_client = APIClient(
|
||||
@ -203,6 +200,9 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(
|
||||
{sp.name for sp in res1.sub_projects},
|
||||
{
|
||||
@ -215,6 +215,9 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(res2.sub_projects, [])
|
||||
|
||||
def _run_tasks(self, *tasks):
|
||||
|
@ -198,6 +198,9 @@ class TestTags(TestService):
|
||||
def assertProjectStats(self, project: AttrDict):
|
||||
self.assertEqual(set(project.stats.keys()), {"active"})
|
||||
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
||||
self.assertEqual(project.stats.active.completed_tasks, 1)
|
||||
self.assertEqual(project.stats.active.total_tasks, 1)
|
||||
self.assertEqual(project.stats.active.running_tasks, 0)
|
||||
for status, count in project.stats.active.status_count.items():
|
||||
self.assertEqual(count, 1 if status == "stopped" else 0)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user