mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c5d6ce3e65 | ||
|
|
694dbc31c4 | ||
|
|
6488dc54e6 | ||
|
|
158da9b480 | ||
|
|
ec2e071ab7 | ||
|
|
465e270342 | ||
|
|
6705aff56f | ||
|
|
9069cfe1da | ||
|
|
677bb3ba6d | ||
|
|
cb253cff9e | ||
|
|
39ceb5ac5c | ||
|
|
d4edeaaf1b | ||
|
|
56aea1ffb8 |
@@ -14,12 +14,18 @@ from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class MetricVariants(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
@@ -39,6 +45,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
@@ -59,8 +66,8 @@ class TaskMetricVariant(Base):
|
||||
|
||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
scroll_id: Optional[str] = StringField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
|
||||
|
||||
class NextDebugImageSampleRequest(Base):
|
||||
@@ -102,3 +109,10 @@ class TaskMetricsRequest(Base):
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
@@ -57,6 +57,7 @@ class AuthBLL:
|
||||
api_version=str(ServiceRepo.max_endpoint_version()),
|
||||
server_version=str(get_version()),
|
||||
server_build=str(get_build_number()),
|
||||
feature_set="basic",
|
||||
)
|
||||
|
||||
return GetTokenResponse(token=token.decode("ascii"))
|
||||
|
||||
@@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
@@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -74,7 +75,7 @@ class DebugImagesIterator:
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
@@ -118,7 +119,7 @@ class DebugImagesIterator:
|
||||
self,
|
||||
company_id,
|
||||
state: DebugImageEventsScrollState,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
"""
|
||||
Determine the metrics for which new debug image events were added
|
||||
@@ -158,11 +159,11 @@ class DebugImagesIterator:
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = set(
|
||||
m
|
||||
metrics_to_recalc = {
|
||||
m: task_metrics[task].get(m)
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
)
|
||||
}
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
@@ -196,7 +197,7 @@ class DebugImagesIterator:
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
@@ -213,7 +214,7 @@ class DebugImagesIterator:
|
||||
]
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
@@ -222,10 +223,11 @@ class DebugImagesIterator:
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
|
||||
if metrics:
|
||||
must.append({"terms": {"metric": list(metrics)}})
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
query = {"bool": {"must": must}}
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
|
||||
@@ -6,9 +6,8 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from mongoengine import Q
|
||||
@@ -22,6 +21,8 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
delete_company_events,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
@@ -43,8 +44,8 @@ from apiserver.utilities.json import loads
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
MAX_LONG = 2**63 - 1
|
||||
MIN_LONG = -2**63
|
||||
MAX_LONG = 2 ** 63 - 1
|
||||
MIN_LONG = -(2 ** 63)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
@@ -94,7 +95,7 @@ class EventBLL(object):
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
actions: List[dict] = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_scalar_events = nested_dict(
|
||||
@@ -197,7 +198,6 @@ class EventBLL(object):
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
action: Dict[dict]
|
||||
plot_actions = [
|
||||
action["_source"]
|
||||
for action in actions
|
||||
@@ -260,7 +260,8 @@ class EventBLL(object):
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
if invalid_iterations_count:
|
||||
raise BulkIndexError(
|
||||
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
|
||||
f"{invalid_iterations_count} document(s) failed to index.",
|
||||
[invalid_iteration_error],
|
||||
)
|
||||
|
||||
if not added:
|
||||
@@ -466,10 +467,16 @@ class EventBLL(object):
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@@ -499,7 +506,7 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
@@ -527,6 +534,7 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@@ -555,6 +563,8 @@ class EventBLL(object):
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
@@ -563,6 +573,7 @@ class EventBLL(object):
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@@ -669,19 +680,19 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
@@ -691,26 +702,24 @@ class EventBLL(object):
|
||||
if last_iter_count is None:
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
)
|
||||
}
|
||||
for task, last_iters in tasks_iters.items()
|
||||
if last_iters
|
||||
]
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
@@ -748,6 +757,7 @@ class EventBLL(object):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@@ -768,7 +778,7 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
@@ -787,21 +797,24 @@ class EventBLL(object):
|
||||
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
|
||||
def get_task_latest_scalar_values(
|
||||
self, company_id, task_id
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
return [], 0
|
||||
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
}
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@@ -905,34 +918,47 @@ class EventBLL(object):
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters(
|
||||
self, company_id: str, event_type: EventType, task_id: str, iters: int
|
||||
):
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
) -> Mapping[str, Sequence]:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
return {}
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
"tasks": {
|
||||
"terms": {"field": "task"},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
return {}
|
||||
|
||||
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
|
||||
return {
|
||||
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
|
||||
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||
}
|
||||
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False):
|
||||
with translate_errors_context():
|
||||
@@ -965,7 +991,9 @@ class EventBLL(object):
|
||||
so it should be checked by the calling code
|
||||
"""
|
||||
es_req = {"query": {"terms": {"task": task_ids}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "delete_multi_tasks_events"
|
||||
):
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence
|
||||
from typing import Union, Sequence, Mapping
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
@@ -16,6 +16,9 @@ class EventType(Enum):
|
||||
all = "*"
|
||||
|
||||
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
@@ -64,3 +67,23 @@ def delete_company_events(
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_metric_variants_condition(
|
||||
metric_variants: MetricVariants,
|
||||
) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"terms": {"variant": variants}},
|
||||
]
|
||||
}
|
||||
}
|
||||
if variants
|
||||
else {"term": {"metric": metric}}
|
||||
for metric, variants in metric_variants.items()
|
||||
]
|
||||
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
@@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
@@ -34,7 +36,12 @@ class EventMetrics:
|
||||
self.es = es
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
@@ -46,7 +53,12 @@ class EventMetrics:
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
@@ -57,6 +69,7 @@ class EventMetrics:
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
@@ -64,6 +77,7 @@ class EventMetrics:
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
@@ -197,6 +211,7 @@ class EventMetrics:
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
@@ -204,9 +219,14 @@ class EventMetrics:
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
|
||||
@@ -554,7 +554,7 @@ class ProjectBLL:
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models/dataviews in the given projects
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
@@ -676,8 +676,8 @@ class ProjectBLL:
|
||||
@classmethod
|
||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/dataviews/models per requested project
|
||||
Use separate aggregation calls on Task/Dataview/Model instead of lookup
|
||||
Returns the amount of task/models per requested project
|
||||
Use separate aggregation calls on Task/Model instead of lookup
|
||||
aggregation on projects in order not to hit memory limits on large tasks
|
||||
"""
|
||||
if not project_ids:
|
||||
|
||||
@@ -30,6 +30,28 @@ class DeleteProjectResult:
|
||||
urls: TaskUrls = None
|
||||
|
||||
|
||||
def validate_project_delete(company: str, project_id: str):
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
project_ids = _ids_with_children([project_id])
|
||||
ret = {}
|
||||
for cls in (Task, Model):
|
||||
ret[f"{cls.__name__.lower()}s"] = cls.objects(
|
||||
project__in=project_ids,
|
||||
).count()
|
||||
for cls in (Task, Model):
|
||||
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
|
||||
project__in=project_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).count()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def delete_project(
|
||||
company: str, project_id: str, force: bool, delete_contents: bool
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
import dpath
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
|
||||
(("execution", "model_desc"), "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
legacy_params = nested_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
not fields.get(new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
@@ -117,11 +118,11 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
nested_set(fields, new_path, new_param)
|
||||
nested_delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
params = fields.get(param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
@@ -131,7 +132,7 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
fields[param_field] = escaped_params
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
@@ -140,7 +141,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
params = fields.get(param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
@@ -150,18 +151,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
fields[param_field] = unescaped_params
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
("hyperparams", ("execution", "parameters"), True),
|
||||
("configuration", ("execution", "model_desc"), False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
fields.get(new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
nested_set(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
@@ -174,7 +175,7 @@ def _process_path(path: str):
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
if len(parts) < 2 or len(parts) > 4:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
@@ -184,7 +185,7 @@ def _process_path(path: str):
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
("execution.model_desc", "configuration"),
|
||||
("execution.docker_cmd", "container")
|
||||
):
|
||||
path: str
|
||||
|
||||
@@ -130,14 +130,14 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
if not metrics:
|
||||
return set()
|
||||
|
||||
task_metrics = {task: set(metrics)}
|
||||
task_metrics = {task: {m: [] for m in metrics}}
|
||||
scroll_id = None
|
||||
urls = set()
|
||||
while True:
|
||||
res = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company,
|
||||
task_metrics=task_metrics,
|
||||
iter_count=100,
|
||||
iter_count=10,
|
||||
state_id=scroll_id,
|
||||
)
|
||||
if not res.metric_events or not any(
|
||||
|
||||
@@ -109,6 +109,7 @@ def enqueue_task(
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
validate: bool = False,
|
||||
force: bool = False,
|
||||
) -> Tuple[int, dict]:
|
||||
if not queue_id:
|
||||
# try to get default queue
|
||||
@@ -128,6 +129,7 @@ def enqueue_task(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
force=force,
|
||||
).execute(enqueue_status=task.status)
|
||||
|
||||
try:
|
||||
@@ -365,7 +367,21 @@ def stop_task(
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
|
||||
is_queued = task.status == TaskStatus.queued
|
||||
set_stopped = (
|
||||
is_queued
|
||||
or TaskSystemTags.development in task.system_tags
|
||||
or not is_run_by_worker(task)
|
||||
)
|
||||
|
||||
if set_stopped:
|
||||
if is_queued:
|
||||
try:
|
||||
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
debug: false # Debug mode
|
||||
pretty_json: false # prettify json response
|
||||
return_stack: true # return stack trace on error
|
||||
log_calls: true # Log API Calls
|
||||
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
|
||||
|
||||
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
|
||||
# valid values are:
|
||||
|
||||
@@ -117,7 +117,7 @@ class GetMixin(PropsMixin):
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
|
||||
def key(self, v):
|
||||
def key(self, v) -> Optional[str]:
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
@@ -133,6 +133,7 @@ class GetMixin(PropsMixin):
|
||||
next_ = self._next
|
||||
if not self._sticky:
|
||||
self._next = self._default
|
||||
|
||||
return next_
|
||||
|
||||
def value_transform(self, v):
|
||||
@@ -273,10 +274,13 @@ class GetMixin(PropsMixin):
|
||||
).items():
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
dict_query[field] = data
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.fields or [], parameters=parameters
|
||||
).items():
|
||||
if "._" in field or "_." in field:
|
||||
query &= Q(__raw__={field: data})
|
||||
else:
|
||||
dict_query[field.replace(".", "__")] = data
|
||||
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
|
||||
@@ -219,6 +219,7 @@ class Task(AttributedDocument):
|
||||
"status",
|
||||
"project",
|
||||
"parent",
|
||||
"hyperparams.*",
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
@@ -233,7 +234,7 @@ class Task(AttributedDocument):
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_message = StringField(user_set_allowed=True)
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
|
||||
@@ -298,8 +298,9 @@ class PrePopulate:
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
# Always use a public user for pre-populated data
|
||||
cls.user_cls(id=user_id, name=user_name, company="").save()
|
||||
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
|
||||
if not existing_user:
|
||||
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
{
|
||||
_description : "Provides an API for running tasks to report events collected by the system."
|
||||
_definitions {
|
||||
metric_variants {
|
||||
type: object
|
||||
metric {
|
||||
description: The metric name
|
||||
type: string
|
||||
}
|
||||
variants {
|
||||
type: array
|
||||
description: The names of the metric variants
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
metrics_scalar_event {
|
||||
description: "Used for reporting scalar metrics during training task"
|
||||
type: object
|
||||
@@ -193,6 +205,29 @@
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
task_metric_variants {
|
||||
type: object
|
||||
required: [task]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
variants {
|
||||
description: Metric variant names
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
task_log_event {
|
||||
@@ -376,7 +411,7 @@
|
||||
metrics {
|
||||
type: array
|
||||
items { "$ref": "#/definitions/task_metric" }
|
||||
description: "List metrics for which the envents will be retreived"
|
||||
description: "List of task metrics for which the envents will be retreived"
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
@@ -411,6 +446,17 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${debug_images."2.7"} {
|
||||
request {
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/task_metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_debug_image_sample {
|
||||
"2.12": {
|
||||
@@ -804,6 +850,17 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${get_task_plots."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_plots {
|
||||
"2.1" {
|
||||
@@ -962,6 +1019,17 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.14": ${scalar_metrics_iter_histogram."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
description: List of metrics and variants
|
||||
items { "$ref": "#/definitions/metric_variants" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
multi_task_scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
|
||||
@@ -379,7 +379,7 @@ get_all {
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of dataviews"
|
||||
description: "Page number, returns a specific page out of the resulting list of projects"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
@@ -469,7 +469,7 @@ get_all_ex {
|
||||
default: false
|
||||
}
|
||||
check_own_contents {
|
||||
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
|
||||
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks and models are counted"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
@@ -594,7 +594,7 @@ merge {
|
||||
type: object
|
||||
properties {
|
||||
moved_entities {
|
||||
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
|
||||
description: "The number of tasks and models moved from the merged project into the destination"
|
||||
type: integer
|
||||
}
|
||||
moved_projects {
|
||||
@@ -605,6 +605,42 @@ merge {
|
||||
}
|
||||
}
|
||||
}
|
||||
validate_delete {
|
||||
"2.14" {
|
||||
description: "Validates that the project existis and can be deleted"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tasks {
|
||||
description: "The total number of tasks under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
non_archived_tasks {
|
||||
description: "The total number of non-archived tasks under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
models {
|
||||
description: "The total number of models under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
non_archived_models {
|
||||
description: "The total number of non-archived models under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
@@ -613,7 +649,7 @@ delete {
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
|
||||
@@ -588,13 +588,18 @@ class APICall(DataContainer):
|
||||
self._end_ts = time.time()
|
||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||
|
||||
def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
|
||||
def get_response(self, include_stack: bool = None) -> Tuple[Union[dict, str], str]:
|
||||
"""
|
||||
Get the response for this call.
|
||||
:param include_stack: If True, stack trace stored in this call's result should
|
||||
be included in the response (default is False)
|
||||
be included in the response (default follows configuration)
|
||||
:return: Response data (encoded according to self.content_type) and the data's content type
|
||||
"""
|
||||
include_stack = (
|
||||
include_stack
|
||||
if include_stack is not None
|
||||
else config.get("apiserver.return_stack_to_caller", False)
|
||||
)
|
||||
|
||||
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
sys_random = random.SystemRandom()
|
||||
|
||||
|
||||
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
|
||||
def get_random_string(
|
||||
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
|
||||
) -> str:
|
||||
"""
|
||||
Returns a securely generated random string.
|
||||
|
||||
@@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
Taken from the django.utils.crypto module.
|
||||
"""
|
||||
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
|
||||
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
|
||||
|
||||
|
||||
def get_client_id(length=20):
|
||||
def get_client_id(length: int = 20) -> str:
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
Taken from the Django project.
|
||||
"""
|
||||
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return get_random_string(length, chars)
|
||||
|
||||
|
||||
def get_secret_key(length=50):
|
||||
def get_secret_key(length: int = 50) -> str:
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
@@ -33,5 +36,5 @@ def get_secret_key(length=50):
|
||||
NOTE: asterisk is not supported due to issues with environment variables containing
|
||||
asterisks (in case the secret key is stored in an environment variable)
|
||||
"""
|
||||
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&(-_=+)'
|
||||
chars = string.ascii_letters + string.digits
|
||||
return get_random_string(length, chars)
|
||||
|
||||
@@ -37,7 +37,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.13")
|
||||
_max_version = PartialVersion("2.14")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
|
||||
import attr
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.events import (
|
||||
@@ -17,9 +18,11 @@ from apiserver.apimodels.events import (
|
||||
LogOrderEnum,
|
||||
GetDebugImageSampleRequest,
|
||||
NextDebugImageSampleRequest,
|
||||
MetricVariants as ApiMetrics,
|
||||
TaskPlotsRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json
|
||||
@@ -321,7 +324,7 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
)
|
||||
last_iters = event_bll.get_last_iters(
|
||||
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
|
||||
)
|
||||
).get(task_id)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
last_iter=last_iters[0] if last_iters else 0,
|
||||
@@ -494,11 +497,22 @@ def get_task_plots_v1_7(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
|
||||
def get_task_plots(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
def _get_metric_variants_from_request(
|
||||
req_metrics: Sequence[ApiMetrics],
|
||||
) -> Optional[MetricVariants]:
|
||||
if not req_metrics:
|
||||
return None
|
||||
|
||||
return {m.metric: m.variants for m in req_metrics}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest
|
||||
)
|
||||
def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@@ -509,6 +523,7 @@ def get_task_plots(call, company_id, _):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@@ -594,9 +609,9 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
response_data_model=DebugImageResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_metrics = defaultdict(set)
|
||||
task_metrics = defaultdict(dict)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task].add(tm.metric)
|
||||
task_metrics[tm.task][tm.metric] = tm.variants
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
@@ -734,11 +749,11 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
|
||||
|
||||
def _get_top_iter_unique_events(events, max_iters):
|
||||
top_unique_events = defaultdict(lambda: [])
|
||||
for e in events:
|
||||
key = e.get("metric", "") + e.get("variant", "")
|
||||
for ev in events:
|
||||
key = ev.get("metric", "") + ev.get("variant", "")
|
||||
evs = top_unique_events[key]
|
||||
if len(evs) < max_iters:
|
||||
evs.append(e)
|
||||
evs.append(ev)
|
||||
unique_events = list(
|
||||
itertools.chain.from_iterable(list(top_unique_events.values()))
|
||||
)
|
||||
|
||||
@@ -16,10 +16,14 @@ from apiserver.apimodels.projects import (
|
||||
MoveRequest,
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
ProjectRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project.project_cleanup import delete_project
|
||||
from apiserver.bll.project.project_cleanup import (
|
||||
delete_project,
|
||||
validate_project_delete,
|
||||
)
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
@@ -230,6 +234,13 @@ def merge(call: APICall, company: str, request: MergeRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.validate_delete")
|
||||
def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
|
||||
call.result.data = validate_project_delete(
|
||||
company=company_id, project_id=request.project
|
||||
)
|
||||
|
||||
|
||||
@endpoint("projects.delete", request_data_model=DeleteRequest)
|
||||
def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
res, affected_projects = delete_project(
|
||||
|
||||
@@ -4,7 +4,6 @@ from functools import partial
|
||||
from typing import Sequence, Union, Tuple
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from mongoengine import EmbeddedDocument, Q
|
||||
from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
@@ -220,14 +219,13 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
|
||||
@@ -236,14 +234,13 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
@@ -252,16 +249,15 @@ def get_all(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
@@ -403,15 +399,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
escape_dict_field(fields, path)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_stripped_fields:
|
||||
try:
|
||||
path = f"script/{field}"
|
||||
value = dpath.get(fields, path)
|
||||
script = fields.get("script")
|
||||
if script:
|
||||
for field in task_script_stripped_fields:
|
||||
value = script.get(field)
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
dpath.set(fields, path, value)
|
||||
except KeyError:
|
||||
pass
|
||||
script[field] = value.strip()
|
||||
|
||||
return fields
|
||||
|
||||
@@ -546,10 +539,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
}
|
||||
|
||||
|
||||
def prepare_update_fields(call: APICall, task, call_data):
|
||||
def prepare_update_fields(call: APICall, call_data):
|
||||
valid_fields = deepcopy(Task.user_set_allowed())
|
||||
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
|
||||
update_fields["output__error"] = None
|
||||
update_fields.update(
|
||||
status=None, status_reason=None, status_message=None, output__error=None
|
||||
)
|
||||
t_fields = task_fields
|
||||
t_fields.add("output__error")
|
||||
fields = parse_from_call(call_data, update_fields, t_fields)
|
||||
@@ -569,7 +564,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
|
||||
|
||||
if not partial_update_dict:
|
||||
return UpdateResponse(updated=0)
|
||||
@@ -642,7 +637,7 @@ def update_batch(call: APICall, company_id, _):
|
||||
updated_projects = set()
|
||||
for id, data in items.items():
|
||||
task = tasks[id]
|
||||
fields, valid_fields = prepare_update_fields(call, task, data)
|
||||
fields, valid_fields = prepare_update_fields(call, data)
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
@@ -744,8 +739,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
||||
)
|
||||
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
|
||||
call.result.data = {
|
||||
"params": [{"task": task, **data} for task, data in tasks_params.items()]
|
||||
@@ -754,39 +748,36 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
|
||||
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
|
||||
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
|
||||
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
||||
)
|
||||
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
company_id, task_ids=request.tasks, names=request.names
|
||||
)
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
company_id, task_ids=request.tasks, names=request.names
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
@@ -801,10 +792,9 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
|
||||
def get_configuration_names(
|
||||
call: APICall, company_id, request: GetConfigurationNamesRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
||||
)
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
@@ -815,31 +805,29 @@ def get_configuration_names(
|
||||
|
||||
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
|
||||
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
|
||||
def delete_configuration(
|
||||
call: APICall, company_id, request: DeleteConfigurationRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -854,6 +842,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
force=request.force,
|
||||
)
|
||||
call.result.data_model = EnqueueResponse(queued=queued, **res)
|
||||
|
||||
@@ -1169,15 +1158,14 @@ def ping(_, company_id, request: PingRequest):
|
||||
def add_or_update_artifacts(
|
||||
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -1186,31 +1174,28 @@ def add_or_update_artifacts(
|
||||
request_data_model=DeleteArtifactsRequest,
|
||||
)
|
||||
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.move", request_data_model=MoveRequest)
|
||||
|
||||
54
apiserver/tests/automated/test_project_delete.py
Normal file
54
apiserver/tests/automated/test_project_delete.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.tests.automated import TestService
|
||||
from apiserver.database.utils import id as db_id
|
||||
|
||||
|
||||
class TestProjectsDelete(TestService):
|
||||
def setUp(self, version="2.14"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks", type="testing", name=db_id(), input=dict(view=dict()), **kwargs
|
||||
)
|
||||
|
||||
def new_model(self, **kwargs):
|
||||
return self.create_temp("models", uri="file:///a/b", name=db_id(), labels={}, **kwargs)
|
||||
|
||||
def new_project(self, **kwargs):
|
||||
return self.create_temp("projects", name=db_id(), description="", **kwargs)
|
||||
|
||||
def test_delete_fails_with_active_task(self):
|
||||
project = self.new_project()
|
||||
self.new_task(project=project)
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.tasks, 1)
|
||||
self.assertEqual(res.non_archived_tasks, 1)
|
||||
with self.api.raises(errors.bad_request.ProjectHasTasks):
|
||||
self.api.projects.delete(project=project)
|
||||
|
||||
def test_delete_with_archived_task(self):
|
||||
project = self.new_project()
|
||||
self.new_task(project=project, system_tags=[EntityVisibility.archived.value])
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.tasks, 1)
|
||||
self.assertEqual(res.non_archived_tasks, 0)
|
||||
self.api.projects.delete(project=project)
|
||||
|
||||
def test_delete_fails_with_active_model(self):
|
||||
project = self.new_project()
|
||||
self.new_model(project=project)
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.models, 1)
|
||||
self.assertEqual(res.non_archived_models, 1)
|
||||
with self.api.raises(errors.bad_request.ProjectHasModels):
|
||||
self.api.projects.delete(project=project)
|
||||
|
||||
def test_delete_with_archived_model(self):
|
||||
project = self.new_project()
|
||||
self.new_model(project=project, system_tags=[EntityVisibility.archived.value])
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.models, 1)
|
||||
self.assertEqual(res.non_archived_models, 0)
|
||||
self.api.projects.delete(project=project)
|
||||
@@ -10,6 +10,7 @@ def extract_properties_to_lists(
|
||||
key_names: Sequence[str],
|
||||
data: Sequence[dict],
|
||||
extract_func: Optional[Callable[[dict], Tuple]] = None,
|
||||
target_keys: Optional[Sequence[str]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Given a list of dictionaries and names of dictionary keys
|
||||
@@ -20,9 +21,10 @@ def extract_properties_to_lists(
|
||||
:param extract_func: the optional callable that extracts properties
|
||||
from a dictionary and put them in a tuple in the order corresponding to
|
||||
key_names. If not specified then properties are extracted according to key_names
|
||||
:param target_keys: optional alternative keys to use in the target dictionary. must be equal in length to key_names.
|
||||
"""
|
||||
if not data:
|
||||
return {k: [] for k in key_names}
|
||||
|
||||
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
|
||||
return dict(zip(key_names, map(list, value_sequences)))
|
||||
return dict(zip((target_keys or key_names), map(list, value_sequences)))
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.0.2"
|
||||
__version__ = "1.1.0"
|
||||
|
||||
Reference in New Issue
Block a user