Add support for credentials label

Support no_scroll in events.get_task_plots
Support better project stats
Fix Redis required on mongodb initialization
Update tests
This commit is contained in:
allegroai 2022-02-13 19:59:58 +02:00
parent 92fd98d5ad
commit 447adb9090
26 changed files with 624 additions and 296 deletions

View File

@ -75,6 +75,7 @@ class CreateUserResponse(Base):
class Credentials(Base):
access_key = StringField(required=True)
secret_key = StringField(required=True)
label = StringField()
class CredentialsResponse(Credentials):
@ -82,6 +83,10 @@ class CredentialsResponse(Credentials):
last_used = DateTimeField(default=None)
class CreateCredentialsRequest(Base):
label = StringField()
class CreateCredentialsResponse(Base):
credentials = EmbeddedField(Credentials)

View File

@ -135,4 +135,5 @@ class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)

View File

@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True)
class GetHyperParamRequest(ProjectOrNoneRequest):
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)

View File

@ -2,7 +2,11 @@ from datetime import datetime
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
from apiserver.apimodels.auth import (
GetTokenResponse,
CreateUserRequest,
Credentials as CredModel,
)
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
@ -145,7 +149,7 @@ class AuthBLL:
@classmethod
def create_credentials(
cls, user_id: str, company_id: str, role: str = None
cls, user_id: str, company_id: str, role: str = None, label: str = None,
) -> CredModel:
with translate_errors_context():
@ -154,7 +158,9 @@ class AuthBLL:
if not user:
raise errors.bad_request.InvalidUserId(**query)
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
cred = CredModel(
access_key=get_client_id(), secret_key=get_secret_key(), label=label
)
user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key)
)

View File

@ -534,6 +534,7 @@ class EventBLL(object):
sort=None,
size: int = 500,
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
):
if scroll_id == self.empty_scroll:
@ -611,7 +612,7 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@ -680,6 +681,7 @@ class EventBLL(object):
sort=None,
size=500,
scroll_id=None,
no_scroll=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@ -740,7 +742,7 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)

View File

@ -1,2 +1,3 @@
from .project_bll import ProjectBLL
from .project_queries import ProjectQueries
from .sub_projects import _ids_with_children as project_ids_with_children

View File

@ -1,6 +1,6 @@
import itertools
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timedelta
from functools import reduce
from itertools import groupby
from operator import itemgetter
@ -306,6 +306,7 @@ class ProjectBLL:
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
@classmethod
def make_projects_get_all_pipelines(
@ -367,6 +368,26 @@ class ProjectBLL:
},
]
def completed_after_subquery(additional_cond, time_thresh: datetime):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed after the time_thresh
"if": {
"$and": [
"$completed",
{"$gt": ["$completed", time_thresh]},
additional_cond,
]
},
"then": 1,
"else": 0,
}
}
}
def runtime_subquery(additional_cond):
return {
# the sum of
@ -397,16 +418,19 @@ class ProjectBLL:
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
time_thresh = datetime.utcnow() - timedelta(hours=24)
for state in cls.visibility_states:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery(
{"$not": cls.archived_tasks_cond}
)
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
cond = (
cls.archived_tasks_cond
if state == EntityVisibility.archived
else {"$not": cls.archived_tasks_cond}
)
group_step[state.value] = runtime_subquery(cond)
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
cond, time_thresh=time_thresh
)
runtime_pipeline = [
# only count run time for these types of tasks
@ -534,15 +558,24 @@ class ProjectBLL:
)
def get_status_counts(project_id, section):
project_runtime = runtime.get(project_id, {})
project_section_statuses = nested_get(
status_count, (project_id, section), default=default_counts
)
return {
"total_runtime": nested_get(runtime, (project_id, section), default=0),
"status_count": nested_get(
status_count, (project_id, section), default=default_counts
"status_count": project_section_statuses,
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
"total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0),
"completed_tasks": project_runtime.get(
f"{section}_recently_completed", 0
),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
s
for s in cls.visibility_states
if not specific_state or specific_state == s
]
stats = {

View File

@ -0,0 +1,289 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import (
Sequence,
Optional,
Tuple,
)
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .sub_projects import _ids_with_children
from ...database.model.model import Model
from ...database.model.task.task import Task
log = config.logger(__file__)
class ProjectQueries:
def __init__(self, redis=None):
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
if include_subprojects:
if project_ids is None:
return {}
project_ids = _ids_with_children(project_ids)
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": [None, "", company_id]}}
return {"company": company_id}
@classmethod
def get_aggregated_project_parameters(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"hyperparams": {"$exists": True, "$gt": {}},
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "section"))
),
"name": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "name"))
),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$ne": []},
}
},
{"$project": {"metadata": 1}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.key"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [r.get("_id") for r in result.get("results", [])]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@ -1,9 +1,6 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
@ -14,7 +11,7 @@ from apiserver.apierrors import errors
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@ -350,54 +346,6 @@ class TaskBLL:
if validate_models:
cls.validate_input_models(task)
@staticmethod
def get_unique_metric_variants(
company_id, project_ids: Sequence[str], include_subprojects: bool
):
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]}, **project_constraint,
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
def set_last_update(
task_ids: Collection[str],
@ -494,173 +442,6 @@ class TaskBLL:
**extra_updates,
)
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**project_constraint,
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
company_constraint = {"company": {"$in": [None, "", company_id]}}
else:
company_constraint = {"company": company_id}
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,

View File

@ -48,6 +48,7 @@ class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
label = StringField()
last_used = DateTimeField()

View File

@ -142,7 +142,9 @@ class GetMixin(PropsMixin):
self.allow_empty = False
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = v[len(self.op_prefix):] if v and v.startswith(self.op_prefix) else None
op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
@ -177,7 +179,9 @@ class GetMixin(PropsMixin):
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = self._get_op(next(data, ""), translate=True) or self.default_mongo_op
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
@ -202,13 +206,20 @@ class GetMixin(PropsMixin):
id = StringField(primary_key=True)
position = IntField(default=0)
cache_manager = RedisCacheManager(
state_class=GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
),
)
_cache_manager = None
@classmethod
def get_cache_manager(cls):
if not cls._cache_manager:
cls._cache_manager = RedisCacheManager(
state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
),
)
return cls._cache_manager
@classmethod
def get(
@ -463,10 +474,7 @@ class GetMixin(PropsMixin):
if not queries:
q = RegexQ()
else:
q = RegexQCombination(
operation=global_op,
children=queries
)
q = RegexQCombination(operation=global_op, children=queries)
if not helper.allow_empty:
return q
@ -609,7 +617,7 @@ class GetMixin(PropsMixin):
state: Optional[cls.GetManyScrollState] = None
if "scroll_id" in query_dict:
size = cls.validate_scroll_size(query_dict)
state = cls.cache_manager.get_or_create_state_core(
state = cls.get_cache_manager().get_or_create_state_core(
query_dict.get("scroll_id")
)
if query_dict.get("refresh_scroll"):
@ -625,7 +633,7 @@ class GetMixin(PropsMixin):
if not state:
return
state.position = query_dict[cls._start_key]
cls.cache_manager.set_state(state)
cls.get_cache_manager().set_state(state)
if ret_params is not None:
ret_params["scroll_id"] = state.id

View File

@ -4,7 +4,12 @@ from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine import (
EmbeddedDocumentField,
EmbeddedDocumentListField,
EmbeddedDocument,
Document,
)
from mongoengine.base import get_document
from apiserver.database.fields import (
@ -25,6 +30,13 @@ class PropsMixin(object):
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
_document_classes = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, (Document, EmbeddedDocument)):
PropsMixin._document_classes[cls._class_name] = cls
@classmethod
def get_fields(cls):
if cls.__cached_fields is None:
@ -57,8 +69,14 @@ class PropsMixin(object):
def resolve_doc(v):
if not isinstance(v, six.string_types):
return v
if v == 'self':
if v == "self":
return cls_.owner_document
doc_cls = PropsMixin._document_classes.get(v)
if doc_cls:
return doc_cls
return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()}
@ -72,7 +90,7 @@ class PropsMixin(object):
).document_type
fields.update(
{
'.'.join((field, subfield)): doc
".".join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr
).items()
@ -80,10 +98,10 @@ class PropsMixin(object):
)
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
return fields
@ -94,7 +112,7 @@ class PropsMixin(object):
for depth, part in enumerate(parts):
if current_cls is None:
raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
)
try:
field_name, field = next(
@ -103,7 +121,7 @@ class PropsMixin(object):
if k == part
)
except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth])
raise ValueError("Invalid field path %s" % parts[:depth])
translated_parts.append(part)
@ -119,7 +137,7 @@ class PropsMixin(object):
),
):
current_cls = field.field.document_type
translated_parts.append('*')
translated_parts.append("*")
else:
current_cls = None
@ -128,7 +146,7 @@ class PropsMixin(object):
@classmethod
def get_reference_fields(cls):
if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field')
fields = cls._get_fields_with_attr(cls, "reference_field")
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@ -143,12 +161,12 @@ class PropsMixin(object):
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields
@classmethod
def get_dpath_translated_path(cls, path, separator='.'):
def get_dpath_translated_path(cls, path, separator="."):
if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields:

View File

@ -1,7 +1,8 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@ -1,7 +1,8 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@ -1,7 +1,8 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@ -26,6 +26,10 @@ credentials {
type: string
description: Credentials secret key
}
label {
type: string
description: Optional credentials label
}
}
}
batch_operation {

View File

@ -15,6 +15,10 @@ _definitions {
type: string
description: ""
}
label {
type: string
description: Optional credentials label
}
last_used {
type: string
description: ""
@ -222,6 +226,12 @@ create_credentials {
}
}
}
"999.0": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
}
get_credentials {

View File

@ -889,6 +889,13 @@ get_task_plots {
}
}
}
"2.16": ${get_task_plots."2.14"} {
request.properties.no_scroll {
description: If true then no scroll is created. Suitable for one time calls
type: boolean
default: false
}
}
}
get_multi_task_plots {
"2.1" {
@ -939,6 +946,13 @@ get_multi_task_plots {
}
}
}
"2.16": ${get_multi_task_plots."2.1"} {
request.properties.no_scroll {
description: If true then no scroll is created. Suitable for one time calls
type: boolean
default: false
}
}
}
get_vector_metrics_and_variants {
"2.1" {
@ -1218,6 +1232,13 @@ get_scalar_metric_data {
}
}
}
"2.16": ${get_scalar_metric_data."2.1"} {
request.properties.no_scroll {
description: If true then no scroll is created. Suitable for one time calls
type: boolean
default: false
}
}
}
scalar_metrics_iter_raw {
"2.16" {

View File

@ -545,6 +545,42 @@ get_all_ex {
type: boolean
default: true
}
response {
properties {
stats {
properties {
active.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
archived.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
}
}
}
}
}
}
update {
@ -893,7 +929,55 @@ get_hyper_parameters {
}
}
}
get_model_metadata_keys {
"999.0" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"

View File

@ -534,7 +534,7 @@ _definitions {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
models {
description: "Task models"
@ -981,7 +981,7 @@ clone {
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@ -1159,7 +1159,7 @@ create {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@ -1248,7 +1248,7 @@ validate {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@ -1410,7 +1410,7 @@ edit {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
runtime {
description: "Task runtime mapping"

View File

@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
CredentialsResponse,
RevokeCredentialsRequest,
EditUserReq,
CreateCredentialsRequest,
)
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.auth import AuthBLL
@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
""" Generates a token based on a requested user and company. INTERNAL. """
if call.identity.role not in Role.get_system_roles():
if call.identity.role != Role.admin and call.identity.user != request.user:
raise errors.bad_request.InvalidUserId("cannot generate token for another user")
raise errors.bad_request.InvalidUserId(
"cannot generate token for another user"
)
if call.identity.company != request.company:
raise errors.bad_request.InvalidId("cannot generate token in another company")
raise errors.bad_request.InvalidId(
"cannot generate token in another company"
)
call.result.data_model = AuthBLL.get_token_for_user(
user_id=request.user,
@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
)
def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company:
if (
call.identity.role not in Role.get_system_roles()
and request.company != call.identity.company
):
raise errors.bad_request.InvalidId("cannot create user in another company")
user_id = AuthBLL.create_user(request=request, call=call)
@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
def create_credentials(call: APICall, _, __):
def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
if _is_protected_user(call.identity.user):
raise errors.bad_request.InvalidUserId("protected identity")
@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
user_id=call.identity.user,
company_id=call.identity.company,
role=call.identity.role,
label=request.label,
)
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@ -151,7 +160,9 @@ def get_credentials(call: APICall, _, __):
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
call.result.data_model = GetCredentialsResponse(
credentials=[
CredentialsResponse(access_key=c.key, last_used=c.last_used)
CredentialsResponse(
access_key=c.key, last_used=c.last_used, label=c.label
)
for c in user.credentials
]
)

View File

@ -361,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@ -372,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id,
no_scroll=no_scroll,
)
call.result.data = dict(
@ -494,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
tasks = task_bll.assert_exists(
company_id=call.identity.company,
@ -515,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
)
tasks = {t.id: t.name for t in tasks}
@ -593,6 +597,7 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics),
)

View File

@ -7,7 +7,7 @@ from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
GetHyperParamRequest,
GetParamsRequest,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
@ -19,12 +19,11 @@ from apiserver.apimodels.projects import (
ProjectRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_cleanup import (
delete_project,
validate_project_delete,
)
from apiserver.bll.task import TaskBLL
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.utils import (
@ -41,8 +40,8 @@ from apiserver.services.utils import (
from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
task_bll = TaskBLL()
project_bll = ProjectBLL()
project_queries = ProjectQueries()
create_fields = {
"name": None,
@ -267,7 +266,7 @@ def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
):
metrics = task_bll.get_unique_metric_variants(
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
@ -276,14 +275,31 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"keys": keys,
}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",
request_data_model=GetHyperParamRequest,
request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
@ -306,7 +322,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
total, values = task_bll.get_hyperparam_distinct_values(
total, values = project_queries.get_hyperparam_distinct_values(
company_id,
project_ids=request.projects,
section=request.section,

View File

@ -6,9 +6,6 @@ from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
def test_queue_metas(self):
@ -29,6 +26,23 @@ class TestQueueAndModelMetadata(TestService):
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
self._temp_model(
"TestMetadata2",
project=project,
metadata=[
{"key": "test_key", "type": "str", "value": "test_value"},
{"key": "test_key2", "type": "str", "value": "test_value"},
],
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue("test_key" in res["keys"])
self.assertFalse("test_key2" in res["keys"])
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,
):
@ -72,3 +86,12 @@ class TestQueueAndModelMetadata(TestService):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test models meta",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)

View File

@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestSubProjects(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
@ -203,6 +200,9 @@ class TestSubProjects(TestService):
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
self.assertEqual(
{sp.name for sp in res1.sub_projects},
{
@ -215,6 +215,9 @@ class TestSubProjects(TestService):
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks):

View File

@ -198,6 +198,9 @@ class TestTags(TestService):
def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
self.assertEqual(project.stats.active.completed_tasks, 1)
self.assertEqual(project.stats.active.total_tasks, 1)
self.assertEqual(project.stats.active.running_tasks, 0)
for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0)