mirror of
https://github.com/clearml/clearml-server
synced 2025-06-22 15:15:28 +00:00
Move to ElasticSearch 7
Add initial support for project ordering Add support for sortable task duration (used by the UI in the experiment's table) Add support for project name in worker's current task info Add support for results and artifacts in pre-populates examples Add demo server features
This commit is contained in:
parent
77397c4f21
commit
baba8b5b73
@ -1,7 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
from jsonmodels.validators import Length
|
||||
from mongoengine.base import BaseDocument
|
||||
|
||||
from apimodels import DictField
|
||||
from apimodels import DictField, ListField
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
'inc',
|
||||
'dec',
|
||||
'push',
|
||||
'push_all',
|
||||
'pop',
|
||||
'pull',
|
||||
'pull_all',
|
||||
'add_to_set',
|
||||
"inc",
|
||||
"dec",
|
||||
"push",
|
||||
"push_all",
|
||||
"pop",
|
||||
"pull",
|
||||
"pull_all",
|
||||
"add_to_set",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split('__')
|
||||
parts = path.split("__")
|
||||
if len(parts) > 1:
|
||||
if parts[0] == 'set':
|
||||
if parts[0] == "set":
|
||||
parts = parts[1:]
|
||||
elif parts[0] == 'unset':
|
||||
elif parts[0] == "unset":
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return '.'.join(parts), cls._normalize_mongo_value(value)
|
||||
return ".".join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
@ -62,3 +63,7 @@ class PagedRequest(models.Base):
|
||||
|
||||
class IdResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
|
||||
|
||||
class MakePublicRequest(models.Base):
|
||||
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])
|
||||
|
@ -3,7 +3,7 @@ from typing import Sequence, Optional
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
from apimodels import ListField, IntField, ActualEnumField
|
||||
from bll.event.event_metrics import EventType
|
||||
@ -11,7 +11,7 @@ from bll.event.scalar_key import ScalarKeyEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=10000)
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
tasks: Sequence[str] = ListField(
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
items_types=str, validators=[Length(minimum_value=1, maximum_value=10)]
|
||||
)
|
||||
|
||||
|
||||
|
@ -67,6 +67,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
|
||||
company = EmbeddedField(IdNameEntry)
|
||||
ip = StringField()
|
||||
task = EmbeddedField(IdNameEntry)
|
||||
project = EmbeddedField(IdNameEntry)
|
||||
queue = StringField() # queue from which current task was taken
|
||||
queues = ListField(str) # list of queues this worker listens to
|
||||
register_time = DateTimeField(required=True)
|
||||
|
@ -208,7 +208,11 @@ class DebugImagesIterator:
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"metric": metrics}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
@ -251,7 +255,7 @@ class DebugImagesIterator:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@ -298,6 +302,7 @@ class DebugImagesIterator:
|
||||
must_conditions = [
|
||||
{"term": {"task": metric.task}},
|
||||
{"term": {"metric": metric.name}},
|
||||
{"exists": {"field": "url"}},
|
||||
]
|
||||
must_not_conditions = []
|
||||
|
||||
@ -368,7 +373,7 @@ class DebugImagesIterator:
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iter_count,
|
||||
"order": {"_term": "desc" if navigate_earlier else "asc"},
|
||||
"order": {"_key": "desc" if navigate_earlier else "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
@ -387,7 +392,7 @@ class DebugImagesIterator:
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
||||
|
@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple
|
||||
from typing import Sequence, Set, Tuple, Optional
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
@ -22,6 +22,7 @@ from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from redis_manager import redman
|
||||
from timing_context import TimingContext
|
||||
from tools import safe_get
|
||||
from utilities.dicts import flatten_nested_items
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
@ -134,7 +135,6 @@ class EventBLL(object):
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
"_type": "event",
|
||||
"_source": event,
|
||||
}
|
||||
|
||||
@ -144,7 +144,6 @@ class EventBLL(object):
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if (
|
||||
iter is not None
|
||||
@ -342,14 +341,9 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, scroll="1h", routing=task_id
|
||||
)
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
@ -377,7 +371,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": num_last_iterations,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -393,7 +387,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@ -422,13 +416,11 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
must = []
|
||||
if last_iterations_per_plot is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": tasks}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
@ -451,32 +443,41 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(tasks)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
# scroll id may be missing when queering a totally empty DB
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
|
||||
"""
|
||||
Return events and next scroll id from the scrolled query
|
||||
Release the scroll once it is exhausted
|
||||
"""
|
||||
total_events = safe_get(es_res, "hits/total/value", default=0)
|
||||
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
if next_scroll_id and not events:
|
||||
self.es.clear_scroll(scroll_id=next_scroll_id)
|
||||
next_scroll_id = None
|
||||
|
||||
return events, total_events, next_scroll_id
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
@ -502,20 +503,16 @@ class EventBLL(object):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if metric or variant:
|
||||
must = query["bool"]["must"]
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
@ -534,27 +531,23 @@ class EventBLL(object):
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(task_ids)
|
||||
es_req = {
|
||||
"sort": sort,
|
||||
"size": min(size, 10000),
|
||||
"query": {"bool": {"must": must}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res.get("_scroll_id")
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
@ -590,7 +583,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@ -622,14 +615,14 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": EventMetrics.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": EventMetrics.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "asc"},
|
||||
"order": {"_key": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_value": {
|
||||
@ -659,7 +652,7 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@ -706,7 +699,7 @@ class EventBLL(object):
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
@ -727,7 +720,7 @@ class EventBLL(object):
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -737,7 +730,7 @@ class EventBLL(object):
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@ -759,8 +752,6 @@ class EventBLL(object):
|
||||
es_index = EventMetrics.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
index=es_index, body=es_req, routing=task_id, refresh=True
|
||||
)
|
||||
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
@ -1,12 +1,11 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Callable, Iterable
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
@ -16,7 +15,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities import safe_get
|
||||
from tools import safe_get
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@ -30,14 +29,18 @@ class EventType(Enum):
|
||||
|
||||
|
||||
class EventMetrics:
|
||||
MAX_TASKS_COUNT = 50
|
||||
MAX_METRICS_COUNT = 200
|
||||
MAX_VARIANTS_COUNT = 500
|
||||
MAX_METRICS_COUNT = 100
|
||||
MAX_VARIANTS_COUNT = 100
|
||||
MAX_AGGS_ELEMENTS_COUNT = 50
|
||||
MAX_SAMPLE_BUCKETS = 6000
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
@property
|
||||
def _max_concurrency(self):
|
||||
return config.get("services.events.max_metrics_concurrency", 4)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
@ -51,15 +54,48 @@ class EventMetrics:
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
return self._run_get_scalar_metrics_as_parallel(
|
||||
company_id,
|
||||
task_ids=[task_id],
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average,
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, es_index, samples, ScalarKey.resolve(key)
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
es_index=es_index, task_id=task_id, samples=samples, field=key.field
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(get_scalar_average, interval_groups)
|
||||
)
|
||||
else:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
get_scalar_average(group) for group in interval_groups
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
|
||||
return ret
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self,
|
||||
company_id,
|
||||
@ -72,12 +108,6 @@ class EventMetrics:
|
||||
Compare scalar metrics for different tasks per metric and variant
|
||||
The amount of points in each histogram should not exceed the requested samples
|
||||
"""
|
||||
if len(task_ids) > self.MAX_TASKS_COUNT:
|
||||
raise errors.BadRequest(
|
||||
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
|
||||
len(task_ids),
|
||||
)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
@ -90,7 +120,6 @@ class EventMetrics:
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
companies = {t.company for t in task_objs}
|
||||
@ -99,138 +128,95 @@ class EventMetrics:
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
ret = self._run_get_scalar_metrics_as_parallel(
|
||||
next(iter(companies)),
|
||||
task_ids=task_ids,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
get_func=self._get_scalar_average_per_task,
|
||||
)
|
||||
|
||||
for metric_data in ret.values():
|
||||
for variant_data in metric_data.values():
|
||||
for task_id, task_data in variant_data.items():
|
||||
task_data["name"] = task_name_by_id[task_id]
|
||||
|
||||
return ret
|
||||
|
||||
TaskMetric = Tuple[str, str, str]
|
||||
|
||||
MetricInterval = Tuple[int, Sequence[TaskMetric]]
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _split_metrics_by_max_aggs_count(
|
||||
self, task_metrics: Sequence[TaskMetric]
|
||||
) -> Iterable[Sequence[TaskMetric]]:
|
||||
"""
|
||||
Return task metrics in groups where amount of task metrics in each group
|
||||
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
|
||||
variants while always preserving all their tasks in the same group
|
||||
"""
|
||||
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield task_metrics
|
||||
return
|
||||
|
||||
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
|
||||
groups = []
|
||||
for group in tm_grouped.values():
|
||||
groups.append(group)
|
||||
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
|
||||
yield list(itertools.chain(*groups))
|
||||
groups = []
|
||||
|
||||
if groups:
|
||||
yield list(itertools.chain(*groups))
|
||||
|
||||
return
|
||||
|
||||
def _run_get_scalar_metrics_as_parallel(
|
||||
self,
|
||||
company_id: str,
|
||||
task_ids: Sequence[str],
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
get_func: Callable[
|
||||
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
|
||||
],
|
||||
) -> dict:
|
||||
"""
|
||||
Group metrics per interval length and execute get_func for each group in parallel
|
||||
:param company_id: id of the company
|
||||
:params task_ids: ids of the tasks to collect data for
|
||||
:param samples: maximum number of samples per metric
|
||||
:param get_func: callable that given metric names for the same interval
|
||||
performs histogram aggregation for the metrics and return the aggregated data
|
||||
"""
|
||||
es_index = self.get_index_name(company_id, "training_stats_scalar")
|
||||
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
intervals = self._get_metric_intervals(
|
||||
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
es_index=es_index,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
)
|
||||
|
||||
if not intervals:
|
||||
return {}
|
||||
|
||||
intervals = list(
|
||||
itertools.chain.from_iterable(
|
||||
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
|
||||
for i, tms in intervals
|
||||
)
|
||||
)
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
|
||||
metrics = itertools.chain.from_iterable(
|
||||
pool.map(
|
||||
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
|
||||
intervals,
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
|
||||
task_metrics = zip(
|
||||
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
|
||||
)
|
||||
|
||||
ret = defaultdict(dict)
|
||||
for metric_key, metric_values in metrics:
|
||||
ret[metric_key].update(metric_values)
|
||||
res = defaultdict(lambda: defaultdict(dict))
|
||||
for task_id, task_data in task_metrics:
|
||||
task_name = task_name_by_id[task_id]
|
||||
for metric_key, metric_data in task_data.items():
|
||||
for variant_key, variant_data in metric_data.items():
|
||||
variant_data["name"] = task_name
|
||||
res[metric_key][variant_key][task_id] = variant_data
|
||||
|
||||
return ret
|
||||
return res
|
||||
|
||||
def _get_metric_intervals(
|
||||
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
|
||||
MetricInterval = Tuple[str, str, int, int]
|
||||
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
|
||||
|
||||
@classmethod
|
||||
def _group_task_metric_intervals(
|
||||
cls, intervals: Sequence[MetricInterval]
|
||||
) -> Sequence[MetricIntervalGroup]:
|
||||
"""
|
||||
Group task metric intervals so that the following conditions are meat:
|
||||
- All the metrics in the same group have the same interval (with 10% rounding)
|
||||
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
|
||||
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
|
||||
"""
|
||||
metric_interval_groups = []
|
||||
interval_group = []
|
||||
group_interval_upper_bound = 0
|
||||
group_max_interval = 0
|
||||
group_samples = 0
|
||||
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
|
||||
if (
|
||||
interval > group_interval_upper_bound
|
||||
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
|
||||
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
|
||||
):
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
interval_group = []
|
||||
group_max_interval = interval
|
||||
group_interval_upper_bound = interval + int(interval * 0.1)
|
||||
group_samples = 0
|
||||
interval_group.append((metric, variant))
|
||||
group_samples += size
|
||||
group_max_interval = max(group_max_interval, interval)
|
||||
if interval_group:
|
||||
metric_interval_groups.append((group_max_interval, interval_group))
|
||||
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self, es_index, task_id: str, samples: int, field: str = "iter"
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
amount of points does not exceed sample.
|
||||
Return metric variants grouped by interval value with 10% rounding
|
||||
For samples==0 return empty list
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
default_intervals = [(1, [])]
|
||||
if not samples:
|
||||
return default_intervals
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
},
|
||||
"aggs": {
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
"count": {"value_count": {"field": field}},
|
||||
"min_index": {"min": {"field": field}},
|
||||
"max_index": {"max": {"field": field}},
|
||||
},
|
||||
}
|
||||
},
|
||||
@ -239,88 +225,75 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
return default_intervals
|
||||
return []
|
||||
|
||||
intervals = [
|
||||
(
|
||||
task["key"],
|
||||
metric["key"],
|
||||
variant["key"],
|
||||
self._calculate_metric_interval(variant, samples),
|
||||
)
|
||||
for task in aggs_result["tasks"]["buckets"]
|
||||
for metric in task["metrics"]["buckets"]
|
||||
return [
|
||||
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
for variant in metric["variants"]["buckets"]
|
||||
]
|
||||
|
||||
metric_intervals = []
|
||||
upper_border = 0
|
||||
interval_metrics = None
|
||||
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
|
||||
if not interval_metrics or interval > upper_border:
|
||||
interval_metrics = []
|
||||
metric_intervals.append((interval, interval_metrics))
|
||||
upper_border = interval + int(interval * 0.1)
|
||||
interval_metrics.append((task, metric, variant))
|
||||
|
||||
return metric_intervals
|
||||
|
||||
@staticmethod
|
||||
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
|
||||
def _build_metric_interval(
|
||||
metric: str, variant: str, data: dict, samples: int
|
||||
) -> Tuple[str, str, int, int]:
|
||||
"""
|
||||
Calculate index interval per metric_variant variant so that the
|
||||
total amount of intervals does not exceeds the samples
|
||||
Return the interval and resulting amount of intervals
|
||||
"""
|
||||
count = safe_get(metric_variant, "count/value")
|
||||
if not count or count < samples:
|
||||
return 1
|
||||
count = safe_get(data, "count/value", default=0)
|
||||
if count < samples:
|
||||
return metric, variant, 1, count
|
||||
|
||||
min_index = safe_get(metric_variant, "min_index/value", default=0)
|
||||
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
|
||||
return max(1, int(max_index - min_index + 1) // samples)
|
||||
min_index = safe_get(data, "min_index/value", default=0)
|
||||
max_index = safe_get(data, "max_index/value", default=min_index)
|
||||
return (
|
||||
metric,
|
||||
variant,
|
||||
max(1, int(max_index - min_index + 1) // samples),
|
||||
samples,
|
||||
)
|
||||
|
||||
MetricData = Tuple[str, dict]
|
||||
|
||||
def _get_scalar_average(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
Note: the function works with a single task only
|
||||
"""
|
||||
|
||||
assert len(task_ids) == 1
|
||||
interval, task_metrics = metrics_interval
|
||||
interval, metrics = metrics_interval
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": self.MAX_METRICS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": self.MAX_VARIANTS_COUNT,
|
||||
"order": {"_term": "desc"},
|
||||
"order": {"_key": "desc"},
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
es_index, aggs=aggs, task_id=task_id, metrics=metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
@ -341,61 +314,6 @@ class EventMetrics:
|
||||
]
|
||||
return metrics
|
||||
|
||||
def _get_scalar_average_per_task(
|
||||
self,
|
||||
metrics_interval: MetricInterval,
|
||||
task_ids: Sequence[str],
|
||||
es_index: str,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
Retrieve scalar histograms per several metric variants that share the same interval
|
||||
"""
|
||||
interval, task_metrics = metrics_interval
|
||||
|
||||
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
|
||||
aggs = {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {
|
||||
"field": "task",
|
||||
"size": self.MAX_TASKS_COUNT,
|
||||
},
|
||||
"aggs": aggregation,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
aggs_result = self._query_aggregation_for_metrics_and_tasks(
|
||||
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
return {}
|
||||
|
||||
metrics = [
|
||||
(
|
||||
metric["key"],
|
||||
{
|
||||
variant["key"]: {
|
||||
task["key"]: key.get_iterations_data(task)
|
||||
for task in variant["tasks"]["buckets"]
|
||||
}
|
||||
for variant in metric["variants"]["buckets"]
|
||||
},
|
||||
)
|
||||
for metric in aggs_result["metrics"]["buckets"]
|
||||
]
|
||||
return metrics
|
||||
|
||||
@staticmethod
|
||||
def _add_aggregation_average(aggregation):
|
||||
average_agg = {"avg_val": {"avg": {"field": "value"}}}
|
||||
@ -404,69 +322,55 @@ class EventMetrics:
|
||||
for key, value in aggregation.items()
|
||||
}
|
||||
|
||||
def _query_aggregation_for_metrics_and_tasks(
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
es_index: str,
|
||||
aggs: dict,
|
||||
task_ids: Sequence[str],
|
||||
task_metrics: Sequence[TaskMetric],
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
) -> dict:
|
||||
"""
|
||||
Return the result of elastic search query for the given aggregation filtered
|
||||
by the given task_ids and metrics
|
||||
"""
|
||||
if task_metrics:
|
||||
condition = {
|
||||
"should": [
|
||||
self._build_metric_terms(task, metric, variant)
|
||||
for task, metric, variant in task_metrics
|
||||
]
|
||||
}
|
||||
else:
|
||||
condition = {"must": [{"terms": {"task": task_ids}}]}
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metrics:
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
for metric, variant in metrics
|
||||
]
|
||||
must.append({"bool": {"should": should}})
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"bool": condition},
|
||||
"query": {"bool": {"must": must}},
|
||||
"aggs": aggs,
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, routing=",".join(task_ids)
|
||||
)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
@staticmethod
|
||||
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
|
||||
"""
|
||||
Build query term for a metric + variant
|
||||
"""
|
||||
return {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def get_tasks_metrics(
|
||||
self, company_id, task_ids: Sequence, event_type: EventType
|
||||
) -> Sequence[Tuple]:
|
||||
) -> Sequence:
|
||||
"""
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
es_index = EventMetrics.get_index_name(company_id, event_type.value)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [(tid, []) for tid in task_ids]
|
||||
return {}
|
||||
|
||||
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
|
||||
with ThreadPoolExecutor(max_concurrency) as pool:
|
||||
with ThreadPoolExecutor(self._max_concurrency) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||
@ -494,7 +398,7 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
@ -71,9 +71,9 @@ class LogEventsIterator:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
@ -92,7 +92,7 @@ class LogEventsIterator:
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
es_result = self.es.search(index=es_index, body=es_req)
|
||||
hits = es_result["hits"]["hits"]
|
||||
if not hits or len(hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
|
@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
}
|
||||
}
|
||||
@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
|
||||
self.name: {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}ms",
|
||||
"fixed_interval": f"{interval}ms",
|
||||
"min_doc_count": 1,
|
||||
"format": "strict_date_time",
|
||||
}
|
||||
|
@ -18,7 +18,6 @@ log = config.logger(__file__)
|
||||
|
||||
class QueueMetrics:
|
||||
class EsKeys:
|
||||
DOC_TYPE = "metrics"
|
||||
WAITING_TIME_FIELD = "average_waiting_time"
|
||||
QUEUE_LENGTH_FIELD = "queue_length"
|
||||
TIMESTAMP_FIELD = "timestamp"
|
||||
@ -66,7 +65,6 @@ class QueueMetrics:
|
||||
entries = [e for e in queue.entries if e.added]
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type=self.EsKeys.DOC_TYPE,
|
||||
_source={
|
||||
self.EsKeys.TIMESTAMP_FIELD: timestamp,
|
||||
self.EsKeys.QUEUE_FIELD: queue.id,
|
||||
@ -93,7 +91,6 @@ class QueueMetrics:
|
||||
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
|
||||
doc_type=self.EsKeys.DOC_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@ -109,7 +106,7 @@ class QueueMetrics:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": cls.EsKeys.TIMESTAMP_FIELD,
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
|
@ -237,7 +237,6 @@ class StatisticsReporter:
|
||||
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
|
||||
return worker_bll.es_client.search(
|
||||
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
|
@ -35,14 +35,21 @@ class SetFieldsResolver:
|
||||
SET_MODIFIERS = ("min", "max")
|
||||
|
||||
def __init__(self, set_fields: Dict[str, Any]):
|
||||
self.orig_fields = set_fields
|
||||
self.fields = {
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
self.orig_fields = {}
|
||||
self.fields = {}
|
||||
self.add_fields(**set_fields)
|
||||
|
||||
def add_fields(self, **set_fields: Any):
|
||||
self.orig_fields.update(set_fields)
|
||||
self.fields.update(
|
||||
{
|
||||
f: fname
|
||||
for f, modifier, dunder, fname in (
|
||||
(f,) + f.partition("__") for f in set_fields.keys()
|
||||
)
|
||||
if dunder and modifier in self.SET_MODIFIERS
|
||||
}
|
||||
)
|
||||
|
||||
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
|
||||
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
|
||||
|
@ -21,6 +21,7 @@ from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.queue import Queue
|
||||
from database.model.task.task import Task
|
||||
from redis_manager import redman
|
||||
@ -146,6 +147,7 @@ class WorkerBLL:
|
||||
|
||||
if not report.task:
|
||||
entry.task = None
|
||||
entry.project = None
|
||||
else:
|
||||
with translate_errors_context():
|
||||
query = dict(id=report.task, company=company_id)
|
||||
@ -160,6 +162,12 @@ class WorkerBLL:
|
||||
raise bad_request.InvalidTaskId(**query)
|
||||
entry.task = IdNameEntry(id=task.id, name=task.name)
|
||||
|
||||
entry.project = None
|
||||
if task.project:
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if project:
|
||||
entry.project = IdNameEntry(id=project.id, name=project.name)
|
||||
|
||||
entry.last_report_time = now
|
||||
except APIError:
|
||||
raise
|
||||
@ -369,7 +377,6 @@ class WorkerBLL:
|
||||
def make_doc(category, metric, variant, value) -> dict:
|
||||
return dict(
|
||||
_index=es_index,
|
||||
_type="stat",
|
||||
_source=dict(
|
||||
timestamp=timestamp,
|
||||
worker=worker,
|
||||
|
@ -25,7 +25,6 @@ class WorkerStats:
|
||||
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
|
||||
return self.es.search(
|
||||
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
|
||||
doc_type="stat",
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
@ -53,7 +52,7 @@ class WorkerStats:
|
||||
|
||||
res = self._search_company_stats(company_id, es_req)
|
||||
|
||||
if not res["hits"]["total"]:
|
||||
if not res["hits"]["total"]["value"]:
|
||||
raise bad_request.WorkerStatsNotFound(
|
||||
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
|
||||
)
|
||||
@ -87,7 +86,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{request.interval}s",
|
||||
"fixed_interval": f"{request.interval}s",
|
||||
"min_doc_count": 1,
|
||||
},
|
||||
"aggs": {
|
||||
@ -216,7 +215,7 @@ class WorkerStats:
|
||||
"dates": {
|
||||
"date_histogram": {
|
||||
"field": "timestamp",
|
||||
"interval": f"{interval}s",
|
||||
"fixed_interval": f"{interval}s",
|
||||
},
|
||||
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
|
||||
}
|
||||
|
@ -30,7 +30,7 @@
|
||||
enabled: false
|
||||
zip_files: ["/path/to/export.zip"]
|
||||
fail_on_error: false
|
||||
artifacts_path: "/mnt/fileserver"
|
||||
# artifacts_path: "/mnt/fileserver"
|
||||
}
|
||||
|
||||
# time in seconds to take an exclusive lock to init es and mongodb
|
||||
|
@ -1,6 +1,6 @@
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
hosts: [{host: "127.0.0.1", port: 9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
@ -11,7 +11,7 @@ elastic {
|
||||
}
|
||||
|
||||
workers {
|
||||
hosts: [{host:"127.0.0.1", port:9200}]
|
||||
hosts: [{host:"127.0.0.1", port:9211}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
|
16
server/config/default/services/auth.conf
Normal file
16
server/config/default/services/auth.conf
Normal file
@ -0,0 +1,16 @@
|
||||
fixed_users {
|
||||
guest {
|
||||
enabled: false
|
||||
|
||||
default_company: "025315a9321f49f8be07f5ac48fbcf92"
|
||||
|
||||
name: "Guest"
|
||||
username: "guest"
|
||||
password: "guest"
|
||||
|
||||
# Allow access only to the following endpoints when using user/pass credentials
|
||||
allow_endpoints: [
|
||||
"auth.login"
|
||||
]
|
||||
}
|
||||
}
|
8
server/config/default/services/projects.conf
Normal file
8
server/config/default/services/projects.conf
Normal file
@ -0,0 +1,8 @@
|
||||
# Order of featured projects, by name or ID
|
||||
featured_order: [
|
||||
# {id: "<project-id>"}
|
||||
# OR
|
||||
# {name: "<project-name>"}
|
||||
# OR
|
||||
# {name_regex: "<python-regex>"}
|
||||
]
|
@ -41,3 +41,6 @@ def get_deployment_type() -> str:
|
||||
|
||||
def get_default_company():
|
||||
return config.get("apiserver.default_company")
|
||||
|
||||
|
||||
missed_es_upgrade = False
|
||||
|
@ -32,6 +32,8 @@ class Role(object):
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
guest = "guest"
|
||||
""" Guest user. Read Only."""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional
|
||||
from typing import Collection, Sequence, Union, Optional, Type
|
||||
|
||||
from boltons.iterutils import first, bucketize
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
@ -9,6 +9,7 @@ from mongoengine import Q, Document, ListField, StringField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apierrors import errors
|
||||
from apierrors.base import BaseError
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
@ -483,6 +484,21 @@ class GetMixin(PropsMixin):
|
||||
query=_query, parameters=parameters, override_projection=override_projection
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_many_public(
|
||||
cls, query: Q = None, projection: Collection[str] = None,
|
||||
):
|
||||
"""
|
||||
Fetch all public documents matching a provided query.
|
||||
:param query: Optional query object (mongoengine.Q).
|
||||
:param projection: A list of projection fields.
|
||||
:return: A list of documents matching the query.
|
||||
"""
|
||||
q = get_company_or_none_constraint()
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(query=_query, override_projection=projection)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls: Union["GetMixin", Document],
|
||||
@ -728,6 +744,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
)
|
||||
return cls.objects.aggregate(pipeline, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def set_public(
|
||||
cls: Type[Document],
|
||||
company_id: str,
|
||||
ids: Sequence[str],
|
||||
invalid_cls: Type[BaseError],
|
||||
enabled: bool = True,
|
||||
):
|
||||
if enabled:
|
||||
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
|
||||
update = dict(set__company_origin=company_id, unset__company=1)
|
||||
else:
|
||||
items = list(
|
||||
cls.objects(
|
||||
id__in=ids, company__in=(None, ""), company_origin=company_id
|
||||
).only("id")
|
||||
)
|
||||
update = dict(set__company=company_id, unset__company_origin=1)
|
||||
|
||||
if len(items) < len(ids):
|
||||
missing = tuple(set(ids).difference(i.id for i in items))
|
||||
raise invalid_cls(ids=missing)
|
||||
|
||||
return {"updated": cls.objects(id__in=ids).update(**update)}
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
|
@ -72,3 +72,4 @@ class Model(DbModelMixin, Document):
|
||||
ui_cache = SafeDictField(
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from mongoengine import StringField, DateTimeField
|
||||
from mongoengine import StringField, DateTimeField, IntField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeSortedListField
|
||||
@ -40,3 +40,7 @@ class Project(AttributedDocument):
|
||||
system_tags = SafeSortedListField(StringField(required=True))
|
||||
default_output_destination = StrippedStringField()
|
||||
last_update = DateTimeField()
|
||||
featured = IntField(default=9999)
|
||||
logo_url = StringField()
|
||||
logo_blob = StringField(exclude_by_default=True)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
|
@ -118,7 +118,7 @@ external_task_types = set(get_options(TaskType))
|
||||
class Task(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
|
||||
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
|
||||
"last_metrics.": {"locale": "en_US", "numericOrdering": True},
|
||||
}
|
||||
|
||||
meta = {
|
||||
@ -194,3 +194,5 @@ class Task(AttributedDocument):
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
duration = IntField() # task duration in seconds
|
||||
|
@ -4,9 +4,9 @@ Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
from requests.packages.urllib3.util.retry import Retry
|
||||
|
||||
@ -14,21 +14,24 @@ HERE = Path(__file__).resolve().parent
|
||||
|
||||
session = requests.Session()
|
||||
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
|
||||
session.mount('http://', adapter)
|
||||
session.mount("http://", adapter)
|
||||
|
||||
|
||||
def get_template(host: str, template) -> dict:
|
||||
url = f"{host}/_template/{template}"
|
||||
res = session.get(url)
|
||||
return res.json()
|
||||
|
||||
|
||||
def apply_mappings_to_host(host: str):
|
||||
def _send_mapping(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
es_server = host
|
||||
url = f"{es_server}/_template/{f.stem}"
|
||||
url = f"{host}/_template/{f.stem}"
|
||||
|
||||
session.delete(url)
|
||||
r = session.post(
|
||||
url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(data),
|
||||
url, headers={"Content-Type": "application/json"}, data=json.dumps(data)
|
||||
)
|
||||
return {"mapping": f.stem, "result": r.text}
|
||||
|
||||
@ -47,7 +50,8 @@ def parse_args():
|
||||
|
||||
|
||||
def main():
|
||||
for host in parse_args().hosts:
|
||||
args = parse_args()
|
||||
for host in args.hosts:
|
||||
print(">>>>> Applying mapping to " + host)
|
||||
res = apply_mappings_to_host(host)
|
||||
print(res)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from furl import furl
|
||||
|
||||
from config import config
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
from elastic.apply_mappings import apply_mappings_to_host, get_template
|
||||
from es_factory import get_cluster_config
|
||||
|
||||
log = config.logger(__file__)
|
||||
@ -15,13 +15,22 @@ class MissingElasticConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
def _url_from_host_conf(conf: dict) -> str:
|
||||
return furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
|
||||
|
||||
def init_es_data() -> bool:
|
||||
"""Return True if the db was empty"""
|
||||
hosts_config = get_cluster_config("events").get("hosts")
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration("for cluster 'events'")
|
||||
|
||||
empty_db = not get_template(_url_from_host_conf(hosts_config[0]), "events*")
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
host = _url_from_host_conf(conf)
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
log.info(res)
|
||||
|
||||
return empty_db
|
||||
|
@ -1,26 +1,39 @@
|
||||
{
|
||||
"template": "events-*",
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"_routing": {
|
||||
"required": true
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": { "type": "date" },
|
||||
"task": { "type": "keyword" },
|
||||
"type": { "type": "keyword" },
|
||||
"worker": { "type": "keyword" },
|
||||
"timestamp": { "type": "date" },
|
||||
"iter": { "type": "long" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" }
|
||||
"type": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"iter": {
|
||||
"type": "long"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,14 @@
|
||||
{
|
||||
"template": "events-log-*",
|
||||
"order" : 1,
|
||||
"index_patterns": "events-log-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"msg": { "type":"text", "index": false },
|
||||
"level": { "type":"keyword" }
|
||||
"properties": {
|
||||
"msg": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
},
|
||||
"level": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,11 @@
|
||||
{
|
||||
"template": "events-plot-*",
|
||||
"order" : 1,
|
||||
"index_patterns": "events-plot-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"plot_str": { "type":"text", "index": false }
|
||||
"properties": {
|
||||
"plot_str": {
|
||||
"type": "text",
|
||||
"index": false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,13 @@
|
||||
{
|
||||
"template": "events-training_debug_image-*",
|
||||
"order" : 1,
|
||||
"index_patterns": "events-training_debug_image-*",
|
||||
"order": 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"key": { "type": "keyword" },
|
||||
"url": { "type": "keyword" }
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"url": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,26 +1,24 @@
|
||||
{
|
||||
"template": "queue_metrics_*",
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"metrics": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
"queue": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"average_waiting_time": {
|
||||
"type": "float"
|
||||
},
|
||||
"queue_length": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,22 +1,36 @@
|
||||
{
|
||||
"template": "worker_stats_*",
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
},
|
||||
"mappings": {
|
||||
"stat": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": {
|
||||
"type": "date"
|
||||
},
|
||||
"properties": {
|
||||
"timestamp": { "type": "date" },
|
||||
"worker": { "type": "keyword" },
|
||||
"category": { "type": "keyword" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" },
|
||||
"unit": { "type": "keyword" },
|
||||
"task": { "type": "keyword" }
|
||||
"worker": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"category": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"metric": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"variant": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"value": {
|
||||
"type": "float"
|
||||
},
|
||||
"unit": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"task": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -24,14 +24,9 @@ def _pre_populate(company_id: str, zip_file: str):
|
||||
else:
|
||||
log.info(f"Pre-populating using {zip_file}")
|
||||
|
||||
user_id = _ensure_backend_user(
|
||||
"__allegroai__", company_id, "Allegro.ai"
|
||||
)
|
||||
|
||||
PrePopulate.import_from_zip(
|
||||
zip_file,
|
||||
company_id="",
|
||||
user_id=user_id,
|
||||
artifacts_path=config.get(
|
||||
"apiserver.pre_populate.artifacts_path", None
|
||||
),
|
||||
@ -60,7 +55,7 @@ def init_mongo_data() -> bool:
|
||||
|
||||
_ensure_uuid()
|
||||
|
||||
company_id = _ensure_company(log)
|
||||
company_id = _ensure_company(get_default_company(), "trains", log)
|
||||
|
||||
_ensure_default_queue(company_id)
|
||||
|
||||
@ -82,9 +77,13 @@ def init_mongo_data() -> bool:
|
||||
if fixed_mode:
|
||||
log.info("Fixed users mode is enabled")
|
||||
FixedUser.validate()
|
||||
|
||||
if FixedUser.guest_enabled():
|
||||
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
|
||||
|
||||
for user in FixedUser.from_config():
|
||||
try:
|
||||
ensure_fixed_user(user, company_id, log=log)
|
||||
ensure_fixed_user(user, log=log)
|
||||
except Exception as ex:
|
||||
log.error(f"Failed creating fixed user {user.name}: {ex}")
|
||||
|
||||
|
@ -1,30 +1,44 @@
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
from os.path import splitext
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union
|
||||
from typing import (
|
||||
Optional,
|
||||
Any,
|
||||
Type,
|
||||
Set,
|
||||
Dict,
|
||||
Sequence,
|
||||
Tuple,
|
||||
BinaryIO,
|
||||
Union,
|
||||
Mapping,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
|
||||
import attr
|
||||
import mongoengine
|
||||
from boltons.iterutils import chunked_iter
|
||||
from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from bll.event import EventBLL
|
||||
from config import config
|
||||
from database.model import EntityVisibility
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, ArtifactModes, TaskStatus
|
||||
from database.utils import get_options
|
||||
from utilities import json
|
||||
from .user import _ensure_backend_user
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@ -32,6 +46,7 @@ class PrePopulate:
|
||||
events_file_suffix = "_events"
|
||||
export_tag_prefix = "Exported:"
|
||||
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
|
||||
metadata_filename = "metadata.json"
|
||||
|
||||
class JsonLinesWriter:
|
||||
def __init__(self, file: BinaryIO):
|
||||
@ -54,26 +69,21 @@ class PrePopulate:
|
||||
self._write("\n" + line)
|
||||
self.empty = False
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class _MapData:
|
||||
files: Sequence[str] = None
|
||||
entities: Dict[str, datetime] = None
|
||||
|
||||
@staticmethod
|
||||
def _get_last_update_time(entity) -> datetime:
|
||||
return getattr(entity, "last_update", None) or getattr(entity, "created")
|
||||
|
||||
@classmethod
|
||||
def _check_for_update(
|
||||
cls, map_file: Path, entities: dict
|
||||
cls, map_file: Path, entities: dict, metadata_hash: str
|
||||
) -> Tuple[bool, Sequence[str]]:
|
||||
if not map_file.is_file():
|
||||
return True, []
|
||||
|
||||
files = []
|
||||
try:
|
||||
map_data = cls._MapData(**json.loads(map_file.read_text()))
|
||||
files = map_data.files
|
||||
map_data = json.loads(map_file.read_text())
|
||||
files = map_data.get("files", [])
|
||||
for file in files:
|
||||
if not Path(file).is_file():
|
||||
return True, files
|
||||
@ -82,7 +92,7 @@ class PrePopulate:
|
||||
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
|
||||
for item in chain.from_iterable(entities.values())
|
||||
}
|
||||
old_times = map_data.entities
|
||||
old_times = map_data.get("entities", {})
|
||||
|
||||
if set(new_times.keys()) != set(old_times.keys()):
|
||||
return True, files
|
||||
@ -90,6 +100,10 @@ class PrePopulate:
|
||||
for id_, new_timestamp in new_times.items():
|
||||
if new_timestamp != old_times[id_]:
|
||||
return True, files
|
||||
|
||||
if metadata_hash != map_data.get("metadata_hash", ""):
|
||||
return True, files
|
||||
|
||||
except Exception as ex:
|
||||
print("Error reading map file. " + str(ex))
|
||||
return True, files
|
||||
@ -98,16 +112,24 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _write_update_file(
|
||||
cls, map_file: Path, entities: dict, created_files: Sequence[str]
|
||||
cls,
|
||||
map_file: Path,
|
||||
entities: dict,
|
||||
created_files: Sequence[str],
|
||||
metadata_hash: str,
|
||||
):
|
||||
map_data = cls._MapData(
|
||||
files=created_files,
|
||||
entities={
|
||||
entity.id: cls._get_last_update_time(entity)
|
||||
for entity in chain.from_iterable(entities.values())
|
||||
},
|
||||
map_file.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
files=created_files,
|
||||
entities={
|
||||
entity.id: cls._get_last_update_time(entity)
|
||||
for entity in chain.from_iterable(entities.values())
|
||||
},
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
)
|
||||
)
|
||||
map_file.write_text(json.dumps(attr.asdict(map_data)))
|
||||
|
||||
@staticmethod
|
||||
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
|
||||
@ -117,7 +139,9 @@ class PrePopulate:
|
||||
return True
|
||||
if a.startswith("http"):
|
||||
parsed = urlparse(a)
|
||||
if parsed.scheme in {"http", "https"} and parsed.port == 8081:
|
||||
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
|
||||
"8081"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -137,6 +161,7 @@ class PrePopulate:
|
||||
artifacts_path: str = None,
|
||||
task_statuses: Sequence[str] = None,
|
||||
tag_exported_entities: bool = False,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
) -> Sequence[str]:
|
||||
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
|
||||
raise ValueError("Invalid task statuses")
|
||||
@ -146,11 +171,22 @@ class PrePopulate:
|
||||
experiments=experiments, projects=projects, task_statuses=task_statuses
|
||||
)
|
||||
|
||||
hash_ = hashlib.md5()
|
||||
if metadata:
|
||||
meta_str = json.dumps(metadata)
|
||||
hash_.update(meta_str.encode())
|
||||
metadata_hash = hash_.hexdigest()
|
||||
else:
|
||||
meta_str, metadata_hash = "", ""
|
||||
|
||||
map_file = file.with_suffix(".map")
|
||||
updated, old_files = cls._check_for_update(map_file, entities)
|
||||
updated, old_files = cls._check_for_update(
|
||||
map_file, entities=entities, metadata_hash=metadata_hash
|
||||
)
|
||||
if not updated:
|
||||
print(f"There are no updates from the last export")
|
||||
return old_files
|
||||
|
||||
for old in old_files:
|
||||
old_path = Path(old)
|
||||
if old_path.is_file():
|
||||
@ -158,10 +194,16 @@ class PrePopulate:
|
||||
|
||||
zip_args = dict(mode="w", compression=ZIP_BZIP2)
|
||||
with ZipFile(file, **zip_args) as zfile:
|
||||
artifacts, hash_ = cls._export(
|
||||
zfile, entities, tag_entities=tag_exported_entities
|
||||
if metadata:
|
||||
zfile.writestr(cls.metadata_filename, meta_str)
|
||||
artifacts = cls._export(
|
||||
zfile,
|
||||
entities=entities,
|
||||
hash_=hash_,
|
||||
tag_entities=tag_exported_entities,
|
||||
)
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}")
|
||||
|
||||
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
|
||||
file.replace(file_with_hash)
|
||||
created_files = [str(file_with_hash)]
|
||||
|
||||
@ -172,16 +214,43 @@ class PrePopulate:
|
||||
cls._export_artifacts(zfile, artifacts, artifacts_path)
|
||||
created_files.append(str(artifacts_file))
|
||||
|
||||
cls._write_update_file(map_file, entities, created_files)
|
||||
cls._write_update_file(
|
||||
map_file,
|
||||
entities=entities,
|
||||
created_files=created_files,
|
||||
metadata_hash=metadata_hash,
|
||||
)
|
||||
|
||||
return created_files
|
||||
|
||||
@classmethod
|
||||
def import_from_zip(
|
||||
cls, filename: str, company_id: str, user_id: str, artifacts_path: str
|
||||
cls,
|
||||
filename: str,
|
||||
company_id: str,
|
||||
artifacts_path: str,
|
||||
user_id: str = "",
|
||||
user_name: str = "",
|
||||
):
|
||||
metadata = None
|
||||
|
||||
with ZipFile(filename) as zfile:
|
||||
cls._import(zfile, company_id, user_id)
|
||||
try:
|
||||
with zfile.open(cls.metadata_filename) as f:
|
||||
metadata = json.loads(f.read())
|
||||
if not user_id:
|
||||
meta_user_id = metadata.get("user_id", "")
|
||||
meta_user_name = metadata.get("user_name", "")
|
||||
user_id, user_name = meta_user_id, meta_user_name
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
user_id, user_name = "__allegroai__", "Allegro.ai"
|
||||
|
||||
user_id = _ensure_backend_user(user_id, company_id, user_name)
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
if artifacts_path and os.path.isdir(artifacts_path):
|
||||
artifacts_file = Path(filename).with_suffix(".artifacts")
|
||||
@ -190,6 +259,24 @@ class PrePopulate:
|
||||
with ZipFile(artifacts_file) as zfile:
|
||||
zfile.extractall(artifacts_path)
|
||||
|
||||
@classmethod
|
||||
def update_featured_projects_order(cls):
|
||||
featured_order = config.get("services.projects.featured_order", [])
|
||||
|
||||
def get_index(p: Project):
|
||||
for index, entry in enumerate(featured_order):
|
||||
if (
|
||||
entry.get("id", None) == p.id
|
||||
or entry.get("name", None) == p.name
|
||||
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
|
||||
):
|
||||
return index
|
||||
return 999
|
||||
|
||||
for project in Project.get_many_public(projection=["id", "name"]):
|
||||
featured_index = get_index(project)
|
||||
Project.objects(id=project.id).update(featured=featured_index)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_type(
|
||||
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
|
||||
@ -389,15 +476,14 @@ class PrePopulate:
|
||||
|
||||
@classmethod
|
||||
def _export(
|
||||
cls, writer: ZipFile, entities: dict, tag_entities: bool = False
|
||||
) -> Tuple[Sequence[str], str]:
|
||||
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Export the requested experiments, projects and models and return the list of artifact files
|
||||
Always do the export on sorted items since the order of items influence hash
|
||||
"""
|
||||
artifacts = []
|
||||
now = datetime.utcnow()
|
||||
hash_ = hashlib.md5()
|
||||
for cls_ in sorted(entities, key=attrgetter("__name__")):
|
||||
items = sorted(entities[cls_], key=attrgetter("id"))
|
||||
if not items:
|
||||
@ -423,7 +509,7 @@ class PrePopulate:
|
||||
if tag_entities:
|
||||
cls._add_tag(items, now.strftime(cls.export_tag))
|
||||
|
||||
return artifacts, hash_.hexdigest()
|
||||
return artifacts
|
||||
|
||||
@staticmethod
|
||||
def json_lines(file: BinaryIO):
|
||||
@ -441,7 +527,13 @@ class PrePopulate:
|
||||
yield clean
|
||||
|
||||
@classmethod
|
||||
def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None):
|
||||
def _import(
|
||||
cls,
|
||||
reader: ZipFile,
|
||||
company_id: str = "",
|
||||
user_id: str = None,
|
||||
metadata: Mapping[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
Import entities and events from the zip file
|
||||
Start from entities since event import will require the tasks already in DB
|
||||
@ -451,12 +543,13 @@ class PrePopulate:
|
||||
fi
|
||||
for fi in reader.filelist
|
||||
if not fi.orig_filename.endswith(event_file_ending)
|
||||
and fi.orig_filename != cls.metadata_filename
|
||||
)
|
||||
event_files = (
|
||||
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
|
||||
)
|
||||
for files, reader_func in (
|
||||
(entity_files, cls._import_entity),
|
||||
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
|
||||
(event_files, cls._import_events),
|
||||
):
|
||||
for file_info in files:
|
||||
@ -466,11 +559,20 @@ class PrePopulate:
|
||||
reader_func(f, full_name, company_id, user_id)
|
||||
|
||||
@classmethod
|
||||
def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str):
|
||||
def _import_entity(
|
||||
cls,
|
||||
f: BinaryIO,
|
||||
full_name: str,
|
||||
company_id: str,
|
||||
user_id: str,
|
||||
metadata: Mapping[str, Any],
|
||||
):
|
||||
module_name, _, class_name = full_name.rpartition(".")
|
||||
module = importlib.import_module(module_name)
|
||||
cls_: Type[mongoengine.Document] = getattr(module, class_name)
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
|
||||
override_project_count = 0
|
||||
for item in cls.json_lines(f):
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
@ -478,10 +580,24 @@ class PrePopulate:
|
||||
if hasattr(doc, "company"):
|
||||
doc.company = company_id
|
||||
if isinstance(doc, Project):
|
||||
override_project_name = metadata.get("project_name", None)
|
||||
if override_project_name:
|
||||
if override_project_count:
|
||||
override_project_name = (
|
||||
f"{override_project_name} {override_project_count + 1}"
|
||||
)
|
||||
override_project_count += 1
|
||||
doc.name = override_project_name
|
||||
|
||||
doc.logo_url = metadata.get("logo_url", None)
|
||||
doc.logo_blob = metadata.get("logo_blob", None)
|
||||
|
||||
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
|
||||
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
|
||||
)
|
||||
|
||||
doc.save()
|
||||
|
||||
if isinstance(doc, Task):
|
||||
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
|
||||
|
||||
|
@ -58,15 +58,15 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
|
||||
return user_id
|
||||
|
||||
|
||||
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
|
||||
if User.objects(id=user.user_id).first():
|
||||
def ensure_fixed_user(user: FixedUser, log: Logger):
|
||||
if User.objects(company=user.company, id=user.user_id).first():
|
||||
return
|
||||
|
||||
data = attr.asdict(user)
|
||||
data["id"] = user.user_id
|
||||
data["email"] = f"{user.user_id}@example.com"
|
||||
data["role"] = Role.user
|
||||
data["role"] = Role.guest if user.is_guest else Role.user
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
|
||||
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
|
||||
|
||||
return _ensure_backend_user(user.user_id, company_id, user.name)
|
||||
return _ensure_backend_user(user.user_id, user.company, user.name)
|
||||
|
@ -3,7 +3,6 @@ from uuid import uuid4
|
||||
|
||||
from bll.queue import QueueBLL
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
from database.model.company import Company
|
||||
from database.model.queue import Queue
|
||||
from database.model.settings import Settings, SettingKeys
|
||||
@ -11,13 +10,11 @@ from database.model.settings import Settings, SettingKeys
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def _ensure_company(log: Logger):
|
||||
company_id = get_default_company()
|
||||
def _ensure_company(company_id, company_name, log: Logger):
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
|
@ -1,7 +1,8 @@
|
||||
attrs>=19.1.0
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=5.0.0,<6.0.0
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
fastjsonschema>=2.8
|
||||
Flask-Compress>=1.4.0
|
||||
Flask-Cors>=3.0.5
|
||||
@ -24,7 +25,7 @@ python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.0,<3
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
tqdm
|
||||
validators>=0.12.4
|
@ -328,6 +328,9 @@ fixed_users_mode {
|
||||
description: "Fixed users mode enabled"
|
||||
type: boolean
|
||||
}
|
||||
migration_warning {
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -848,7 +848,7 @@
|
||||
description: "Task ID"
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
@ -886,7 +886,7 @@
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task Task IDs"
|
||||
description: "List of task Task IDs. Maximum amount of tasks is 10"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
@ -894,7 +894,7 @@
|
||||
}
|
||||
}
|
||||
samples {
|
||||
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
|
||||
description: "The amount of histogram points to return. Optional, the default value is 6000"
|
||||
type: integer
|
||||
}
|
||||
key {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -573,6 +573,7 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
@ -580,10 +581,61 @@ get_task_tags {
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company projects to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public projects to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the projects to convert. Only the projects originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1441,4 +1441,54 @@ add_or_update_artifacts {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_public {
|
||||
"2.9" {
|
||||
description: """Convert company tasks to public"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
make_private {
|
||||
"2.9" {
|
||||
description: """Convert public tasks to private"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
ids {
|
||||
description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted"
|
||||
type: array
|
||||
items { type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of tasks updated"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -135,6 +135,10 @@
|
||||
description: "Task currently being run by the worker"
|
||||
"$ref": "#/definitions/current_task_entry"
|
||||
}
|
||||
project {
|
||||
description: "Project in which currently executing task resides"
|
||||
"$ref": "#/definitions/id_name_entry"
|
||||
}
|
||||
queue {
|
||||
description: "Queue from which running task was taken"
|
||||
"$ref": "#/definitions/queue_entry"
|
||||
@ -151,11 +155,11 @@
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Worker ID"
|
||||
description: "ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Worker name"
|
||||
description: "Name"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ from werkzeug.exceptions import BadRequest
|
||||
import database
|
||||
from apierrors.base import BaseError
|
||||
from bll.statistics.stats_reporter import StatisticsReporter
|
||||
from config import config
|
||||
from config import config, info
|
||||
from elastic.initialize import init_es_data
|
||||
from mongo.initialize import init_mongo_data, pre_populate_data
|
||||
from service_repo import ServiceRepo, APICall
|
||||
@ -39,9 +39,11 @@ database.initialize()
|
||||
hosts_string = ";".join(sorted(database.get_hosts()))
|
||||
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
|
||||
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)):
|
||||
print(key)
|
||||
init_es_data()
|
||||
empty_es = init_es_data()
|
||||
empty_db = init_mongo_data()
|
||||
if empty_es and not empty_db:
|
||||
log.info(f"ES database seems not migrated")
|
||||
info.missed_es_upgrade = True
|
||||
if empty_db and config.get("apiserver.pre_populate.enabled", False):
|
||||
pre_populate_data()
|
||||
|
||||
|
@ -69,6 +69,10 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
if fixed_user:
|
||||
if secret_key != fixed_user.password:
|
||||
raise errors.unauthorized.InvalidCredentials('bad username or password')
|
||||
|
||||
if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action):
|
||||
raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest')
|
||||
|
||||
query = Q(id=fixed_user.user_id)
|
||||
|
||||
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
|
||||
|
@ -1,14 +1,12 @@
|
||||
import hashlib
|
||||
from functools import lru_cache
|
||||
from typing import Sequence, TypeVar
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import attr
|
||||
|
||||
from config import config
|
||||
from config.info import get_default_company
|
||||
|
||||
T = TypeVar("T", bound="FixedUser")
|
||||
|
||||
|
||||
class FixedUsersError(Exception):
|
||||
pass
|
||||
@ -21,6 +19,8 @@ class FixedUser:
|
||||
name: str
|
||||
company: str = get_default_company()
|
||||
|
||||
is_guest: bool = False
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest()
|
||||
|
||||
@ -28,6 +28,10 @@ class FixedUser:
|
||||
def enabled(cls):
|
||||
return config.get("apiserver.auth.fixed_users.enabled", False)
|
||||
|
||||
@classmethod
|
||||
def guest_enabled(cls):
|
||||
return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False)
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
if not cls.enabled():
|
||||
@ -39,18 +43,50 @@ class FixedUser:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def from_config(cls) -> Sequence[T]:
|
||||
return [
|
||||
# @lru_cache()
|
||||
def from_config(cls) -> Sequence["FixedUser"]:
|
||||
users = [
|
||||
cls(**user) for user in config.get("apiserver.auth.fixed_users.users", [])
|
||||
]
|
||||
|
||||
if cls.guest_enabled():
|
||||
users.insert(
|
||||
0,
|
||||
cls.get_guest_user()
|
||||
)
|
||||
|
||||
return users
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def get_by_username(cls, username) -> T:
|
||||
def get_by_username(cls, username) -> "FixedUser":
|
||||
return next(
|
||||
(user for user in cls.from_config() if user.username == username), None
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def is_guest_endpoint(cls, service, action):
|
||||
"""
|
||||
Validate a potential guest user,
|
||||
This method will verify the user is indeed the guest user,
|
||||
and that the guest user may access the service/action using its username/password
|
||||
"""
|
||||
return any(
|
||||
ep == ".".join((service, action))
|
||||
for ep in config.get("services.auth.fixed_users.guest.allow_endpoints", [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_guest_user(cls) -> Optional["FixedUser"]:
|
||||
if cls.guest_enabled():
|
||||
return cls(
|
||||
is_guest=True,
|
||||
username=config.get("services.auth.fixed_users.guest.username"),
|
||||
password=config.get("services.auth.fixed_users.guest.password"),
|
||||
name=config.get("services.auth.fixed_users.guest.name"),
|
||||
company=config.get("services.auth.fixed_users.guest.default_company"),
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.user_id)
|
||||
|
@ -16,7 +16,7 @@ from apimodels.auth import (
|
||||
)
|
||||
from apimodels.base import UpdateResponse
|
||||
from bll.auth import AuthBLL
|
||||
from config import config
|
||||
from config import config, info
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User
|
||||
from service_repo import APICall, endpoint
|
||||
@ -176,4 +176,17 @@ def update(call, company_id, _):
|
||||
|
||||
@endpoint("auth.fixed_users_mode")
|
||||
def fixed_users_mode(call: APICall, *_, **__):
|
||||
call.result.data = dict(enabled=FixedUser.enabled())
|
||||
data = {
|
||||
"enabled": FixedUser.enabled(),
|
||||
"migration_warning": info.missed_es_upgrade,
|
||||
"guest": {
|
||||
"enabled": FixedUser.guest_enabled(),
|
||||
}
|
||||
}
|
||||
guest_user = FixedUser.get_guest_user()
|
||||
if guest_user:
|
||||
data["guest"]["name"] = guest_user.name
|
||||
data["guest"]["username"] = guest_user.username
|
||||
data["guest"]["password"] = guest_user.password
|
||||
|
||||
call.result.data = data
|
||||
|
@ -5,7 +5,8 @@ from mongoengine import Q, EmbeddedDocument
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apierrors.errors.bad_request import InvalidModelId
|
||||
from apimodels.base import UpdateResponse, MakePublicRequest
|
||||
from apimodels.models import (
|
||||
CreateModelRequest,
|
||||
CreateModelResponse,
|
||||
@ -467,3 +468,21 @@ def update(call: APICall, company_id, _):
|
||||
if del_count:
|
||||
_reset_cached_tags(company_id, projects=[model.project])
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
|
||||
|
||||
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Model.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Model.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
|
||||
)
|
||||
|
@ -8,7 +8,8 @@ from mongoengine import Q
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apierrors.errors.bad_request import InvalidProjectId
|
||||
from apimodels.base import UpdateResponse, MakePublicRequest
|
||||
from apimodels.projects import (
|
||||
GetHyperParamReq,
|
||||
GetHyperParamResp,
|
||||
@ -422,3 +423,23 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Project.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
|
||||
)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Project.set_public(
|
||||
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
|
||||
)
|
||||
|
@ -11,7 +11,8 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
|
||||
from apierrors import errors, APIError
|
||||
from apimodels.base import UpdateResponse, IdResponse
|
||||
from apierrors.errors.bad_request import InvalidTaskId
|
||||
from apimodels.base import UpdateResponse, IdResponse, MakePublicRequest
|
||||
from apimodels.tasks import (
|
||||
StartedResponse,
|
||||
ResetResponse,
|
||||
@ -78,10 +79,24 @@ def set_task_status_from_call(
|
||||
task = TaskBLL.get_task_with_access(
|
||||
request.task,
|
||||
company_id=company_id,
|
||||
only=tuple({"status", "project"} | fields_resolver.get_names()),
|
||||
only=tuple(
|
||||
{"status", "project", "started", "duration"} | fields_resolver.get_names()
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
if "duration" not in fields_resolver.get_names():
|
||||
if new_status == Task.started:
|
||||
fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
|
||||
elif new_status in (
|
||||
TaskStatus.completed,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.stopped,
|
||||
):
|
||||
fields_resolver.add_fields(
|
||||
duration=int((task.started - datetime.utcnow()).total_seconds())
|
||||
)
|
||||
|
||||
status_reason = request.status_reason
|
||||
status_message = request.status_message
|
||||
force = request.force
|
||||
@ -354,9 +369,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(
|
||||
company, Tags.Task, projects=projects
|
||||
)
|
||||
org_bll.reset_tags(company, Tags.Task, projects=projects)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@ -573,9 +586,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
if updated:
|
||||
new_project = fixed_fields.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, task.project]
|
||||
)
|
||||
_reset_cached_tags(company_id, projects=[new_project, task.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=task.project, fields=fixed_fields
|
||||
@ -1005,3 +1016,19 @@ def add_or_update_artifacts(
|
||||
task_id=request.task, company_id=company_id, artifacts=request.artifacts
|
||||
)
|
||||
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
|
||||
|
||||
|
||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from apierrors.errors.bad_request import InvalidModelId
|
||||
from tests.automated import TestService
|
||||
|
||||
MODEL_CANNOT_BE_UPDATED_CODES = (400, 203)
|
||||
@ -7,7 +8,7 @@ IN_PROGRESS = "in_progress"
|
||||
|
||||
|
||||
class TestModelsService(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
def setUp(self, version="2.9"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_publish_output_model_running_task(self):
|
||||
@ -197,6 +198,28 @@ class TestModelsService(TestService):
|
||||
res = self.api.models.get_frameworks(projects=[project])
|
||||
self.assertEqual([], res.frameworks)
|
||||
|
||||
def test_make_public(self):
|
||||
m1 = self._create_model(name="public model test")
|
||||
|
||||
# model with company_origin not set to the current company cannot be converted to private
|
||||
with self.api.raises(InvalidModelId):
|
||||
self.api.models.make_private(ids=[m1])
|
||||
|
||||
# public model can be retrieved but not updated
|
||||
res = self.api.models.make_public(ids=[m1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.models.get_all(id=[m1])
|
||||
self.assertEqual([m.id for m in res.models], [m1])
|
||||
with self.api.raises(InvalidModelId):
|
||||
self.api.models.update(model=m1, name="public model test change 1")
|
||||
|
||||
# task made private again and can be both retrieved and updated
|
||||
res = self.api.models.make_private(ids=[m1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.models.get_all(id=[m1])
|
||||
self.assertEqual([m.id for m in res.models], [m1])
|
||||
self.api.models.update(model=m1, name="public model test change 2")
|
||||
|
||||
def _assert_task_status(self, task_id, status):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
assert task.status == status
|
||||
|
34
server/tests/automated/test_projects_edit.py
Normal file
34
server/tests/automated/test_projects_edit.py
Normal file
@ -0,0 +1,34 @@
|
||||
from apierrors.errors.bad_request import InvalidProjectId
|
||||
from apierrors.errors.forbidden import NoWritePermission
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class TestProjectsEdit(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.9")
|
||||
|
||||
def test_make_public(self):
|
||||
p1 = self.create_temp("projects", name="Test public", description="test")
|
||||
|
||||
# project with company_origin not set to the current company cannot be converted to private
|
||||
with self.api.raises(InvalidProjectId):
|
||||
self.api.projects.make_private(ids=[p1])
|
||||
|
||||
# public project can be retrieved but not updated
|
||||
res = self.api.projects.make_public(ids=[p1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.projects.get_all(id=[p1])
|
||||
self.assertEqual([p.id for p in res.projects], [p1])
|
||||
with self.api.raises(NoWritePermission):
|
||||
self.api.projects.update(project=p1, name="Test public change 1")
|
||||
|
||||
# task made private again and can be both retrieved and updated
|
||||
res = self.api.projects.make_private(ids=[p1])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.projects.get_all(id=[p1])
|
||||
self.assertEqual([p.id for p in res.projects], [p1])
|
||||
self.api.projects.update(project=p1, name="Test public change 2")
|
@ -1,4 +1,5 @@
|
||||
from apierrors.errors.bad_request import InvalidModelId, ValidationError
|
||||
from apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
|
||||
from apierrors.errors.forbidden import NoWritePermission
|
||||
from config import config
|
||||
from tests.automated import TestService
|
||||
|
||||
@ -8,7 +9,7 @@ log = config.logger(__file__)
|
||||
|
||||
class TestTasksEdit(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version=2.5)
|
||||
super().setUp(version="2.9")
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
@ -145,3 +146,28 @@ class TestTasksEdit(TestService):
|
||||
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
|
||||
)
|
||||
return new_task
|
||||
|
||||
def test_make_public(self):
|
||||
task = self.new_task()
|
||||
|
||||
# task is created as private and can be updated
|
||||
self.api.tasks.started(task=task)
|
||||
|
||||
# task with company_origin not set to the current company cannot be converted to private
|
||||
with self.api.raises(InvalidTaskId):
|
||||
self.api.tasks.make_private(ids=[task])
|
||||
|
||||
# public task can be retrieved but not updated
|
||||
res = self.api.tasks.make_public(ids=[task])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.tasks.get_all_ex(id=[task])
|
||||
self.assertEqual([t.id for t in res.tasks], [task])
|
||||
with self.api.raises(NoWritePermission):
|
||||
self.api.tasks.stopped(task=task)
|
||||
|
||||
# task made private again and can be both retrieved and updated
|
||||
res = self.api.tasks.make_private(ids=[task])
|
||||
self.assertEqual(res.updated, 1)
|
||||
res = self.api.tasks.get_all_ex(id=[task])
|
||||
self.assertEqual([t.id for t in res.tasks], [task])
|
||||
self.api.tasks.stopped(task=task)
|
||||
|
Loading…
Reference in New Issue
Block a user