Add support for credentials label

Support no_scroll in events.get_task_plots
Support better project stats
Fix Redis required on mongodb initialization
Update tests
This commit is contained in:
allegroai 2022-02-13 19:59:58 +02:00
parent 92fd98d5ad
commit 447adb9090
26 changed files with 624 additions and 296 deletions

View File

@ -75,6 +75,7 @@ class CreateUserResponse(Base):
class Credentials(Base): class Credentials(Base):
access_key = StringField(required=True) access_key = StringField(required=True)
secret_key = StringField(required=True) secret_key = StringField(required=True)
label = StringField()
class CredentialsResponse(Credentials): class CredentialsResponse(Credentials):
@ -82,6 +83,10 @@ class CredentialsResponse(Credentials):
last_used = DateTimeField(default=None) last_used = DateTimeField(default=None)
class CreateCredentialsRequest(Base):
label = StringField()
class CreateCredentialsResponse(Base): class CreateCredentialsResponse(Base):
credentials = EmbeddedField(Credentials) credentials = EmbeddedField(Credentials)

View File

@ -135,4 +135,5 @@ class TaskPlotsRequest(Base):
task: str = StringField(required=True) task: str = StringField(required=True)
iters: int = IntField(default=1) iters: int = IntField(default=1)
scroll_id: str = StringField() scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)

View File

@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True) include_subprojects = fields.BoolField(default=True)
class GetHyperParamRequest(ProjectOrNoneRequest): class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0) page = fields.IntField(default=0)
page_size = fields.IntField(default=500) page_size = fields.IntField(default=500)

View File

@ -2,7 +2,11 @@ from datetime import datetime
from apiserver import database from apiserver import database
from apiserver.apierrors import errors from apiserver.apierrors import errors
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel from apiserver.apimodels.auth import (
GetTokenResponse,
CreateUserRequest,
Credentials as CredModel,
)
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
from apiserver.bll.user import UserBLL from apiserver.bll.user import UserBLL
from apiserver.config_repo import config from apiserver.config_repo import config
@ -145,7 +149,7 @@ class AuthBLL:
@classmethod @classmethod
def create_credentials( def create_credentials(
cls, user_id: str, company_id: str, role: str = None cls, user_id: str, company_id: str, role: str = None, label: str = None,
) -> CredModel: ) -> CredModel:
with translate_errors_context(): with translate_errors_context():
@ -154,7 +158,9 @@ class AuthBLL:
if not user: if not user:
raise errors.bad_request.InvalidUserId(**query) raise errors.bad_request.InvalidUserId(**query)
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key()) cred = CredModel(
access_key=get_client_id(), secret_key=get_secret_key(), label=label
)
user.credentials.append( user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key) Credentials(key=cred.access_key, secret=cred.secret_key)
) )

View File

@ -534,6 +534,7 @@ class EventBLL(object):
sort=None, sort=None,
size: int = 500, size: int = 500,
scroll_id: str = None, scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None, metric_variants: MetricVariants = None,
): ):
if scroll_id == self.empty_scroll: if scroll_id == self.empty_scroll:
@ -611,7 +612,7 @@ class EventBLL(object):
event_type=event_type, event_type=event_type,
body=es_req, body=es_req,
ignore=404, ignore=404,
scroll="1h", **({} if no_scroll else {"scroll": "1h"}),
) )
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@ -680,6 +681,7 @@ class EventBLL(object):
sort=None, sort=None,
size=500, size=500,
scroll_id=None, scroll_id=None,
no_scroll=False,
) -> TaskEventsResult: ) -> TaskEventsResult:
if scroll_id == self.empty_scroll: if scroll_id == self.empty_scroll:
return TaskEventsResult() return TaskEventsResult()
@ -740,7 +742,7 @@ class EventBLL(object):
event_type=event_type, event_type=event_type,
body=es_req, body=es_req,
ignore=404, ignore=404,
scroll="1h", **({} if no_scroll else {"scroll": "1h"}),
) )
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)

View File

@ -1,2 +1,3 @@
from .project_bll import ProjectBLL from .project_bll import ProjectBLL
from .project_queries import ProjectQueries
from .sub_projects import _ids_with_children as project_ids_with_children from .sub_projects import _ids_with_children as project_ids_with_children

View File

@ -1,6 +1,6 @@
import itertools import itertools
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime, timedelta
from functools import reduce from functools import reduce
from itertools import groupby from itertools import groupby
from operator import itemgetter from operator import itemgetter
@ -306,6 +306,7 @@ class ProjectBLL:
return project return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]} archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
@classmethod @classmethod
def make_projects_get_all_pipelines( def make_projects_get_all_pipelines(
@ -367,6 +368,26 @@ class ProjectBLL:
}, },
] ]
def completed_after_subquery(additional_cond, time_thresh: datetime):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed after the time_thresh
"if": {
"$and": [
"$completed",
{"$gt": ["$completed", time_thresh]},
additional_cond,
]
},
"then": 1,
"else": 0,
}
}
}
def runtime_subquery(additional_cond): def runtime_subquery(additional_cond):
return { return {
# the sum of # the sum of
@ -397,16 +418,19 @@ class ProjectBLL:
} }
group_step = {"_id": "$project"} group_step = {"_id": "$project"}
time_thresh = datetime.utcnow() - timedelta(hours=24)
for state in EntityVisibility: for state in cls.visibility_states:
if specific_state and state != specific_state: if specific_state and state != specific_state:
continue continue
if state == EntityVisibility.active: cond = (
group_step[state.value] = runtime_subquery( cls.archived_tasks_cond
{"$not": cls.archived_tasks_cond} if state == EntityVisibility.archived
) else {"$not": cls.archived_tasks_cond}
elif state == EntityVisibility.archived: )
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond) group_step[state.value] = runtime_subquery(cond)
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
cond, time_thresh=time_thresh
)
runtime_pipeline = [ runtime_pipeline = [
# only count run time for these types of tasks # only count run time for these types of tasks
@ -534,15 +558,24 @@ class ProjectBLL:
) )
def get_status_counts(project_id, section): def get_status_counts(project_id, section):
project_runtime = runtime.get(project_id, {})
project_section_statuses = nested_get(
status_count, (project_id, section), default=default_counts
)
return { return {
"total_runtime": nested_get(runtime, (project_id, section), default=0), "status_count": project_section_statuses,
"status_count": nested_get( "running_tasks": project_section_statuses.get(TaskStatus.in_progress),
status_count, (project_id, section), default=default_counts "total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0),
"completed_tasks": project_runtime.get(
f"{section}_recently_completed", 0
), ),
} }
report_for_states = [ report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s s
for s in cls.visibility_states
if not specific_state or specific_state == s
] ]
stats = { stats = {

View File

@ -0,0 +1,289 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import (
Sequence,
Optional,
Tuple,
)
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .sub_projects import _ids_with_children
from ...database.model.model import Model
from ...database.model.task.task import Task
log = config.logger(__file__)
class ProjectQueries:
def __init__(self, redis=None):
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
if include_subprojects:
if project_ids is None:
return {}
project_ids = _ids_with_children(project_ids)
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": [None, "", company_id]}}
return {"company": company_id}
@classmethod
def get_aggregated_project_parameters(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"hyperparams": {"$exists": True, "$gt": {}},
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "section"))
),
"name": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "name"))
),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$ne": []},
}
},
{"$project": {"metadata": 1}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.key"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [r.get("_id") for r in result.get("results", [])]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@ -1,9 +1,6 @@
import json from datetime import datetime
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Collection, Sequence, Tuple, Any, Optional, Dict from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six import six
from mongoengine import Q from mongoengine import Q
from redis import StrictRedis from redis import StrictRedis
@ -14,7 +11,7 @@ from apiserver.apierrors import errors
from apiserver.apimodels.tasks import TaskInputModel from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model from apiserver.database.model.model import Model
@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
from apiserver.service_repo import APICall from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save from .param_utils import params_prepare_for_save
from .utils import ( from .utils import (
@ -350,54 +346,6 @@ class TaskBLL:
if validate_models: if validate_models:
cls.validate_input_models(task) cls.validate_input_models(task)
@staticmethod
def get_unique_metric_variants(
company_id, project_ids: Sequence[str], include_subprojects: bool
):
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]}, **project_constraint,
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod @staticmethod
def set_last_update( def set_last_update(
task_ids: Collection[str], task_ids: Collection[str],
@ -494,173 +442,6 @@ class TaskBLL:
**extra_updates, **extra_updates,
) )
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**project_constraint,
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
company_constraint = {"company": {"$in": [None, "", company_id]}}
else:
company_constraint = {"company": company_id}
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod @classmethod
def dequeue_and_change_status( def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str, cls, task: Task, company_id: str, status_message: str, status_reason: str,

View File

@ -48,6 +48,7 @@ class Credentials(EmbeddedDocument):
meta = {"strict": False} meta = {"strict": False}
key = StringField(required=True) key = StringField(required=True)
secret = StringField(required=True) secret = StringField(required=True)
label = StringField()
last_used = DateTimeField() last_used = DateTimeField()

View File

@ -142,7 +142,9 @@ class GetMixin(PropsMixin):
self.allow_empty = False self.allow_empty = False
def _get_op(self, v: str, translate: bool = False) -> Optional[str]: def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
)
if translate: if translate:
tup = self._ops.get(op, None) tup = self._ops.get(op, None)
return tup[0] if tup else None return tup[0] if tup else None
@ -177,7 +179,9 @@ class GetMixin(PropsMixin):
"all": Q.AND, "all": Q.AND,
} }
data = (x for x in data if x is not None) data = (x for x in data if x is not None)
first_op = self._get_op(next(data, ""), translate=True) or self.default_mongo_op first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op) return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]: def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
@ -202,13 +206,20 @@ class GetMixin(PropsMixin):
id = StringField(primary_key=True) id = StringField(primary_key=True)
position = IntField(default=0) position = IntField(default=0)
cache_manager = RedisCacheManager( _cache_manager = None
state_class=GetManyScrollState,
redis=redman.connection("apiserver"), @classmethod
expiration_interval=config.get( def get_cache_manager(cls):
"services._mongo.scroll_state_expiration_seconds", 600 if not cls._cache_manager:
), cls._cache_manager = RedisCacheManager(
) state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
),
)
return cls._cache_manager
@classmethod @classmethod
def get( def get(
@ -463,10 +474,7 @@ class GetMixin(PropsMixin):
if not queries: if not queries:
q = RegexQ() q = RegexQ()
else: else:
q = RegexQCombination( q = RegexQCombination(operation=global_op, children=queries)
operation=global_op,
children=queries
)
if not helper.allow_empty: if not helper.allow_empty:
return q return q
@ -609,7 +617,7 @@ class GetMixin(PropsMixin):
state: Optional[cls.GetManyScrollState] = None state: Optional[cls.GetManyScrollState] = None
if "scroll_id" in query_dict: if "scroll_id" in query_dict:
size = cls.validate_scroll_size(query_dict) size = cls.validate_scroll_size(query_dict)
state = cls.cache_manager.get_or_create_state_core( state = cls.get_cache_manager().get_or_create_state_core(
query_dict.get("scroll_id") query_dict.get("scroll_id")
) )
if query_dict.get("refresh_scroll"): if query_dict.get("refresh_scroll"):
@ -625,7 +633,7 @@ class GetMixin(PropsMixin):
if not state: if not state:
return return
state.position = query_dict[cls._start_key] state.position = query_dict[cls._start_key]
cls.cache_manager.set_state(state) cls.get_cache_manager().set_state(state)
if ret_params is not None: if ret_params is not None:
ret_params["scroll_id"] = state.id ret_params["scroll_id"] = state.id

View File

@ -4,7 +4,12 @@ from threading import Lock
from typing import Sequence from typing import Sequence
import six import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField from mongoengine import (
EmbeddedDocumentField,
EmbeddedDocumentListField,
EmbeddedDocument,
Document,
)
from mongoengine.base import get_document from mongoengine.base import get_document
from apiserver.database.fields import ( from apiserver.database.fields import (
@ -25,6 +30,13 @@ class PropsMixin(object):
__cached_dpath_computed_fields_lock = Lock() __cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None __cached_dpath_computed_fields = None
_document_classes = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, (Document, EmbeddedDocument)):
PropsMixin._document_classes[cls._class_name] = cls
@classmethod @classmethod
def get_fields(cls): def get_fields(cls):
if cls.__cached_fields is None: if cls.__cached_fields is None:
@ -57,8 +69,14 @@ class PropsMixin(object):
def resolve_doc(v): def resolve_doc(v):
if not isinstance(v, six.string_types): if not isinstance(v, six.string_types):
return v return v
if v == 'self':
if v == "self":
return cls_.owner_document return cls_.owner_document
doc_cls = PropsMixin._document_classes.get(v)
if doc_cls:
return doc_cls
return get_document(v) return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()} fields = {k: resolve_doc(v) for k, v in res.items()}
@ -72,7 +90,7 @@ class PropsMixin(object):
).document_type ).document_type
fields.update( fields.update(
{ {
'.'.join((field, subfield)): doc ".".join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr( for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr embedded_doc_cls, attr
).items() ).items()
@ -80,10 +98,10 @@ class PropsMixin(object):
) )
collect_embedded_docs(EmbeddedDocumentField, lambda x: x) collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field')) collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field')) collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field')) collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field')) collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
return fields return fields
@ -94,7 +112,7 @@ class PropsMixin(object):
for depth, part in enumerate(parts): for depth, part in enumerate(parts):
if current_cls is None: if current_cls is None:
raise ValueError( raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1] "Invalid path (non-document encountered at %s)" % parts[: depth - 1]
) )
try: try:
field_name, field = next( field_name, field = next(
@ -103,7 +121,7 @@ class PropsMixin(object):
if k == part if k == part
) )
except StopIteration: except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth]) raise ValueError("Invalid field path %s" % parts[:depth])
translated_parts.append(part) translated_parts.append(part)
@ -119,7 +137,7 @@ class PropsMixin(object):
), ),
): ):
current_cls = field.field.document_type current_cls = field.field.document_type
translated_parts.append('*') translated_parts.append("*")
else: else:
current_cls = None current_cls = None
@ -128,7 +146,7 @@ class PropsMixin(object):
@classmethod @classmethod
def get_reference_fields(cls): def get_reference_fields(cls):
if cls.__cached_reference_fields is None: if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field') fields = cls._get_fields_with_attr(cls, "reference_field")
cls.__cached_reference_fields = OrderedDict(sorted(fields.items())) cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields return cls.__cached_reference_fields
@ -143,12 +161,12 @@ class PropsMixin(object):
@classmethod @classmethod
def get_exclude_fields(cls): def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None: if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default') fields = cls._get_fields_with_attr(cls, "exclude_by_default")
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items())) cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields return cls.__cached_exclude_fields
@classmethod @classmethod
def get_dpath_translated_path(cls, path, separator='.'): def get_dpath_translated_path(cls, path, separator="."):
if cls.__cached_dpath_computed_fields is None: if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {} cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields: if path not in cls.__cached_dpath_computed_fields:

View File

@ -1,7 +1,8 @@
{ {
"index_patterns": "events-*", "index_patterns": "events-*",
"settings": { "settings": {
"number_of_shards": 1 "number_of_shards": 1,
"number_of_replicas": 0
}, },
"mappings": { "mappings": {
"_source": { "_source": {

View File

@ -1,7 +1,8 @@
{ {
"index_patterns": "queue_metrics_*", "index_patterns": "queue_metrics_*",
"settings": { "settings": {
"number_of_shards": 1 "number_of_shards": 1,
"number_of_replicas": 0
}, },
"mappings": { "mappings": {
"_source": { "_source": {

View File

@ -1,7 +1,8 @@
{ {
"index_patterns": "worker_stats_*", "index_patterns": "worker_stats_*",
"settings": { "settings": {
"number_of_shards": 1 "number_of_shards": 1,
"number_of_replicas": 0
}, },
"mappings": { "mappings": {
"_source": { "_source": {

View File

@ -26,6 +26,10 @@ credentials {
type: string type: string
description: Credentials secret key description: Credentials secret key
} }
label {
type: string
description: Optional credentials label
}
} }
} }
batch_operation { batch_operation {

View File

@ -15,6 +15,10 @@ _definitions {
type: string type: string
description: "" description: ""
} }
label {
type: string
description: Optional credentials label
}
last_used { last_used {
type: string type: string
description: "" description: ""
@ -222,6 +226,12 @@ create_credentials {
} }
} }
} }
"999.0": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
} }
get_credentials { get_credentials {

View File

@ -889,6 +889,13 @@ get_task_plots {
} }
} }
} }
"2.16": ${get_task_plots."2.14"} {
request.properties.no_scroll {
description: If true then no scroll is created. Suitable for one time calls
type: boolean
default: false
}
}
} }
get_multi_task_plots { get_multi_task_plots {
"2.1" { "2.1" {
@ -939,6 +946,13 @@ get_multi_task_plots {
} }
} }
} }
"2.16": ${get_multi_task_plots."2.1"} {
request.properties.no_scroll {
description: If true then no scroll is created. Suitable for one time calls
type: boolean
default: false
}
}
} }
get_vector_metrics_and_variants { get_vector_metrics_and_variants {
"2.1" { "2.1" {
@ -1218,6 +1232,13 @@ get_scalar_metric_data {
} }
} }
} }
"2.16": ${get_scalar_metric_data."2.1"} {
request.properties.no_scroll {
description: If true then no scroll is created. Suitable for one time calls
type: boolean
default: false
}
}
} }
scalar_metrics_iter_raw { scalar_metrics_iter_raw {
"2.16" { "2.16" {

View File

@ -545,6 +545,42 @@ get_all_ex {
type: boolean type: boolean
default: true default: true
} }
response {
properties {
stats {
properties {
active.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
archived.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
}
}
}
}
} }
} }
update { update {
@ -893,7 +929,55 @@ get_hyper_parameters {
} }
} }
} }
get_model_metadata_keys {
"999.0" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_task_tags { get_task_tags {
"2.8" { "2.8" {
description: "Get user and system tags used for the tasks under the specified projects" description: "Get user and system tags used for the tasks under the specified projects"

View File

@ -534,7 +534,7 @@ _definitions {
container { container {
description: "Docker container parameters" description: "Docker container parameters"
type: object type: object
additionalProperties { type: string } additionalProperties { type: [string, null] }
} }
models { models {
description: "Task models" description: "Task models"
@ -981,7 +981,7 @@ clone {
new_task_container { new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task" description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object type: object
additionalProperties { type: string } additionalProperties { type: [string, null] }
} }
} }
} }
@ -1159,7 +1159,7 @@ create {
container { container {
description: "Docker container parameters" description: "Docker container parameters"
type: object type: object
additionalProperties { type: string } additionalProperties { type: [string, null] }
} }
} }
} }
@ -1248,7 +1248,7 @@ validate {
container { container {
description: "Docker container parameters" description: "Docker container parameters"
type: object type: object
additionalProperties { type: string } additionalProperties { type: [string, null] }
} }
} }
} }
@ -1410,7 +1410,7 @@ edit {
container { container {
description: "Docker container parameters" description: "Docker container parameters"
type: object type: object
additionalProperties { type: string } additionalProperties { type: [string, null] }
} }
runtime { runtime {
description: "Task runtime mapping" description: "Task runtime mapping"

View File

@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
CredentialsResponse, CredentialsResponse,
RevokeCredentialsRequest, RevokeCredentialsRequest,
EditUserReq, EditUserReq,
CreateCredentialsRequest,
) )
from apiserver.apimodels.base import UpdateResponse from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.auth import AuthBLL from apiserver.bll.auth import AuthBLL
@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
""" Generates a token based on a requested user and company. INTERNAL. """ """ Generates a token based on a requested user and company. INTERNAL. """
if call.identity.role not in Role.get_system_roles(): if call.identity.role not in Role.get_system_roles():
if call.identity.role != Role.admin and call.identity.user != request.user: if call.identity.role != Role.admin and call.identity.user != request.user:
raise errors.bad_request.InvalidUserId("cannot generate token for another user") raise errors.bad_request.InvalidUserId(
"cannot generate token for another user"
)
if call.identity.company != request.company: if call.identity.company != request.company:
raise errors.bad_request.InvalidId("cannot generate token in another company") raise errors.bad_request.InvalidId(
"cannot generate token in another company"
)
call.result.data_model = AuthBLL.get_token_for_user( call.result.data_model = AuthBLL.get_token_for_user(
user_id=request.user, user_id=request.user,
@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
) )
def create_user(call: APICall, _, request: CreateUserRequest): def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """ """ Create a user from. INTERNAL. """
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company: if (
call.identity.role not in Role.get_system_roles()
and request.company != call.identity.company
):
raise errors.bad_request.InvalidId("cannot create user in another company") raise errors.bad_request.InvalidId("cannot create user in another company")
user_id = AuthBLL.create_user(request=request, call=call) user_id = AuthBLL.create_user(request=request, call=call)
@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse) @endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
def create_credentials(call: APICall, _, __): def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
if _is_protected_user(call.identity.user): if _is_protected_user(call.identity.user):
raise errors.bad_request.InvalidUserId("protected identity") raise errors.bad_request.InvalidUserId("protected identity")
@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
user_id=call.identity.user, user_id=call.identity.user,
company_id=call.identity.company, company_id=call.identity.company,
role=call.identity.role, role=call.identity.role,
label=request.label,
) )
call.result.data_model = CreateCredentialsResponse(credentials=credentials) call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@ -151,7 +160,9 @@ def get_credentials(call: APICall, _, __):
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials) # we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
call.result.data_model = GetCredentialsResponse( call.result.data_model = GetCredentialsResponse(
credentials=[ credentials=[
CredentialsResponse(access_key=c.key, last_used=c.last_used) CredentialsResponse(
access_key=c.key, last_used=c.last_used, label=c.label
)
for c in user.credentials for c in user.credentials
] ]
) )

View File

@ -361,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
metric = call.data["metric"] metric = call.data["metric"]
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin") company_id, task_id, allow_public=True, only=("company", "company_origin")
@ -372,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
metric=metric, metric=metric,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=no_scroll,
) )
call.result.data = dict( call.result.data = dict(
@ -494,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
task_ids = call.data["tasks"] task_ids = call.data["tasks"]
iters = call.data.get("iters", 1) iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
tasks = task_bll.assert_exists( tasks = task_bll.assert_exists(
company_id=call.identity.company, company_id=call.identity.company,
@ -515,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters, last_iter_count=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=no_scroll,
) )
tasks = {t.id: t.name for t in tasks} tasks = {t.id: t.name for t in tasks}
@ -593,6 +597,7 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters, last_iterations_per_plot=iters,
scroll_id=scroll_id, scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics), metric_variants=_get_metric_variants_from_request(request.metrics),
) )

View File

@ -7,7 +7,7 @@ from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import ( from apiserver.apimodels.projects import (
GetHyperParamRequest, GetParamsRequest,
ProjectTagsRequest, ProjectTagsRequest,
ProjectTaskParentsRequest, ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest, ProjectHyperparamValuesRequest,
@ -19,12 +19,11 @@ from apiserver.apimodels.projects import (
ProjectRequest, ProjectRequest,
) )
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_cleanup import ( from apiserver.bll.project.project_cleanup import (
delete_project, delete_project,
validate_project_delete, validate_project_delete,
) )
from apiserver.bll.task import TaskBLL
from apiserver.database.errors import translate_errors_context from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project from apiserver.database.model.project import Project
from apiserver.database.utils import ( from apiserver.database.utils import (
@ -41,8 +40,8 @@ from apiserver.services.utils import (
from apiserver.timing_context import TimingContext from apiserver.timing_context import TimingContext
org_bll = OrgBLL() org_bll = OrgBLL()
task_bll = TaskBLL()
project_bll = ProjectBLL() project_bll = ProjectBLL()
project_queries = ProjectQueries()
create_fields = { create_fields = {
"name": None, "name": None,
@ -267,7 +266,7 @@ def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest call: APICall, company_id: str, request: ProjectOrNoneRequest
): ):
metrics = task_bll.get_unique_metric_variants( metrics = project_queries.get_unique_metric_variants(
company_id, company_id,
[request.project] if request.project else None, [request.project] if request.project else None,
include_subprojects=request.include_subprojects, include_subprojects=request.include_subprojects,
@ -276,14 +275,31 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics} call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"keys": keys,
}
@endpoint( @endpoint(
"projects.get_hyper_parameters", "projects.get_hyper_parameters",
min_version="2.9", min_version="2.9",
request_data_model=GetHyperParamRequest, request_data_model=GetParamsRequest,
) )
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest): def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters( total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id, company_id,
project_ids=[request.project] if request.project else None, project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects, include_subprojects=request.include_subprojects,
@ -306,7 +322,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
def get_hyperparam_values( def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
): ):
total, values = task_bll.get_hyperparam_distinct_values( total, values = project_queries.get_hyperparam_distinct_values(
company_id, company_id,
project_ids=request.projects, project_ids=request.projects,
section=request.section, section=request.section,

View File

@ -6,9 +6,6 @@ from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService): class TestQueueAndModelMetadata(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}] meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
def test_queue_metas(self): def test_queue_metas(self):
@ -29,6 +26,23 @@ class TestQueueAndModelMetadata(TestService):
self.api.models.edit(model=model_id, metadata=[self.meta1[0]]) self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1) self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
self._temp_model(
"TestMetadata2",
project=project,
metadata=[
{"key": "test_key", "type": "str", "value": "test_value"},
{"key": "test_key2", "type": "str", "value": "test_value"},
],
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue("test_key" in res["keys"])
self.assertFalse("test_key2" in res["keys"])
def _test_meta_operations( def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str, self, service: APIClient.Service, entity: str, _id: str,
): ):
@ -72,3 +86,12 @@ class TestQueueAndModelMetadata(TestService):
return self.create_temp( return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs "models", uri="file://test", name=name, labels={}, **kwargs
) )
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test models meta",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)

View File

@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestSubProjects(TestService): class TestSubProjects(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_project_aggregations(self): def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db""" """This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient( user2_client = APIClient(
@ -203,6 +200,9 @@ class TestSubProjects(TestService):
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0) self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2) self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["total_runtime"], 2) self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
self.assertEqual( self.assertEqual(
{sp.name for sp in res1.sub_projects}, {sp.name for sp in res1.sub_projects},
{ {
@ -215,6 +215,9 @@ class TestSubProjects(TestService):
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0) self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0) self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0) self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
self.assertEqual(res2.sub_projects, []) self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks): def _run_tasks(self, *tasks):

View File

@ -198,6 +198,9 @@ class TestTags(TestService):
def assertProjectStats(self, project: AttrDict): def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"}) self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0) self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
self.assertEqual(project.stats.active.completed_tasks, 1)
self.assertEqual(project.stats.active.total_tasks, 1)
self.assertEqual(project.stats.active.running_tasks, 0)
for status, count in project.stats.active.status_count.items(): for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0) self.assertEqual(count, 1 if status == "stopped" else 0)