mirror of
https://github.com/clearml/clearml-server
synced 2025-05-03 03:32:25 +00:00
Add support for credentials label
Support no_scroll in events.get_task_plots Support better project stats Fix Redis required on mongodb initialization Update tests
This commit is contained in:
parent
92fd98d5ad
commit
447adb9090
apiserver
apimodels
bll
database
elastic/mappings
schema/services
services
tests/automated
@ -75,6 +75,7 @@ class CreateUserResponse(Base):
|
|||||||
class Credentials(Base):
|
class Credentials(Base):
|
||||||
access_key = StringField(required=True)
|
access_key = StringField(required=True)
|
||||||
secret_key = StringField(required=True)
|
secret_key = StringField(required=True)
|
||||||
|
label = StringField()
|
||||||
|
|
||||||
|
|
||||||
class CredentialsResponse(Credentials):
|
class CredentialsResponse(Credentials):
|
||||||
@ -82,6 +83,10 @@ class CredentialsResponse(Credentials):
|
|||||||
last_used = DateTimeField(default=None)
|
last_used = DateTimeField(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateCredentialsRequest(Base):
|
||||||
|
label = StringField()
|
||||||
|
|
||||||
|
|
||||||
class CreateCredentialsResponse(Base):
|
class CreateCredentialsResponse(Base):
|
||||||
credentials = EmbeddedField(Credentials)
|
credentials = EmbeddedField(Credentials)
|
||||||
|
|
||||||
|
@ -135,4 +135,5 @@ class TaskPlotsRequest(Base):
|
|||||||
task: str = StringField(required=True)
|
task: str = StringField(required=True)
|
||||||
iters: int = IntField(default=1)
|
iters: int = IntField(default=1)
|
||||||
scroll_id: str = StringField()
|
scroll_id: str = StringField()
|
||||||
|
no_scroll: bool = BoolField(default=False)
|
||||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||||
|
@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
|
|||||||
include_subprojects = fields.BoolField(default=True)
|
include_subprojects = fields.BoolField(default=True)
|
||||||
|
|
||||||
|
|
||||||
class GetHyperParamRequest(ProjectOrNoneRequest):
|
class GetParamsRequest(ProjectOrNoneRequest):
|
||||||
page = fields.IntField(default=0)
|
page = fields.IntField(default=0)
|
||||||
page_size = fields.IntField(default=500)
|
page_size = fields.IntField(default=500)
|
||||||
|
|
||||||
|
@ -2,7 +2,11 @@ from datetime import datetime
|
|||||||
|
|
||||||
from apiserver import database
|
from apiserver import database
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
from apiserver.apimodels.auth import (
|
||||||
|
GetTokenResponse,
|
||||||
|
CreateUserRequest,
|
||||||
|
Credentials as CredModel,
|
||||||
|
)
|
||||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||||
from apiserver.bll.user import UserBLL
|
from apiserver.bll.user import UserBLL
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
@ -145,7 +149,7 @@ class AuthBLL:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_credentials(
|
def create_credentials(
|
||||||
cls, user_id: str, company_id: str, role: str = None
|
cls, user_id: str, company_id: str, role: str = None, label: str = None,
|
||||||
) -> CredModel:
|
) -> CredModel:
|
||||||
|
|
||||||
with translate_errors_context():
|
with translate_errors_context():
|
||||||
@ -154,7 +158,9 @@ class AuthBLL:
|
|||||||
if not user:
|
if not user:
|
||||||
raise errors.bad_request.InvalidUserId(**query)
|
raise errors.bad_request.InvalidUserId(**query)
|
||||||
|
|
||||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
cred = CredModel(
|
||||||
|
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||||
|
)
|
||||||
user.credentials.append(
|
user.credentials.append(
|
||||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||||
)
|
)
|
||||||
|
@ -534,6 +534,7 @@ class EventBLL(object):
|
|||||||
sort=None,
|
sort=None,
|
||||||
size: int = 500,
|
size: int = 500,
|
||||||
scroll_id: str = None,
|
scroll_id: str = None,
|
||||||
|
no_scroll: bool = False,
|
||||||
metric_variants: MetricVariants = None,
|
metric_variants: MetricVariants = None,
|
||||||
):
|
):
|
||||||
if scroll_id == self.empty_scroll:
|
if scroll_id == self.empty_scroll:
|
||||||
@ -611,7 +612,7 @@ class EventBLL(object):
|
|||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
body=es_req,
|
body=es_req,
|
||||||
ignore=404,
|
ignore=404,
|
||||||
scroll="1h",
|
**({} if no_scroll else {"scroll": "1h"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||||
@ -680,6 +681,7 @@ class EventBLL(object):
|
|||||||
sort=None,
|
sort=None,
|
||||||
size=500,
|
size=500,
|
||||||
scroll_id=None,
|
scroll_id=None,
|
||||||
|
no_scroll=False,
|
||||||
) -> TaskEventsResult:
|
) -> TaskEventsResult:
|
||||||
if scroll_id == self.empty_scroll:
|
if scroll_id == self.empty_scroll:
|
||||||
return TaskEventsResult()
|
return TaskEventsResult()
|
||||||
@ -740,7 +742,7 @@ class EventBLL(object):
|
|||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
body=es_req,
|
body=es_req,
|
||||||
ignore=404,
|
ignore=404,
|
||||||
scroll="1h",
|
**({} if no_scroll else {"scroll": "1h"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
from .project_bll import ProjectBLL
|
from .project_bll import ProjectBLL
|
||||||
|
from .project_queries import ProjectQueries
|
||||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -306,6 +306,7 @@ class ProjectBLL:
|
|||||||
return project
|
return project
|
||||||
|
|
||||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||||
|
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_projects_get_all_pipelines(
|
def make_projects_get_all_pipelines(
|
||||||
@ -367,6 +368,26 @@ class ProjectBLL:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def completed_after_subquery(additional_cond, time_thresh: datetime):
|
||||||
|
return {
|
||||||
|
# the sum of
|
||||||
|
"$sum": {
|
||||||
|
# for each task
|
||||||
|
"$cond": {
|
||||||
|
# if completed after the time_thresh
|
||||||
|
"if": {
|
||||||
|
"$and": [
|
||||||
|
"$completed",
|
||||||
|
{"$gt": ["$completed", time_thresh]},
|
||||||
|
additional_cond,
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"then": 1,
|
||||||
|
"else": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def runtime_subquery(additional_cond):
|
def runtime_subquery(additional_cond):
|
||||||
return {
|
return {
|
||||||
# the sum of
|
# the sum of
|
||||||
@ -397,16 +418,19 @@ class ProjectBLL:
|
|||||||
}
|
}
|
||||||
|
|
||||||
group_step = {"_id": "$project"}
|
group_step = {"_id": "$project"}
|
||||||
|
time_thresh = datetime.utcnow() - timedelta(hours=24)
|
||||||
for state in EntityVisibility:
|
for state in cls.visibility_states:
|
||||||
if specific_state and state != specific_state:
|
if specific_state and state != specific_state:
|
||||||
continue
|
continue
|
||||||
if state == EntityVisibility.active:
|
cond = (
|
||||||
group_step[state.value] = runtime_subquery(
|
cls.archived_tasks_cond
|
||||||
{"$not": cls.archived_tasks_cond}
|
if state == EntityVisibility.archived
|
||||||
)
|
else {"$not": cls.archived_tasks_cond}
|
||||||
elif state == EntityVisibility.archived:
|
)
|
||||||
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
|
group_step[state.value] = runtime_subquery(cond)
|
||||||
|
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||||
|
cond, time_thresh=time_thresh
|
||||||
|
)
|
||||||
|
|
||||||
runtime_pipeline = [
|
runtime_pipeline = [
|
||||||
# only count run time for these types of tasks
|
# only count run time for these types of tasks
|
||||||
@ -534,15 +558,24 @@ class ProjectBLL:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_status_counts(project_id, section):
|
def get_status_counts(project_id, section):
|
||||||
|
project_runtime = runtime.get(project_id, {})
|
||||||
|
project_section_statuses = nested_get(
|
||||||
|
status_count, (project_id, section), default=default_counts
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"total_runtime": nested_get(runtime, (project_id, section), default=0),
|
"status_count": project_section_statuses,
|
||||||
"status_count": nested_get(
|
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
||||||
status_count, (project_id, section), default=default_counts
|
"total_tasks": sum(project_section_statuses.values()),
|
||||||
|
"total_runtime": project_runtime.get(section, 0),
|
||||||
|
"completed_tasks": project_runtime.get(
|
||||||
|
f"{section}_recently_completed", 0
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
report_for_states = [
|
report_for_states = [
|
||||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
s
|
||||||
|
for s in cls.visibility_states
|
||||||
|
if not specific_state or specific_state == s
|
||||||
]
|
]
|
||||||
|
|
||||||
stats = {
|
stats = {
|
||||||
|
289
apiserver/bll/project/project_queries.py
Normal file
289
apiserver/bll/project/project_queries.py
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
import json
|
||||||
|
from collections import OrderedDict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import (
|
||||||
|
Sequence,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
|
from redis import StrictRedis
|
||||||
|
|
||||||
|
from apiserver.config_repo import config
|
||||||
|
from apiserver.redis_manager import redman
|
||||||
|
from apiserver.utilities.dicts import nested_get
|
||||||
|
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||||
|
from .sub_projects import _ids_with_children
|
||||||
|
from ...database.model.model import Model
|
||||||
|
from ...database.model.task.task import Task
|
||||||
|
|
||||||
|
log = config.logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectQueries:
|
||||||
|
def __init__(self, redis=None):
|
||||||
|
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_project_constraint(
|
||||||
|
project_ids: Sequence[str], include_subprojects: bool
|
||||||
|
) -> dict:
|
||||||
|
if include_subprojects:
|
||||||
|
if project_ids is None:
|
||||||
|
return {}
|
||||||
|
project_ids = _ids_with_children(project_ids)
|
||||||
|
|
||||||
|
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||||
|
if allow_public:
|
||||||
|
return {"company": {"$in": [None, "", company_id]}}
|
||||||
|
|
||||||
|
return {"company": company_id}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_aggregated_project_parameters(
|
||||||
|
cls,
|
||||||
|
company_id,
|
||||||
|
project_ids: Sequence[str],
|
||||||
|
include_subprojects: bool,
|
||||||
|
page: int = 0,
|
||||||
|
page_size: int = 500,
|
||||||
|
) -> Tuple[int, int, Sequence[dict]]:
|
||||||
|
page = max(0, page)
|
||||||
|
page_size = max(1, page_size)
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
**cls._get_company_constraint(company_id),
|
||||||
|
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||||
|
"hyperparams": {"$exists": True, "$gt": {}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||||
|
{"$unwind": "$sections"},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"section": "$sections.k",
|
||||||
|
"names": {"$objectToArray": "$sections.v"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$unwind": "$names"},
|
||||||
|
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||||
|
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||||
|
{"$skip": page * page_size},
|
||||||
|
{"$limit": page_size},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": 1,
|
||||||
|
"total": {"$sum": 1},
|
||||||
|
"results": {"$push": "$$ROOT"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = next(Task.aggregate(pipeline), None)
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
remaining = 0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if result:
|
||||||
|
total = int(result.get("total", -1))
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"section": ParameterKeyEscaper.unescape(
|
||||||
|
nested_get(r, ("_id", "section"))
|
||||||
|
),
|
||||||
|
"name": ParameterKeyEscaper.unescape(
|
||||||
|
nested_get(r, ("_id", "name"))
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for r in result.get("results", [])
|
||||||
|
]
|
||||||
|
remaining = max(0, total - (len(results) + page * page_size))
|
||||||
|
|
||||||
|
return total, remaining, results
|
||||||
|
|
||||||
|
HyperParamValues = Tuple[int, Sequence[str]]
|
||||||
|
|
||||||
|
def _get_cached_hyperparam_values(
|
||||||
|
self, key: str, last_update: datetime
|
||||||
|
) -> Optional[HyperParamValues]:
|
||||||
|
allowed_delta = timedelta(
|
||||||
|
seconds=config.get(
|
||||||
|
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
cached = self.redis.get(key)
|
||||||
|
if not cached:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = json.loads(cached)
|
||||||
|
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||||
|
if (last_update - cached_last_update) < allowed_delta:
|
||||||
|
return data["total"], data["values"]
|
||||||
|
except Exception as ex:
|
||||||
|
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||||
|
|
||||||
|
def get_hyperparam_distinct_values(
|
||||||
|
self,
|
||||||
|
company_id: str,
|
||||||
|
project_ids: Sequence[str],
|
||||||
|
section: str,
|
||||||
|
name: str,
|
||||||
|
include_subprojects: bool,
|
||||||
|
allow_public: bool = True,
|
||||||
|
) -> HyperParamValues:
|
||||||
|
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||||
|
project_constraint = self._get_project_constraint(
|
||||||
|
project_ids, include_subprojects
|
||||||
|
)
|
||||||
|
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||||
|
last_updated_task = (
|
||||||
|
Task.objects(
|
||||||
|
**company_constraint,
|
||||||
|
**project_constraint,
|
||||||
|
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||||
|
)
|
||||||
|
.only("last_update")
|
||||||
|
.order_by("-last_update")
|
||||||
|
.limit(1)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not last_updated_task:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||||
|
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||||
|
cached_res = self._get_cached_hyperparam_values(
|
||||||
|
key=redis_key, last_update=last_update
|
||||||
|
)
|
||||||
|
if cached_res:
|
||||||
|
return cached_res
|
||||||
|
|
||||||
|
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
**company_constraint,
|
||||||
|
**project_constraint,
|
||||||
|
key_path: {"$exists": True},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"value": f"${key_path}.value"}},
|
||||||
|
{"$group": {"_id": "$value"}},
|
||||||
|
{"$sort": {"_id": 1}},
|
||||||
|
{"$limit": max_values},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": 1,
|
||||||
|
"total": {"$sum": 1},
|
||||||
|
"results": {"$push": "$$ROOT._id"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||||
|
if not result:
|
||||||
|
return 0, []
|
||||||
|
|
||||||
|
total = int(result.get("total", 0))
|
||||||
|
values = result.get("results", [])
|
||||||
|
|
||||||
|
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||||
|
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||||
|
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||||
|
|
||||||
|
return total, values
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_unique_metric_variants(
|
||||||
|
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||||
|
):
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
**cls._get_company_constraint(company_id),
|
||||||
|
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||||
|
{"$unwind": "$metrics"},
|
||||||
|
{
|
||||||
|
"$project": {
|
||||||
|
"metric": "$metrics.k",
|
||||||
|
"variants": {"$objectToArray": "$metrics.v"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$unwind": "$variants"},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": {
|
||||||
|
"metric": "$variants.v.metric",
|
||||||
|
"variant": "$variants.v.variant",
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"$addToSet": {
|
||||||
|
"metric": "$variants.v.metric",
|
||||||
|
"metric_hash": "$metric",
|
||||||
|
"variant": "$variants.v.variant",
|
||||||
|
"variant_hash": "$variants.k",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = Task.aggregate(pipeline)
|
||||||
|
return [r["metrics"][0] for r in result]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_model_metadata_keys(
|
||||||
|
cls,
|
||||||
|
company_id,
|
||||||
|
project_ids: Sequence[str],
|
||||||
|
include_subprojects: bool,
|
||||||
|
page: int = 0,
|
||||||
|
page_size: int = 500,
|
||||||
|
) -> Tuple[int, int, Sequence[dict]]:
|
||||||
|
page = max(0, page)
|
||||||
|
page_size = max(1, page_size)
|
||||||
|
pipeline = [
|
||||||
|
{
|
||||||
|
"$match": {
|
||||||
|
**cls._get_company_constraint(company_id),
|
||||||
|
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||||
|
"metadata": {"$exists": True, "$ne": []},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"$project": {"metadata": 1}},
|
||||||
|
{"$unwind": "$metadata"},
|
||||||
|
{"$group": {"_id": "$metadata.key"}},
|
||||||
|
{"$sort": {"_id": 1}},
|
||||||
|
{"$skip": page * page_size},
|
||||||
|
{"$limit": page_size},
|
||||||
|
{
|
||||||
|
"$group": {
|
||||||
|
"_id": 1,
|
||||||
|
"total": {"$sum": 1},
|
||||||
|
"results": {"$push": "$$ROOT"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = next(Model.aggregate(pipeline), None)
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
remaining = 0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
if result:
|
||||||
|
total = int(result.get("total", -1))
|
||||||
|
results = [r.get("_id") for r in result.get("results", [])]
|
||||||
|
remaining = max(0, total - (len(results) + page * page_size))
|
||||||
|
|
||||||
|
return total, remaining, results
|
@ -1,9 +1,6 @@
|
|||||||
import json
|
from datetime import datetime
|
||||||
from collections import OrderedDict
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||||
|
|
||||||
import dpath
|
|
||||||
import six
|
import six
|
||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
from redis import StrictRedis
|
from redis import StrictRedis
|
||||||
@ -14,7 +11,7 @@ from apiserver.apierrors import errors
|
|||||||
from apiserver.apimodels.tasks import TaskInputModel
|
from apiserver.apimodels.tasks import TaskInputModel
|
||||||
from apiserver.bll.queue import QueueBLL
|
from apiserver.bll.queue import QueueBLL
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
from apiserver.bll.project import ProjectBLL
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.database.model.model import Model
|
from apiserver.database.model.model import Model
|
||||||
@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
|
|||||||
from apiserver.service_repo import APICall
|
from apiserver.service_repo import APICall
|
||||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
|
||||||
from .artifacts import artifacts_prepare_for_save
|
from .artifacts import artifacts_prepare_for_save
|
||||||
from .param_utils import params_prepare_for_save
|
from .param_utils import params_prepare_for_save
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -350,54 +346,6 @@ class TaskBLL:
|
|||||||
if validate_models:
|
if validate_models:
|
||||||
cls.validate_input_models(task)
|
cls.validate_input_models(task)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_unique_metric_variants(
|
|
||||||
company_id, project_ids: Sequence[str], include_subprojects: bool
|
|
||||||
):
|
|
||||||
if project_ids:
|
|
||||||
if include_subprojects:
|
|
||||||
project_ids = project_ids_with_children(project_ids)
|
|
||||||
project_constraint = {"project": {"$in": project_ids}}
|
|
||||||
else:
|
|
||||||
project_constraint = {}
|
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$match": dict(
|
|
||||||
company={"$in": [None, "", company_id]}, **project_constraint,
|
|
||||||
)
|
|
||||||
},
|
|
||||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
|
||||||
{"$unwind": "$metrics"},
|
|
||||||
{
|
|
||||||
"$project": {
|
|
||||||
"metric": "$metrics.k",
|
|
||||||
"variants": {"$objectToArray": "$metrics.v"},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$unwind": "$variants"},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": {
|
|
||||||
"metric": "$variants.v.metric",
|
|
||||||
"variant": "$variants.v.variant",
|
|
||||||
},
|
|
||||||
"metrics": {
|
|
||||||
"$addToSet": {
|
|
||||||
"metric": "$variants.v.metric",
|
|
||||||
"metric_hash": "$metric",
|
|
||||||
"variant": "$variants.v.variant",
|
|
||||||
"variant_hash": "$variants.k",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
|
||||||
]
|
|
||||||
|
|
||||||
with translate_errors_context():
|
|
||||||
result = Task.aggregate(pipeline)
|
|
||||||
return [r["metrics"][0] for r in result]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_last_update(
|
def set_last_update(
|
||||||
task_ids: Collection[str],
|
task_ids: Collection[str],
|
||||||
@ -494,173 +442,6 @@ class TaskBLL:
|
|||||||
**extra_updates,
|
**extra_updates,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_aggregated_project_parameters(
|
|
||||||
company_id,
|
|
||||||
project_ids: Sequence[str],
|
|
||||||
include_subprojects: bool,
|
|
||||||
page: int = 0,
|
|
||||||
page_size: int = 500,
|
|
||||||
) -> Tuple[int, int, Sequence[dict]]:
|
|
||||||
if project_ids:
|
|
||||||
if include_subprojects:
|
|
||||||
project_ids = project_ids_with_children(project_ids)
|
|
||||||
project_constraint = {"project": {"$in": project_ids}}
|
|
||||||
else:
|
|
||||||
project_constraint = {}
|
|
||||||
page = max(0, page)
|
|
||||||
page_size = max(1, page_size)
|
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$match": {
|
|
||||||
"company": {"$in": [None, "", company_id]},
|
|
||||||
"hyperparams": {"$exists": True, "$gt": {}},
|
|
||||||
**project_constraint,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
|
||||||
{"$unwind": "$sections"},
|
|
||||||
{
|
|
||||||
"$project": {
|
|
||||||
"section": "$sections.k",
|
|
||||||
"names": {"$objectToArray": "$sections.v"},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$unwind": "$names"},
|
|
||||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
|
||||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
|
||||||
{"$skip": page * page_size},
|
|
||||||
{"$limit": page_size},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": 1,
|
|
||||||
"total": {"$sum": 1},
|
|
||||||
"results": {"$push": "$$ROOT"},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
result = next(Task.aggregate(pipeline), None)
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
remaining = 0
|
|
||||||
results = []
|
|
||||||
|
|
||||||
if result:
|
|
||||||
total = int(result.get("total", -1))
|
|
||||||
results = [
|
|
||||||
{
|
|
||||||
"section": ParameterKeyEscaper.unescape(
|
|
||||||
dpath.get(r, "_id/section")
|
|
||||||
),
|
|
||||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
|
||||||
}
|
|
||||||
for r in result.get("results", [])
|
|
||||||
]
|
|
||||||
remaining = max(0, total - (len(results) + page * page_size))
|
|
||||||
|
|
||||||
return total, remaining, results
|
|
||||||
|
|
||||||
HyperParamValues = Tuple[int, Sequence[str]]
|
|
||||||
|
|
||||||
def _get_cached_hyperparam_values(
|
|
||||||
self, key: str, last_update: datetime
|
|
||||||
) -> Optional[HyperParamValues]:
|
|
||||||
allowed_delta = timedelta(
|
|
||||||
seconds=config.get(
|
|
||||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
|
||||||
)
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
cached = self.redis.get(key)
|
|
||||||
if not cached:
|
|
||||||
return
|
|
||||||
|
|
||||||
data = json.loads(cached)
|
|
||||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
|
||||||
if (last_update - cached_last_update) < allowed_delta:
|
|
||||||
return data["total"], data["values"]
|
|
||||||
except Exception as ex:
|
|
||||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
|
||||||
|
|
||||||
def get_hyperparam_distinct_values(
|
|
||||||
self,
|
|
||||||
company_id: str,
|
|
||||||
project_ids: Sequence[str],
|
|
||||||
section: str,
|
|
||||||
name: str,
|
|
||||||
include_subprojects: bool,
|
|
||||||
allow_public: bool = True,
|
|
||||||
) -> HyperParamValues:
|
|
||||||
if allow_public:
|
|
||||||
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
|
||||||
else:
|
|
||||||
company_constraint = {"company": company_id}
|
|
||||||
if project_ids:
|
|
||||||
if include_subprojects:
|
|
||||||
project_ids = project_ids_with_children(project_ids)
|
|
||||||
project_constraint = {"project": {"$in": project_ids}}
|
|
||||||
else:
|
|
||||||
project_constraint = {}
|
|
||||||
|
|
||||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
|
||||||
last_updated_task = (
|
|
||||||
Task.objects(
|
|
||||||
**company_constraint,
|
|
||||||
**project_constraint,
|
|
||||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
|
||||||
)
|
|
||||||
.only("last_update")
|
|
||||||
.order_by("-last_update")
|
|
||||||
.limit(1)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not last_updated_task:
|
|
||||||
return 0, []
|
|
||||||
|
|
||||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
|
||||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
|
||||||
cached_res = self._get_cached_hyperparam_values(
|
|
||||||
key=redis_key, last_update=last_update
|
|
||||||
)
|
|
||||||
if cached_res:
|
|
||||||
return cached_res
|
|
||||||
|
|
||||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$match": {
|
|
||||||
**company_constraint,
|
|
||||||
**project_constraint,
|
|
||||||
key_path: {"$exists": True},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{"$project": {"value": f"${key_path}.value"}},
|
|
||||||
{"$group": {"_id": "$value"}},
|
|
||||||
{"$sort": {"_id": 1}},
|
|
||||||
{"$limit": max_values},
|
|
||||||
{
|
|
||||||
"$group": {
|
|
||||||
"_id": 1,
|
|
||||||
"total": {"$sum": 1},
|
|
||||||
"results": {"$push": "$$ROOT._id"},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
|
||||||
if not result:
|
|
||||||
return 0, []
|
|
||||||
|
|
||||||
total = int(result.get("total", 0))
|
|
||||||
values = result.get("results", [])
|
|
||||||
|
|
||||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
|
||||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
|
||||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
|
||||||
|
|
||||||
return total, values
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def dequeue_and_change_status(
|
def dequeue_and_change_status(
|
||||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||||
|
@ -48,6 +48,7 @@ class Credentials(EmbeddedDocument):
|
|||||||
meta = {"strict": False}
|
meta = {"strict": False}
|
||||||
key = StringField(required=True)
|
key = StringField(required=True)
|
||||||
secret = StringField(required=True)
|
secret = StringField(required=True)
|
||||||
|
label = StringField()
|
||||||
last_used = DateTimeField()
|
last_used = DateTimeField()
|
||||||
|
|
||||||
|
|
||||||
|
@ -142,7 +142,9 @@ class GetMixin(PropsMixin):
|
|||||||
self.allow_empty = False
|
self.allow_empty = False
|
||||||
|
|
||||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||||
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None
|
op = (
|
||||||
|
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
||||||
|
)
|
||||||
if translate:
|
if translate:
|
||||||
tup = self._ops.get(op, None)
|
tup = self._ops.get(op, None)
|
||||||
return tup[0] if tup else None
|
return tup[0] if tup else None
|
||||||
@ -177,7 +179,9 @@ class GetMixin(PropsMixin):
|
|||||||
"all": Q.AND,
|
"all": Q.AND,
|
||||||
}
|
}
|
||||||
data = (x for x in data if x is not None)
|
data = (x for x in data if x is not None)
|
||||||
first_op = self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
first_op = (
|
||||||
|
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
||||||
|
)
|
||||||
return op_to_res.get(first_op, self.default_mongo_op)
|
return op_to_res.get(first_op, self.default_mongo_op)
|
||||||
|
|
||||||
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
||||||
@ -202,13 +206,20 @@ class GetMixin(PropsMixin):
|
|||||||
id = StringField(primary_key=True)
|
id = StringField(primary_key=True)
|
||||||
position = IntField(default=0)
|
position = IntField(default=0)
|
||||||
|
|
||||||
cache_manager = RedisCacheManager(
|
_cache_manager = None
|
||||||
state_class=GetManyScrollState,
|
|
||||||
redis=redman.connection("apiserver"),
|
@classmethod
|
||||||
expiration_interval=config.get(
|
def get_cache_manager(cls):
|
||||||
"services._mongo.scroll_state_expiration_seconds", 600
|
if not cls._cache_manager:
|
||||||
),
|
cls._cache_manager = RedisCacheManager(
|
||||||
)
|
state_class=cls.GetManyScrollState,
|
||||||
|
redis=redman.connection("apiserver"),
|
||||||
|
expiration_interval=config.get(
|
||||||
|
"services._mongo.scroll_state_expiration_seconds", 600
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls._cache_manager
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(
|
def get(
|
||||||
@ -463,10 +474,7 @@ class GetMixin(PropsMixin):
|
|||||||
if not queries:
|
if not queries:
|
||||||
q = RegexQ()
|
q = RegexQ()
|
||||||
else:
|
else:
|
||||||
q = RegexQCombination(
|
q = RegexQCombination(operation=global_op, children=queries)
|
||||||
operation=global_op,
|
|
||||||
children=queries
|
|
||||||
)
|
|
||||||
|
|
||||||
if not helper.allow_empty:
|
if not helper.allow_empty:
|
||||||
return q
|
return q
|
||||||
@ -609,7 +617,7 @@ class GetMixin(PropsMixin):
|
|||||||
state: Optional[cls.GetManyScrollState] = None
|
state: Optional[cls.GetManyScrollState] = None
|
||||||
if "scroll_id" in query_dict:
|
if "scroll_id" in query_dict:
|
||||||
size = cls.validate_scroll_size(query_dict)
|
size = cls.validate_scroll_size(query_dict)
|
||||||
state = cls.cache_manager.get_or_create_state_core(
|
state = cls.get_cache_manager().get_or_create_state_core(
|
||||||
query_dict.get("scroll_id")
|
query_dict.get("scroll_id")
|
||||||
)
|
)
|
||||||
if query_dict.get("refresh_scroll"):
|
if query_dict.get("refresh_scroll"):
|
||||||
@ -625,7 +633,7 @@ class GetMixin(PropsMixin):
|
|||||||
if not state:
|
if not state:
|
||||||
return
|
return
|
||||||
state.position = query_dict[cls._start_key]
|
state.position = query_dict[cls._start_key]
|
||||||
cls.cache_manager.set_state(state)
|
cls.get_cache_manager().set_state(state)
|
||||||
if ret_params is not None:
|
if ret_params is not None:
|
||||||
ret_params["scroll_id"] = state.id
|
ret_params["scroll_id"] = state.id
|
||||||
|
|
||||||
|
@ -4,7 +4,12 @@ from threading import Lock
|
|||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
from mongoengine import (
|
||||||
|
EmbeddedDocumentField,
|
||||||
|
EmbeddedDocumentListField,
|
||||||
|
EmbeddedDocument,
|
||||||
|
Document,
|
||||||
|
)
|
||||||
from mongoengine.base import get_document
|
from mongoengine.base import get_document
|
||||||
|
|
||||||
from apiserver.database.fields import (
|
from apiserver.database.fields import (
|
||||||
@ -25,6 +30,13 @@ class PropsMixin(object):
|
|||||||
__cached_dpath_computed_fields_lock = Lock()
|
__cached_dpath_computed_fields_lock = Lock()
|
||||||
__cached_dpath_computed_fields = None
|
__cached_dpath_computed_fields = None
|
||||||
|
|
||||||
|
_document_classes = {}
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
if issubclass(cls, (Document, EmbeddedDocument)):
|
||||||
|
PropsMixin._document_classes[cls._class_name] = cls
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls):
|
def get_fields(cls):
|
||||||
if cls.__cached_fields is None:
|
if cls.__cached_fields is None:
|
||||||
@ -57,8 +69,14 @@ class PropsMixin(object):
|
|||||||
def resolve_doc(v):
|
def resolve_doc(v):
|
||||||
if not isinstance(v, six.string_types):
|
if not isinstance(v, six.string_types):
|
||||||
return v
|
return v
|
||||||
if v == 'self':
|
|
||||||
|
if v == "self":
|
||||||
return cls_.owner_document
|
return cls_.owner_document
|
||||||
|
|
||||||
|
doc_cls = PropsMixin._document_classes.get(v)
|
||||||
|
if doc_cls:
|
||||||
|
return doc_cls
|
||||||
|
|
||||||
return get_document(v)
|
return get_document(v)
|
||||||
|
|
||||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||||
@ -72,7 +90,7 @@ class PropsMixin(object):
|
|||||||
).document_type
|
).document_type
|
||||||
fields.update(
|
fields.update(
|
||||||
{
|
{
|
||||||
'.'.join((field, subfield)): doc
|
".".join((field, subfield)): doc
|
||||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||||
embedded_doc_cls, attr
|
embedded_doc_cls, attr
|
||||||
).items()
|
).items()
|
||||||
@ -80,10 +98,10 @@ class PropsMixin(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
|
||||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
|
||||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
|
||||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
|
||||||
|
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
@ -94,7 +112,7 @@ class PropsMixin(object):
|
|||||||
for depth, part in enumerate(parts):
|
for depth, part in enumerate(parts):
|
||||||
if current_cls is None:
|
if current_cls is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
field_name, field = next(
|
field_name, field = next(
|
||||||
@ -103,7 +121,7 @@ class PropsMixin(object):
|
|||||||
if k == part
|
if k == part
|
||||||
)
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
raise ValueError("Invalid field path %s" % parts[:depth])
|
||||||
|
|
||||||
translated_parts.append(part)
|
translated_parts.append(part)
|
||||||
|
|
||||||
@ -119,7 +137,7 @@ class PropsMixin(object):
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
current_cls = field.field.document_type
|
current_cls = field.field.document_type
|
||||||
translated_parts.append('*')
|
translated_parts.append("*")
|
||||||
else:
|
else:
|
||||||
current_cls = None
|
current_cls = None
|
||||||
|
|
||||||
@ -128,7 +146,7 @@ class PropsMixin(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_reference_fields(cls):
|
def get_reference_fields(cls):
|
||||||
if cls.__cached_reference_fields is None:
|
if cls.__cached_reference_fields is None:
|
||||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
fields = cls._get_fields_with_attr(cls, "reference_field")
|
||||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||||
return cls.__cached_reference_fields
|
return cls.__cached_reference_fields
|
||||||
|
|
||||||
@ -143,12 +161,12 @@ class PropsMixin(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_exclude_fields(cls):
|
def get_exclude_fields(cls):
|
||||||
if cls.__cached_exclude_fields is None:
|
if cls.__cached_exclude_fields is None:
|
||||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
|
||||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||||
return cls.__cached_exclude_fields
|
return cls.__cached_exclude_fields
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_dpath_translated_path(cls, path, separator='.'):
|
def get_dpath_translated_path(cls, path, separator="."):
|
||||||
if cls.__cached_dpath_computed_fields is None:
|
if cls.__cached_dpath_computed_fields is None:
|
||||||
cls.__cached_dpath_computed_fields = {}
|
cls.__cached_dpath_computed_fields = {}
|
||||||
if path not in cls.__cached_dpath_computed_fields:
|
if path not in cls.__cached_dpath_computed_fields:
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
{
|
{
|
||||||
"index_patterns": "events-*",
|
"index_patterns": "events-*",
|
||||||
"settings": {
|
"settings": {
|
||||||
"number_of_shards": 1
|
"number_of_shards": 1,
|
||||||
|
"number_of_replicas": 0
|
||||||
},
|
},
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"_source": {
|
"_source": {
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
{
|
{
|
||||||
"index_patterns": "queue_metrics_*",
|
"index_patterns": "queue_metrics_*",
|
||||||
"settings": {
|
"settings": {
|
||||||
"number_of_shards": 1
|
"number_of_shards": 1,
|
||||||
|
"number_of_replicas": 0
|
||||||
},
|
},
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"_source": {
|
"_source": {
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
{
|
{
|
||||||
"index_patterns": "worker_stats_*",
|
"index_patterns": "worker_stats_*",
|
||||||
"settings": {
|
"settings": {
|
||||||
"number_of_shards": 1
|
"number_of_shards": 1,
|
||||||
|
"number_of_replicas": 0
|
||||||
},
|
},
|
||||||
"mappings": {
|
"mappings": {
|
||||||
"_source": {
|
"_source": {
|
||||||
|
@ -26,6 +26,10 @@ credentials {
|
|||||||
type: string
|
type: string
|
||||||
description: Credentials secret key
|
description: Credentials secret key
|
||||||
}
|
}
|
||||||
|
label {
|
||||||
|
type: string
|
||||||
|
description: Optional credentials label
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
batch_operation {
|
batch_operation {
|
||||||
|
@ -15,6 +15,10 @@ _definitions {
|
|||||||
type: string
|
type: string
|
||||||
description: ""
|
description: ""
|
||||||
}
|
}
|
||||||
|
label {
|
||||||
|
type: string
|
||||||
|
description: Optional credentials label
|
||||||
|
}
|
||||||
last_used {
|
last_used {
|
||||||
type: string
|
type: string
|
||||||
description: ""
|
description: ""
|
||||||
@ -222,6 +226,12 @@ create_credentials {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"999.0": ${create_credentials."2.1"} {
|
||||||
|
request.properties.label {
|
||||||
|
type: string
|
||||||
|
description: Optional credentials label
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
get_credentials {
|
get_credentials {
|
||||||
|
@ -889,6 +889,13 @@ get_task_plots {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.16": ${get_task_plots."2.14"} {
|
||||||
|
request.properties.no_scroll {
|
||||||
|
description: If true then no scroll is created. Suitable for one time calls
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_multi_task_plots {
|
get_multi_task_plots {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
@ -939,6 +946,13 @@ get_multi_task_plots {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.16": ${get_multi_task_plots."2.1"} {
|
||||||
|
request.properties.no_scroll {
|
||||||
|
description: If true then no scroll is created. Suitable for one time calls
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
get_vector_metrics_and_variants {
|
get_vector_metrics_and_variants {
|
||||||
"2.1" {
|
"2.1" {
|
||||||
@ -1218,6 +1232,13 @@ get_scalar_metric_data {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
"2.16": ${get_scalar_metric_data."2.1"} {
|
||||||
|
request.properties.no_scroll {
|
||||||
|
description: If true then no scroll is created. Suitable for one time calls
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
scalar_metrics_iter_raw {
|
scalar_metrics_iter_raw {
|
||||||
"2.16" {
|
"2.16" {
|
||||||
|
@ -545,6 +545,42 @@ get_all_ex {
|
|||||||
type: boolean
|
type: boolean
|
||||||
default: true
|
default: true
|
||||||
}
|
}
|
||||||
|
response {
|
||||||
|
properties {
|
||||||
|
stats {
|
||||||
|
properties {
|
||||||
|
active.properties {
|
||||||
|
total_tasks {
|
||||||
|
description: "Number of tasks"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
completed_tasks {
|
||||||
|
description: "Number of tasks completed in the last 24 hours"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
running_tasks {
|
||||||
|
description: "Number of running tasks"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
archived.properties {
|
||||||
|
total_tasks {
|
||||||
|
description: "Number of tasks"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
completed_tasks {
|
||||||
|
description: "Number of tasks completed in the last 24 hours"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
running_tasks {
|
||||||
|
description: "Number of running tasks"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
update {
|
update {
|
||||||
@ -893,7 +929,55 @@ get_hyper_parameters {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
get_model_metadata_keys {
|
||||||
|
"999.0" {
|
||||||
|
description: """Get a list of all metadata keys used in models within the given project."""
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
required: [project]
|
||||||
|
properties {
|
||||||
|
project {
|
||||||
|
description: "Project ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
include_subprojects {
|
||||||
|
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
|
||||||
|
type: boolean
|
||||||
|
default: true
|
||||||
|
}
|
||||||
|
|
||||||
|
page {
|
||||||
|
description: "Page number"
|
||||||
|
default: 0
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
page_size {
|
||||||
|
description: "Page size"
|
||||||
|
default: 500
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
keys {
|
||||||
|
description: "A list of model keys"
|
||||||
|
type: array
|
||||||
|
items {type: string}
|
||||||
|
}
|
||||||
|
remaining {
|
||||||
|
description: "Remaining results"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
total {
|
||||||
|
description: "Total number of results"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
get_task_tags {
|
get_task_tags {
|
||||||
"2.8" {
|
"2.8" {
|
||||||
description: "Get user and system tags used for the tasks under the specified projects"
|
description: "Get user and system tags used for the tasks under the specified projects"
|
||||||
|
@ -534,7 +534,7 @@ _definitions {
|
|||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
type: object
|
type: object
|
||||||
additionalProperties { type: string }
|
additionalProperties { type: [string, null] }
|
||||||
}
|
}
|
||||||
models {
|
models {
|
||||||
description: "Task models"
|
description: "Task models"
|
||||||
@ -981,7 +981,7 @@ clone {
|
|||||||
new_task_container {
|
new_task_container {
|
||||||
description: "The docker container properties for the new task. If not provided then taken from the original task"
|
description: "The docker container properties for the new task. If not provided then taken from the original task"
|
||||||
type: object
|
type: object
|
||||||
additionalProperties { type: string }
|
additionalProperties { type: [string, null] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1159,7 +1159,7 @@ create {
|
|||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
type: object
|
type: object
|
||||||
additionalProperties { type: string }
|
additionalProperties { type: [string, null] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1248,7 +1248,7 @@ validate {
|
|||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
type: object
|
type: object
|
||||||
additionalProperties { type: string }
|
additionalProperties { type: [string, null] }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1410,7 +1410,7 @@ edit {
|
|||||||
container {
|
container {
|
||||||
description: "Docker container parameters"
|
description: "Docker container parameters"
|
||||||
type: object
|
type: object
|
||||||
additionalProperties { type: string }
|
additionalProperties { type: [string, null] }
|
||||||
}
|
}
|
||||||
runtime {
|
runtime {
|
||||||
description: "Task runtime mapping"
|
description: "Task runtime mapping"
|
||||||
|
@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
|
|||||||
CredentialsResponse,
|
CredentialsResponse,
|
||||||
RevokeCredentialsRequest,
|
RevokeCredentialsRequest,
|
||||||
EditUserReq,
|
EditUserReq,
|
||||||
|
CreateCredentialsRequest,
|
||||||
)
|
)
|
||||||
from apiserver.apimodels.base import UpdateResponse
|
from apiserver.apimodels.base import UpdateResponse
|
||||||
from apiserver.bll.auth import AuthBLL
|
from apiserver.bll.auth import AuthBLL
|
||||||
@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
|||||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
""" Generates a token based on a requested user and company. INTERNAL. """
|
||||||
if call.identity.role not in Role.get_system_roles():
|
if call.identity.role not in Role.get_system_roles():
|
||||||
if call.identity.role != Role.admin and call.identity.user != request.user:
|
if call.identity.role != Role.admin and call.identity.user != request.user:
|
||||||
raise errors.bad_request.InvalidUserId("cannot generate token for another user")
|
raise errors.bad_request.InvalidUserId(
|
||||||
|
"cannot generate token for another user"
|
||||||
|
)
|
||||||
if call.identity.company != request.company:
|
if call.identity.company != request.company:
|
||||||
raise errors.bad_request.InvalidId("cannot generate token in another company")
|
raise errors.bad_request.InvalidId(
|
||||||
|
"cannot generate token in another company"
|
||||||
|
)
|
||||||
|
|
||||||
call.result.data_model = AuthBLL.get_token_for_user(
|
call.result.data_model = AuthBLL.get_token_for_user(
|
||||||
user_id=request.user,
|
user_id=request.user,
|
||||||
@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
|
|||||||
)
|
)
|
||||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||||
""" Create a user from. INTERNAL. """
|
""" Create a user from. INTERNAL. """
|
||||||
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company:
|
if (
|
||||||
|
call.identity.role not in Role.get_system_roles()
|
||||||
|
and request.company != call.identity.company
|
||||||
|
):
|
||||||
raise errors.bad_request.InvalidId("cannot create user in another company")
|
raise errors.bad_request.InvalidId("cannot create user in another company")
|
||||||
|
|
||||||
user_id = AuthBLL.create_user(request=request, call=call)
|
user_id = AuthBLL.create_user(request=request, call=call)
|
||||||
@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
|
|||||||
|
|
||||||
|
|
||||||
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
|
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
|
||||||
def create_credentials(call: APICall, _, __):
|
def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
|
||||||
if _is_protected_user(call.identity.user):
|
if _is_protected_user(call.identity.user):
|
||||||
raise errors.bad_request.InvalidUserId("protected identity")
|
raise errors.bad_request.InvalidUserId("protected identity")
|
||||||
|
|
||||||
@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
|
|||||||
user_id=call.identity.user,
|
user_id=call.identity.user,
|
||||||
company_id=call.identity.company,
|
company_id=call.identity.company,
|
||||||
role=call.identity.role,
|
role=call.identity.role,
|
||||||
|
label=request.label,
|
||||||
)
|
)
|
||||||
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
|
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
|
||||||
|
|
||||||
@ -151,7 +160,9 @@ def get_credentials(call: APICall, _, __):
|
|||||||
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
|
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
|
||||||
call.result.data_model = GetCredentialsResponse(
|
call.result.data_model = GetCredentialsResponse(
|
||||||
credentials=[
|
credentials=[
|
||||||
CredentialsResponse(access_key=c.key, last_used=c.last_used)
|
CredentialsResponse(
|
||||||
|
access_key=c.key, last_used=c.last_used, label=c.label
|
||||||
|
)
|
||||||
for c in user.credentials
|
for c in user.credentials
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -361,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
|
|||||||
task_id = call.data["task"]
|
task_id = call.data["task"]
|
||||||
metric = call.data["metric"]
|
metric = call.data["metric"]
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = call.data.get("scroll_id")
|
||||||
|
no_scroll = call.data.get("no_scroll", False)
|
||||||
|
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||||
@ -372,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
|
|||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
metric=metric,
|
metric=metric,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
|
no_scroll=no_scroll,
|
||||||
)
|
)
|
||||||
|
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
@ -494,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
|||||||
task_ids = call.data["tasks"]
|
task_ids = call.data["tasks"]
|
||||||
iters = call.data.get("iters", 1)
|
iters = call.data.get("iters", 1)
|
||||||
scroll_id = call.data.get("scroll_id")
|
scroll_id = call.data.get("scroll_id")
|
||||||
|
no_scroll = call.data.get("no_scroll", False)
|
||||||
|
|
||||||
tasks = task_bll.assert_exists(
|
tasks = task_bll.assert_exists(
|
||||||
company_id=call.identity.company,
|
company_id=call.identity.company,
|
||||||
@ -515,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
|||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
last_iter_count=iters,
|
last_iter_count=iters,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
|
no_scroll=no_scroll,
|
||||||
)
|
)
|
||||||
|
|
||||||
tasks = {t.id: t.name for t in tasks}
|
tasks = {t.id: t.name for t in tasks}
|
||||||
@ -593,6 +597,7 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
|||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
last_iterations_per_plot=iters,
|
last_iterations_per_plot=iters,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
|
no_scroll=request.no_scroll,
|
||||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from apiserver.apierrors import errors
|
|||||||
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
||||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
||||||
from apiserver.apimodels.projects import (
|
from apiserver.apimodels.projects import (
|
||||||
GetHyperParamRequest,
|
GetParamsRequest,
|
||||||
ProjectTagsRequest,
|
ProjectTagsRequest,
|
||||||
ProjectTaskParentsRequest,
|
ProjectTaskParentsRequest,
|
||||||
ProjectHyperparamValuesRequest,
|
ProjectHyperparamValuesRequest,
|
||||||
@ -19,12 +19,11 @@ from apiserver.apimodels.projects import (
|
|||||||
ProjectRequest,
|
ProjectRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
from apiserver.bll.project import ProjectBLL
|
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||||
from apiserver.bll.project.project_cleanup import (
|
from apiserver.bll.project.project_cleanup import (
|
||||||
delete_project,
|
delete_project,
|
||||||
validate_project_delete,
|
validate_project_delete,
|
||||||
)
|
)
|
||||||
from apiserver.bll.task import TaskBLL
|
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.database.model.project import Project
|
from apiserver.database.model.project import Project
|
||||||
from apiserver.database.utils import (
|
from apiserver.database.utils import (
|
||||||
@ -41,8 +40,8 @@ from apiserver.services.utils import (
|
|||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
|
|
||||||
org_bll = OrgBLL()
|
org_bll = OrgBLL()
|
||||||
task_bll = TaskBLL()
|
|
||||||
project_bll = ProjectBLL()
|
project_bll = ProjectBLL()
|
||||||
|
project_queries = ProjectQueries()
|
||||||
|
|
||||||
create_fields = {
|
create_fields = {
|
||||||
"name": None,
|
"name": None,
|
||||||
@ -267,7 +266,7 @@ def get_unique_metric_variants(
|
|||||||
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
||||||
):
|
):
|
||||||
|
|
||||||
metrics = task_bll.get_unique_metric_variants(
|
metrics = project_queries.get_unique_metric_variants(
|
||||||
company_id,
|
company_id,
|
||||||
[request.project] if request.project else None,
|
[request.project] if request.project else None,
|
||||||
include_subprojects=request.include_subprojects,
|
include_subprojects=request.include_subprojects,
|
||||||
@ -276,14 +275,31 @@ def get_unique_metric_variants(
|
|||||||
call.result.data = {"metrics": metrics}
|
call.result.data = {"metrics": metrics}
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("projects.get_model_metadata_keys",)
|
||||||
|
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
|
||||||
|
total, remaining, keys = project_queries.get_model_metadata_keys(
|
||||||
|
company_id,
|
||||||
|
project_ids=[request.project] if request.project else None,
|
||||||
|
include_subprojects=request.include_subprojects,
|
||||||
|
page=request.page,
|
||||||
|
page_size=request.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
call.result.data = {
|
||||||
|
"total": total,
|
||||||
|
"remaining": remaining,
|
||||||
|
"keys": keys,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"projects.get_hyper_parameters",
|
"projects.get_hyper_parameters",
|
||||||
min_version="2.9",
|
min_version="2.9",
|
||||||
request_data_model=GetHyperParamRequest,
|
request_data_model=GetParamsRequest,
|
||||||
)
|
)
|
||||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
|
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
||||||
|
|
||||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
||||||
company_id,
|
company_id,
|
||||||
project_ids=[request.project] if request.project else None,
|
project_ids=[request.project] if request.project else None,
|
||||||
include_subprojects=request.include_subprojects,
|
include_subprojects=request.include_subprojects,
|
||||||
@ -306,7 +322,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
|||||||
def get_hyperparam_values(
|
def get_hyperparam_values(
|
||||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||||
):
|
):
|
||||||
total, values = task_bll.get_hyperparam_distinct_values(
|
total, values = project_queries.get_hyperparam_distinct_values(
|
||||||
company_id,
|
company_id,
|
||||||
project_ids=request.projects,
|
project_ids=request.projects,
|
||||||
section=request.section,
|
section=request.section,
|
||||||
|
@ -6,9 +6,6 @@ from apiserver.tests.automated import TestService
|
|||||||
|
|
||||||
|
|
||||||
class TestQueueAndModelMetadata(TestService):
|
class TestQueueAndModelMetadata(TestService):
|
||||||
def setUp(self, version="2.13"):
|
|
||||||
super().setUp(version=version)
|
|
||||||
|
|
||||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||||
|
|
||||||
def test_queue_metas(self):
|
def test_queue_metas(self):
|
||||||
@ -29,6 +26,23 @@ class TestQueueAndModelMetadata(TestService):
|
|||||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||||
|
|
||||||
|
def test_project_meta_query(self):
|
||||||
|
self._temp_model("TestMetadata", metadata=self.meta1)
|
||||||
|
project = self.temp_project(name="MetaParent")
|
||||||
|
self._temp_model(
|
||||||
|
"TestMetadata2",
|
||||||
|
project=project,
|
||||||
|
metadata=[
|
||||||
|
{"key": "test_key", "type": "str", "value": "test_value"},
|
||||||
|
{"key": "test_key2", "type": "str", "value": "test_value"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
res = self.api.projects.get_model_metadata_keys()
|
||||||
|
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
|
||||||
|
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
|
||||||
|
self.assertTrue("test_key" in res["keys"])
|
||||||
|
self.assertFalse("test_key2" in res["keys"])
|
||||||
|
|
||||||
def _test_meta_operations(
|
def _test_meta_operations(
|
||||||
self, service: APIClient.Service, entity: str, _id: str,
|
self, service: APIClient.Service, entity: str, _id: str,
|
||||||
):
|
):
|
||||||
@ -72,3 +86,12 @@ class TestQueueAndModelMetadata(TestService):
|
|||||||
return self.create_temp(
|
return self.create_temp(
|
||||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def temp_project(self, **kwargs) -> str:
|
||||||
|
self.update_missing(
|
||||||
|
kwargs,
|
||||||
|
name="Test models meta",
|
||||||
|
description="test",
|
||||||
|
delete_params=dict(force=True),
|
||||||
|
)
|
||||||
|
return self.create_temp("projects", **kwargs)
|
||||||
|
@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
|
|||||||
|
|
||||||
|
|
||||||
class TestSubProjects(TestService):
|
class TestSubProjects(TestService):
|
||||||
def setUp(self, **kwargs):
|
|
||||||
super().setUp(version="2.13")
|
|
||||||
|
|
||||||
def test_project_aggregations(self):
|
def test_project_aggregations(self):
|
||||||
"""This test requires user with user_auth_only... credentials in db"""
|
"""This test requires user with user_auth_only... credentials in db"""
|
||||||
user2_client = APIClient(
|
user2_client = APIClient(
|
||||||
@ -203,6 +200,9 @@ class TestSubProjects(TestService):
|
|||||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||||
|
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
|
||||||
|
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
|
||||||
|
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
{sp.name for sp in res1.sub_projects},
|
{sp.name for sp in res1.sub_projects},
|
||||||
{
|
{
|
||||||
@ -215,6 +215,9 @@ class TestSubProjects(TestService):
|
|||||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||||
|
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
|
||||||
|
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
|
||||||
|
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
|
||||||
self.assertEqual(res2.sub_projects, [])
|
self.assertEqual(res2.sub_projects, [])
|
||||||
|
|
||||||
def _run_tasks(self, *tasks):
|
def _run_tasks(self, *tasks):
|
||||||
|
@ -198,6 +198,9 @@ class TestTags(TestService):
|
|||||||
def assertProjectStats(self, project: AttrDict):
|
def assertProjectStats(self, project: AttrDict):
|
||||||
self.assertEqual(set(project.stats.keys()), {"active"})
|
self.assertEqual(set(project.stats.keys()), {"active"})
|
||||||
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
||||||
|
self.assertEqual(project.stats.active.completed_tasks, 1)
|
||||||
|
self.assertEqual(project.stats.active.total_tasks, 1)
|
||||||
|
self.assertEqual(project.stats.active.running_tasks, 0)
|
||||||
for status, count in project.stats.active.status_count.items():
|
for status, count in project.stats.active.status_count.items():
|
||||||
self.assertEqual(count, 1 if status == "stopped" else 0)
|
self.assertEqual(count, 1 if status == "stopped" else 0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user