Move to ElasticSearch 7

Add initial support for project ordering
Add support for sortable task duration (used by the UI in the experiment's table)
Add support for project name in worker's current task info
Add support for results and artifacts in pre-populates examples
Add demo server features
This commit is contained in:
allegroai 2020-08-10 08:30:40 +03:00
parent 77397c4f21
commit baba8b5b73
52 changed files with 1655 additions and 1144 deletions

View File

@ -1,7 +1,8 @@
from jsonmodels import models, fields from jsonmodels import models, fields
from jsonmodels.validators import Length
from mongoengine.base import BaseDocument from mongoengine.base import BaseDocument
from apimodels import DictField from apimodels import DictField, ListField
class MongoengineFieldsDict(DictField): class MongoengineFieldsDict(DictField):
@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
""" """
mongoengine_update_operators = ( mongoengine_update_operators = (
'inc', "inc",
'dec', "dec",
'push', "push",
'push_all', "push_all",
'pop', "pop",
'pull', "pull",
'pull_all', "pull_all",
'add_to_set', "add_to_set",
) )
@staticmethod @staticmethod
@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
@classmethod @classmethod
def _normalize_mongo_field_path(cls, path, value): def _normalize_mongo_field_path(cls, path, value):
parts = path.split('__') parts = path.split("__")
if len(parts) > 1: if len(parts) > 1:
if parts[0] == 'set': if parts[0] == "set":
parts = parts[1:] parts = parts[1:]
elif parts[0] == 'unset': elif parts[0] == "unset":
parts = parts[1:] parts = parts[1:]
value = None value = None
elif parts[0] in cls.mongoengine_update_operators: elif parts[0] in cls.mongoengine_update_operators:
return None, None return None, None
return '.'.join(parts), cls._normalize_mongo_value(value) return ".".join(parts), cls._normalize_mongo_value(value)
def parse_value(self, value): def parse_value(self, value):
value = super(MongoengineFieldsDict, self).parse_value(value) value = super(MongoengineFieldsDict, self).parse_value(value)
@ -62,3 +63,7 @@ class PagedRequest(models.Base):
class IdResponse(models.Base): class IdResponse(models.Base):
id = fields.StringField(required=True) id = fields.StringField(required=True)
class MakePublicRequest(models.Base):
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])

View File

@ -3,7 +3,7 @@ from typing import Sequence, Optional
from jsonmodels import validators from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base from jsonmodels.models import Base
from jsonmodels.validators import Length from jsonmodels.validators import Length, Min, Max
from apimodels import ListField, IntField, ActualEnumField from apimodels import ListField, IntField, ActualEnumField
from bll.event.event_metrics import EventType from bll.event.event_metrics import EventType
@ -11,7 +11,7 @@ from bll.event.scalar_key import ScalarKeyEnum
class HistogramRequestBase(Base): class HistogramRequestBase(Base):
samples: int = IntField(default=10000) samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@ -21,7 +21,7 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField( tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)] items_types=str, validators=[Length(minimum_value=1, maximum_value=10)]
) )

View File

@ -67,6 +67,7 @@ class WorkerEntry(Base, JsonSerializableMixin):
company = EmbeddedField(IdNameEntry) company = EmbeddedField(IdNameEntry)
ip = StringField() ip = StringField()
task = EmbeddedField(IdNameEntry) task = EmbeddedField(IdNameEntry)
project = EmbeddedField(IdNameEntry)
queue = StringField() # queue from which current task was taken queue = StringField() # queue from which current task was taken
queues = ListField(str) # list of queues this worker listens to queues = ListField(str) # list of queues this worker listens to
register_time = DateTimeField(required=True) register_time = DateTimeField(required=True)

View File

@ -208,7 +208,11 @@ class DebugImagesIterator:
"size": 0, "size": 0,
"query": { "query": {
"bool": { "bool": {
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}] "must": [
{"term": {"task": task}},
{"terms": {"metric": metrics}},
{"exists": {"field": "url"}},
]
} }
}, },
"aggs": { "aggs": {
@ -251,7 +255,7 @@ class DebugImagesIterator:
} }
with translate_errors_context(), TimingContext("es", "_init_metric_states"): with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req, routing=task) es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return []
@ -298,6 +302,7 @@ class DebugImagesIterator:
must_conditions = [ must_conditions = [
{"term": {"task": metric.task}}, {"term": {"task": metric.task}},
{"term": {"metric": metric.name}}, {"term": {"metric": metric.name}},
{"exists": {"field": "url"}},
] ]
must_not_conditions = [] must_not_conditions = []
@ -368,7 +373,7 @@ class DebugImagesIterator:
"terms": { "terms": {
"field": "iter", "field": "iter",
"size": iter_count, "size": iter_count,
"order": {"_term": "desc" if navigate_earlier else "asc"}, "order": {"_key": "desc" if navigate_earlier else "asc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
@ -387,7 +392,7 @@ class DebugImagesIterator:
}, },
} }
with translate_errors_context(), TimingContext("es", "get_debug_image_events"): with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task) es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return metric.task, metric.name, [] return metric.task, metric.name, []

View File

@ -3,7 +3,7 @@ from collections import defaultdict
from contextlib import closing from contextlib import closing
from datetime import datetime from datetime import datetime
from operator import attrgetter from operator import attrgetter
from typing import Sequence, Set, Tuple from typing import Sequence, Set, Tuple, Optional
import six import six
from elasticsearch import helpers from elasticsearch import helpers
@ -22,6 +22,7 @@ from database.errors import translate_errors_context
from database.model.task.task import Task, TaskStatus from database.model.task.task import Task, TaskStatus
from redis_manager import redman from redis_manager import redman
from timing_context import TimingContext from timing_context import TimingContext
from tools import safe_get
from utilities.dicts import flatten_nested_items from utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker # noinspection PyTypeChecker
@ -134,7 +135,6 @@ class EventBLL(object):
es_action = { es_action = {
"_op_type": "index", # overwrite if exists with same ID "_op_type": "index", # overwrite if exists with same ID
"_index": index_name, "_index": index_name,
"_type": "event",
"_source": event, "_source": event,
} }
@ -144,7 +144,6 @@ class EventBLL(object):
else: else:
es_action["_id"] = dbutils.id() es_action["_id"] = dbutils.id()
es_action["_routing"] = task_id
task_ids.add(task_id) task_ids.add(task_id)
if ( if (
iter is not None iter is not None
@ -342,14 +341,9 @@ class EventBLL(object):
} }
with translate_errors_context(), TimingContext("es", "scroll_task_events"): with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search( es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
index=es_index, body=es_req, scroll="1h", routing=task_id
)
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
next_scroll_id = es_res["_scroll_id"]
total_events = es_res["hits"]["total"]
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id, total_events return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant( def get_last_iterations_per_event_metric_variant(
@ -377,7 +371,7 @@ class EventBLL(object):
"terms": { "terms": {
"field": "iter", "field": "iter",
"size": num_last_iterations, "size": num_last_iterations,
"order": {"_term": "desc"}, "order": {"_key": "desc"},
} }
} }
}, },
@ -393,7 +387,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant" "es", "task_last_iter_metric_variant"
): ):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id) es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return []
@ -422,13 +416,11 @@ class EventBLL(object):
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
query = {"bool": defaultdict(list)} must = []
if last_iterations_per_plot is None: if last_iterations_per_plot is None:
must = query["bool"]["must"]
must.append({"terms": {"task": tasks}}) must.append({"terms": {"task": tasks}})
else: else:
should = query["bool"]["should"] should = []
for i, task_id in enumerate(tasks): for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant( last_iters = self.get_last_iterations_per_event_metric_variant(
es_index, task_id, last_iterations_per_plot, event_type es_index, task_id, last_iterations_per_plot, event_type
@ -451,32 +443,41 @@ class EventBLL(object):
) )
if not should: if not should:
return TaskEventsResult() return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None: if sort is None:
sort = [{"timestamp": {"order": "asc"}}] sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query} es_req = {
"sort": sort,
routing = ",".join(tasks) "size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_plots"): with translate_errors_context(), TimingContext("es", "get_task_plots"):
es_res = self.es.search( es_res = self.es.search(
index=es_index, index=es_index, body=es_req, ignore=404, scroll="1h",
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
) )
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
# scroll id may be missing when queering a totally empty DB
next_scroll_id = es_res.get("_scroll_id")
total_events = es_res["hits"]["total"]
return TaskEventsResult( return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events events=events, next_scroll_id=next_scroll_id, total_events=total_events
) )
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
"""
Return events and next scroll id from the scrolled query
Release the scroll once it is exhausted
"""
total_events = safe_get(es_res, "hits/total/value", default=0)
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.es.clear_scroll(scroll_id=next_scroll_id)
next_scroll_id = None
return events, total_events, next_scroll_id
def get_task_events( def get_task_events(
self, self,
company_id, company_id,
@ -502,20 +503,16 @@ class EventBLL(object):
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return TaskEventsResult() return TaskEventsResult()
query = {"bool": defaultdict(list)} must = []
if metric:
if metric or variant: must.append({"term": {"metric": metric}})
must = query["bool"]["must"] if variant:
if metric: must.append({"term": {"variant": variant}})
must.append({"term": {"metric": metric}})
if variant:
must.append({"term": {"variant": variant}})
if last_iter_count is None: if last_iter_count is None:
must = query["bool"]["must"]
must.append({"terms": {"task": task_ids}}) must.append({"terms": {"task": task_ids}})
else: else:
should = query["bool"]["should"] should = []
for i, task_id in enumerate(task_ids): for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters( last_iters = self.get_last_iters(
es_index, task_id, event_type, last_iter_count es_index, task_id, event_type, last_iter_count
@ -534,27 +531,23 @@ class EventBLL(object):
) )
if not should: if not should:
return TaskEventsResult() return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None: if sort is None:
sort = [{"timestamp": {"order": "asc"}}] sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query} es_req = {
"sort": sort,
routing = ",".join(task_ids) "size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.search( es_res = self.es.search(
index=es_index, index=es_index, body=es_req, ignore=404, scroll="1h",
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
) )
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
next_scroll_id = es_res.get("_scroll_id")
total_events = es_res["hits"]["total"]
return TaskEventsResult( return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events events=events, next_scroll_id=next_scroll_id, total_events=total_events
) )
@ -590,7 +583,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants" "es", "events_get_metrics_and_variants"
): ):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id) es_res = self.es.search(index=es_index, body=es_req)
metrics = {} metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"): for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
@ -622,14 +615,14 @@ class EventBLL(object):
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT, "size": EventMetrics.MAX_METRICS_COUNT,
"order": {"_term": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT, "size": EventMetrics.MAX_VARIANTS_COUNT,
"order": {"_term": "asc"}, "order": {"_key": "asc"},
}, },
"aggs": { "aggs": {
"last_value": { "last_value": {
@ -659,7 +652,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext( with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants" "es", "events_get_metrics_and_variants"
): ):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id) es_res = self.es.search(index=es_index, body=es_req)
metrics = [] metrics = []
max_timestamp = 0 max_timestamp = 0
@ -706,7 +699,7 @@ class EventBLL(object):
"sort": ["iter"], "sort": ["iter"],
} }
with translate_errors_context(), TimingContext("es", "task_stats_vector"): with translate_errors_context(), TimingContext("es", "task_stats_vector"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id) es_res = self.es.search(index=es_index, body=es_req)
vectors = [] vectors = []
iterations = [] iterations = []
@ -727,7 +720,7 @@ class EventBLL(object):
"terms": { "terms": {
"field": "iter", "field": "iter",
"size": iters, "size": iters,
"order": {"_term": "desc"}, "order": {"_key": "desc"},
} }
} }
}, },
@ -737,7 +730,7 @@ class EventBLL(object):
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}}) es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext("es", "task_last_iter"): with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id) es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res: if "aggregations" not in es_res:
return [] return []
@ -759,8 +752,6 @@ class EventBLL(object):
es_index = EventMetrics.get_index_name(company_id, "*") es_index = EventMetrics.get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}} es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"): with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query( es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
index=es_index, body=es_req, routing=task_id, refresh=True
)
return es_res.get("deleted", 0) return es_res.get("deleted", 0)

View File

@ -1,12 +1,11 @@
import itertools import itertools
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from operator import itemgetter from operator import itemgetter
from typing import Sequence, Tuple, Callable, Iterable from typing import Sequence, Tuple
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from mongoengine import Q from mongoengine import Q
@ -16,7 +15,7 @@ from config import config
from database.errors import translate_errors_context from database.errors import translate_errors_context
from database.model.task.task import Task from database.model.task.task import Task
from timing_context import TimingContext from timing_context import TimingContext
from utilities import safe_get from tools import safe_get
log = config.logger(__file__) log = config.logger(__file__)
@ -30,14 +29,18 @@ class EventType(Enum):
class EventMetrics: class EventMetrics:
MAX_TASKS_COUNT = 50 MAX_METRICS_COUNT = 100
MAX_METRICS_COUNT = 200 MAX_VARIANTS_COUNT = 100
MAX_VARIANTS_COUNT = 500
MAX_AGGS_ELEMENTS_COUNT = 50 MAX_AGGS_ELEMENTS_COUNT = 50
MAX_SAMPLE_BUCKETS = 6000
def __init__(self, es: Elasticsearch): def __init__(self, es: Elasticsearch):
self.es = es self.es = es
@property
def _max_concurrency(self):
return config.get("services.events.max_metrics_concurrency", 4)
@staticmethod @staticmethod
def get_index_name(company_id, event_type): def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_") event_type = event_type.lower().replace(" ", "_")
@ -51,15 +54,48 @@ class EventMetrics:
The amount of points in each histogram should not exceed The amount of points in each histogram should not exceed
the requested samples the requested samples
""" """
es_index = self.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
return self._run_get_scalar_metrics_as_parallel( return self._get_scalar_average_per_iter_core(
company_id, task_id, es_index, samples, ScalarKey.resolve(key)
task_ids=[task_id],
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average,
) )
def _get_scalar_average_per_iter_core(
self,
task_id: str,
es_index: str,
samples: int,
key: ScalarKey,
run_parallel: bool = True,
) -> dict:
intervals = self._get_task_metric_intervals(
es_index=es_index, task_id=task_id, samples=samples, field=key.field
)
if not intervals:
return {}
interval_groups = self._group_task_metric_intervals(intervals)
get_scalar_average = partial(
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
)
if run_parallel:
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(get_scalar_average, interval_groups)
)
else:
metrics = itertools.chain.from_iterable(
get_scalar_average(group) for group in interval_groups
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
return ret
def compare_scalar_metrics_average_per_iter( def compare_scalar_metrics_average_per_iter(
self, self,
company_id, company_id,
@ -72,12 +108,6 @@ class EventMetrics:
Compare scalar metrics for different tasks per metric and variant Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples The amount of points in each histogram should not exceed the requested samples
""" """
if len(task_ids) > self.MAX_TASKS_COUNT:
raise errors.BadRequest(
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
len(task_ids),
)
task_name_by_id = {} task_name_by_id = {}
with translate_errors_context(): with translate_errors_context():
task_objs = Task.get_many( task_objs = Task.get_many(
@ -90,7 +120,6 @@ class EventMetrics:
if len(task_objs) < len(task_ids): if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs)) invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs} task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.company for t in task_objs} companies = {t.company for t in task_objs}
@ -99,138 +128,95 @@ class EventMetrics:
"only tasks from the same company are supported" "only tasks from the same company are supported"
) )
ret = self._run_get_scalar_metrics_as_parallel( es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
next(iter(companies)),
task_ids=task_ids,
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average_per_task,
)
for metric_data in ret.values():
for variant_data in metric_data.values():
for task_id, task_data in variant_data.items():
task_data["name"] = task_name_by_id[task_id]
return ret
TaskMetric = Tuple[str, str, str]
MetricInterval = Tuple[int, Sequence[TaskMetric]]
MetricData = Tuple[str, dict]
def _split_metrics_by_max_aggs_count(
self, task_metrics: Sequence[TaskMetric]
) -> Iterable[Sequence[TaskMetric]]:
"""
Return task metrics in groups where amount of task metrics in each group
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
variants while always preserving all their tasks in the same group
"""
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
yield task_metrics
return
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
groups = []
for group in tm_grouped.values():
groups.append(group)
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
yield list(itertools.chain(*groups))
groups = []
if groups:
yield list(itertools.chain(*groups))
return
def _run_get_scalar_metrics_as_parallel(
self,
company_id: str,
task_ids: Sequence[str],
samples: int,
key: ScalarKey,
get_func: Callable[
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
],
) -> dict:
"""
Group metrics per interval length and execute get_func for each group in parallel
:param company_id: id of the company
:params task_ids: ids of the tasks to collect data for
:param samples: maximum number of samples per metric
:param get_func: callable that given metric names for the same interval
performs histogram aggregation for the metrics and return the aggregated data
"""
es_index = self.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return {} return {}
intervals = self._get_metric_intervals( get_scalar_average_per_iter = partial(
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field self._get_scalar_average_per_iter_core,
es_index=es_index,
samples=samples,
key=ScalarKey.resolve(key),
run_parallel=False,
) )
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
if not intervals: task_metrics = zip(
return {} task_ids, pool.map(get_scalar_average_per_iter, task_ids)
intervals = list(
itertools.chain.from_iterable(
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
for i, tms in intervals
)
)
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
intervals,
)
) )
ret = defaultdict(dict) res = defaultdict(lambda: defaultdict(dict))
for metric_key, metric_values in metrics: for task_id, task_data in task_metrics:
ret[metric_key].update(metric_values) task_name = task_name_by_id[task_id]
for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name
res[metric_key][variant_key][task_id] = variant_data
return ret return res
def _get_metric_intervals( MetricInterval = Tuple[str, str, int, int]
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter" MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
@classmethod
def _group_task_metric_intervals(
cls, intervals: Sequence[MetricInterval]
) -> Sequence[MetricIntervalGroup]:
"""
Group task metric intervals so that the following conditions are meat:
- All the metrics in the same group have the same interval (with 10% rounding)
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
"""
metric_interval_groups = []
interval_group = []
group_interval_upper_bound = 0
group_max_interval = 0
group_samples = 0
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
if (
interval > group_interval_upper_bound
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
):
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
interval_group = []
group_max_interval = interval
group_interval_upper_bound = interval + int(interval * 0.1)
group_samples = 0
interval_group.append((metric, variant))
group_samples += size
group_max_interval = max(group_max_interval, interval)
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
return metric_interval_groups
def _get_task_metric_intervals(
self, es_index, task_id: str, samples: int, field: str = "iter"
) -> Sequence[MetricInterval]: ) -> Sequence[MetricInterval]:
""" """
Calculate interval per task metric variant so that the resulting Calculate interval per task metric variant so that the resulting
amount of points does not exceed sample. amount of points does not exceed sample.
Return metric variants grouped by interval value with 10% rounding Return the list og metric variant intervals as the following tuple:
For samples==0 return empty list (metric, variant, interval, samples)
""" """
default_intervals = [(1, [])]
if not samples:
return default_intervals
es_req = { es_req = {
"size": 0, "size": 0,
"query": {"terms": {"task": task_ids}}, "query": {"term": {"task": task_id}},
"aggs": { "aggs": {
"tasks": { "metrics": {
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT}, "terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
"aggs": { "aggs": {
"metrics": { "variants": {
"terms": { "terms": {
"field": "metric", "field": "variant",
"size": self.MAX_METRICS_COUNT, "size": self.MAX_VARIANTS_COUNT,
}, },
"aggs": { "aggs": {
"variants": { "count": {"value_count": {"field": field}},
"terms": { "min_index": {"min": {"field": field}},
"field": "variant", "max_index": {"max": {"field": field}},
"size": self.MAX_VARIANTS_COUNT,
},
"aggs": {
"count": {"value_count": {"field": field}},
"min_index": {"min": {"field": field}},
"max_index": {"max": {"field": field}},
},
}
}, },
} }
}, },
@ -239,88 +225,75 @@ class EventMetrics:
} }
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"): with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = self.es.search( es_res = self.es.search(index=es_index, body=es_req)
index=es_index, body=es_req, routing=",".join(task_ids)
)
aggs_result = es_res.get("aggregations") aggs_result = es_res.get("aggregations")
if not aggs_result: if not aggs_result:
return default_intervals return []
intervals = [ return [
( self._build_metric_interval(metric["key"], variant["key"], variant, samples)
task["key"], for metric in aggs_result["metrics"]["buckets"]
metric["key"],
variant["key"],
self._calculate_metric_interval(variant, samples),
)
for task in aggs_result["tasks"]["buckets"]
for metric in task["metrics"]["buckets"]
for variant in metric["variants"]["buckets"] for variant in metric["variants"]["buckets"]
] ]
metric_intervals = []
upper_border = 0
interval_metrics = None
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
if not interval_metrics or interval > upper_border:
interval_metrics = []
metric_intervals.append((interval, interval_metrics))
upper_border = interval + int(interval * 0.1)
interval_metrics.append((task, metric, variant))
return metric_intervals
@staticmethod @staticmethod
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int: def _build_metric_interval(
metric: str, variant: str, data: dict, samples: int
) -> Tuple[str, str, int, int]:
""" """
Calculate index interval per metric_variant variant so that the Calculate index interval per metric_variant variant so that the
total amount of intervals does not exceeds the samples total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
""" """
count = safe_get(metric_variant, "count/value") count = safe_get(data, "count/value", default=0)
if not count or count < samples: if count < samples:
return 1 return metric, variant, 1, count
min_index = safe_get(metric_variant, "min_index/value", default=0) min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(metric_variant, "max_index/value", default=min_index) max_index = safe_get(data, "max_index/value", default=min_index)
return max(1, int(max_index - min_index + 1) // samples) return (
metric,
variant,
max(1, int(max_index - min_index + 1) // samples),
samples,
)
MetricData = Tuple[str, dict]
def _get_scalar_average( def _get_scalar_average(
self, self,
metrics_interval: MetricInterval, metrics_interval: MetricIntervalGroup,
task_ids: Sequence[str], task_id: str,
es_index: str, es_index: str,
key: ScalarKey, key: ScalarKey,
) -> Sequence[MetricData]: ) -> Sequence[MetricData]:
""" """
Retrieve scalar histograms per several metric variants that share the same interval Retrieve scalar histograms per several metric variants that share the same interval
Note: the function works with a single task only
""" """
interval, metrics = metrics_interval
assert len(task_ids) == 1
interval, task_metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval)) aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = { aggs = {
"metrics": { "metrics": {
"terms": { "terms": {
"field": "metric", "field": "metric",
"size": self.MAX_METRICS_COUNT, "size": self.MAX_METRICS_COUNT,
"order": {"_term": "desc"}, "order": {"_key": "desc"},
}, },
"aggs": { "aggs": {
"variants": { "variants": {
"terms": { "terms": {
"field": "variant", "field": "variant",
"size": self.MAX_VARIANTS_COUNT, "size": self.MAX_VARIANTS_COUNT,
"order": {"_term": "desc"}, "order": {"_key": "desc"},
}, },
"aggs": aggregation, "aggs": aggregation,
} }
}, },
} }
} }
aggs_result = self._query_aggregation_for_metrics_and_tasks( aggs_result = self._query_aggregation_for_task_metrics(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics es_index, aggs=aggs, task_id=task_id, metrics=metrics
) )
if not aggs_result: if not aggs_result:
@ -341,61 +314,6 @@ class EventMetrics:
] ]
return metrics return metrics
def _get_scalar_average_per_task(
self,
metrics_interval: MetricInterval,
task_ids: Sequence[str],
es_index: str,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
"""
interval, task_metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
"aggs": {
"variants": {
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
"aggs": {
"tasks": {
"terms": {
"field": "task",
"size": self.MAX_TASKS_COUNT,
},
"aggs": aggregation,
}
},
}
},
}
}
aggs_result = self._query_aggregation_for_metrics_and_tasks(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
)
if not aggs_result:
return {}
metrics = [
(
metric["key"],
{
variant["key"]: {
task["key"]: key.get_iterations_data(task)
for task in variant["tasks"]["buckets"]
}
for variant in metric["variants"]["buckets"]
},
)
for metric in aggs_result["metrics"]["buckets"]
]
return metrics
@staticmethod @staticmethod
def _add_aggregation_average(aggregation): def _add_aggregation_average(aggregation):
average_agg = {"avg_val": {"avg": {"field": "value"}}} average_agg = {"avg_val": {"avg": {"field": "value"}}}
@ -404,69 +322,55 @@ class EventMetrics:
for key, value in aggregation.items() for key, value in aggregation.items()
} }
def _query_aggregation_for_metrics_and_tasks( def _query_aggregation_for_task_metrics(
self, self,
es_index: str, es_index: str,
aggs: dict, aggs: dict,
task_ids: Sequence[str], task_id: str,
task_metrics: Sequence[TaskMetric], metrics: Sequence[Tuple[str, str]],
) -> dict: ) -> dict:
""" """
Return the result of elastic search query for the given aggregation filtered Return the result of elastic search query for the given aggregation filtered
by the given task_ids and metrics by the given task_ids and metrics
""" """
if task_metrics: must = [{"term": {"task": task_id}}]
condition = { if metrics:
"should": [ should = [
self._build_metric_terms(task, metric, variant) {
for task, metric, variant in task_metrics "bool": {
] "must": [
} {"term": {"metric": metric}},
else: {"term": {"variant": variant}},
condition = {"must": [{"terms": {"task": task_ids}}]} ]
}
}
for metric, variant in metrics
]
must.append({"bool": {"should": should}})
es_req = { es_req = {
"size": 0, "size": 0,
"_source": {"excludes": []}, "query": {"bool": {"must": must}},
"query": {"bool": condition},
"aggs": aggs, "aggs": aggs,
"version": True,
} }
with translate_errors_context(), TimingContext("es", "task_stats_scalar"): with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search( es_res = self.es.search(index=es_index, body=es_req)
index=es_index, body=es_req, routing=",".join(task_ids)
)
return es_res.get("aggregations") return es_res.get("aggregations")
@staticmethod
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
"""
Build query term for a metric + variant
"""
return {
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
}
def get_tasks_metrics( def get_tasks_metrics(
self, company_id, task_ids: Sequence, event_type: EventType self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence[Tuple]: ) -> Sequence:
""" """
For the requested tasks return all the metrics that For the requested tasks return all the metrics that
reported events of the requested types reported events of the requested types
""" """
es_index = EventMetrics.get_index_name(company_id, event_type.value) es_index = EventMetrics.get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index): if not self.es.indices.exists(es_index):
return [(tid, []) for tid in task_ids] return {}
max_concurrency = config.get("services.events.max_metrics_concurrency", 4) with ThreadPoolExecutor(self._max_concurrency) as pool:
with ThreadPoolExecutor(max_concurrency) as pool:
res = pool.map( res = pool.map(
partial( partial(
self._get_task_metrics, es_index=es_index, event_type=event_type, self._get_task_metrics, es_index=es_index, event_type=event_type,
@ -494,7 +398,7 @@ class EventMetrics:
} }
with translate_errors_context(), TimingContext("es", "_get_task_metrics"): with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id) es_res = self.es.search(index=es_index, body=es_req)
return [ return [
metric["key"] metric["key"]

View File

@ -71,9 +71,9 @@ class LogEventsIterator:
es_req["search_after"] = [from_timestamp] es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"): with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req, routing=task_id) es_result = self.es.search(index=es_index, body=es_req)
hits = es_result["hits"]["hits"] hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"] hits_total = es_result["hits"]["total"]["value"]
if not hits: if not hits:
return [], hits_total return [], hits_total
@ -92,7 +92,7 @@ class LogEventsIterator:
} }
}, },
} }
es_result = self.es.search(index=es_index, body=es_req, routing=task_id) es_result = self.es.search(index=es_index, body=es_req)
hits = es_result["hits"]["hits"] hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2: if not hits or len(hits) < 2:
# if only one element is returned for the last timestamp # if only one element is returned for the last timestamp

View File

@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
self.name: { self.name: {
"date_histogram": { "date_histogram": {
"field": "timestamp", "field": "timestamp",
"interval": f"{interval}ms", "fixed_interval": f"{interval}ms",
"min_doc_count": 1, "min_doc_count": 1,
} }
} }
@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
self.name: { self.name: {
"date_histogram": { "date_histogram": {
"field": "timestamp", "field": "timestamp",
"interval": f"{interval}ms", "fixed_interval": f"{interval}ms",
"min_doc_count": 1, "min_doc_count": 1,
"format": "strict_date_time", "format": "strict_date_time",
} }

View File

@ -18,7 +18,6 @@ log = config.logger(__file__)
class QueueMetrics: class QueueMetrics:
class EsKeys: class EsKeys:
DOC_TYPE = "metrics"
WAITING_TIME_FIELD = "average_waiting_time" WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length" QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp" TIMESTAMP_FIELD = "timestamp"
@ -66,7 +65,6 @@ class QueueMetrics:
entries = [e for e in queue.entries if e.added] entries = [e for e in queue.entries if e.added]
return dict( return dict(
_index=es_index, _index=es_index,
_type=self.EsKeys.DOC_TYPE,
_source={ _source={
self.EsKeys.TIMESTAMP_FIELD: timestamp, self.EsKeys.TIMESTAMP_FIELD: timestamp,
self.EsKeys.QUEUE_FIELD: queue.id, self.EsKeys.QUEUE_FIELD: queue.id,
@ -93,7 +91,6 @@ class QueueMetrics:
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict: def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
return self.es.search( return self.es.search(
index=f"{self._queue_metrics_prefix_for_company(company_id)}*", index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
doc_type=self.EsKeys.DOC_TYPE,
body=es_req, body=es_req,
) )
@ -109,7 +106,7 @@ class QueueMetrics:
"dates": { "dates": {
"date_histogram": { "date_histogram": {
"field": cls.EsKeys.TIMESTAMP_FIELD, "field": cls.EsKeys.TIMESTAMP_FIELD,
"interval": f"{interval}s", "fixed_interval": f"{interval}s",
"min_doc_count": 1, "min_doc_count": 1,
}, },
"aggs": { "aggs": {

View File

@ -237,7 +237,6 @@ class StatisticsReporter:
def _run_worker_stats_query(cls, company_id, es_req) -> dict: def _run_worker_stats_query(cls, company_id, es_req) -> dict:
return worker_bll.es_client.search( return worker_bll.es_client.search(
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*", index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req, body=es_req,
) )

View File

@ -35,14 +35,21 @@ class SetFieldsResolver:
SET_MODIFIERS = ("min", "max") SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]): def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = set_fields self.orig_fields = {}
self.fields = { self.fields = {}
f: fname self.add_fields(**set_fields)
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys() def add_fields(self, **set_fields: Any):
) self.orig_fields.update(set_fields)
if dunder and modifier in self.SET_MODIFIERS self.fields.update(
} {
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
)
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str: def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None: if name in self.fields and doc.get_field_value(self.fields[name]) is None:

View File

@ -21,6 +21,7 @@ from config import config
from database.errors import translate_errors_context from database.errors import translate_errors_context
from database.model.auth import User from database.model.auth import User
from database.model.company import Company from database.model.company import Company
from database.model.project import Project
from database.model.queue import Queue from database.model.queue import Queue
from database.model.task.task import Task from database.model.task.task import Task
from redis_manager import redman from redis_manager import redman
@ -146,6 +147,7 @@ class WorkerBLL:
if not report.task: if not report.task:
entry.task = None entry.task = None
entry.project = None
else: else:
with translate_errors_context(): with translate_errors_context():
query = dict(id=report.task, company=company_id) query = dict(id=report.task, company=company_id)
@ -160,6 +162,12 @@ class WorkerBLL:
raise bad_request.InvalidTaskId(**query) raise bad_request.InvalidTaskId(**query)
entry.task = IdNameEntry(id=task.id, name=task.name) entry.task = IdNameEntry(id=task.id, name=task.name)
entry.project = None
if task.project:
project = Project.objects(id=task.project).only("name").first()
if project:
entry.project = IdNameEntry(id=project.id, name=project.name)
entry.last_report_time = now entry.last_report_time = now
except APIError: except APIError:
raise raise
@ -369,7 +377,6 @@ class WorkerBLL:
def make_doc(category, metric, variant, value) -> dict: def make_doc(category, metric, variant, value) -> dict:
return dict( return dict(
_index=es_index, _index=es_index,
_type="stat",
_source=dict( _source=dict(
timestamp=timestamp, timestamp=timestamp,
worker=worker, worker=worker,

View File

@ -25,7 +25,6 @@ class WorkerStats:
def _search_company_stats(self, company_id: str, es_req: dict) -> dict: def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search( return self.es.search(
index=f"{self.worker_stats_prefix_for_company(company_id)}*", index=f"{self.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req, body=es_req,
) )
@ -53,7 +52,7 @@ class WorkerStats:
res = self._search_company_stats(company_id, es_req) res = self._search_company_stats(company_id, es_req)
if not res["hits"]["total"]: if not res["hits"]["total"]["value"]:
raise bad_request.WorkerStatsNotFound( raise bad_request.WorkerStatsNotFound(
f"No statistic metrics found for the company {company_id} and workers {worker_ids}" f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
) )
@ -87,7 +86,7 @@ class WorkerStats:
"dates": { "dates": {
"date_histogram": { "date_histogram": {
"field": "timestamp", "field": "timestamp",
"interval": f"{request.interval}s", "fixed_interval": f"{request.interval}s",
"min_doc_count": 1, "min_doc_count": 1,
}, },
"aggs": { "aggs": {
@ -216,7 +215,7 @@ class WorkerStats:
"dates": { "dates": {
"date_histogram": { "date_histogram": {
"field": "timestamp", "field": "timestamp",
"interval": f"{interval}s", "fixed_interval": f"{interval}s",
}, },
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}}, "aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
} }

View File

@ -30,7 +30,7 @@
enabled: false enabled: false
zip_files: ["/path/to/export.zip"] zip_files: ["/path/to/export.zip"]
fail_on_error: false fail_on_error: false
artifacts_path: "/mnt/fileserver" # artifacts_path: "/mnt/fileserver"
} }
# time in seconds to take an exclusive lock to init es and mongodb # time in seconds to take an exclusive lock to init es and mongodb

View File

@ -1,6 +1,6 @@
elastic { elastic {
events { events {
hosts: [{host: "127.0.0.1", port: 9200}] hosts: [{host: "127.0.0.1", port: 9211}]
args { args {
timeout: 60 timeout: 60
dead_timeout: 10 dead_timeout: 10
@ -11,7 +11,7 @@ elastic {
} }
workers { workers {
hosts: [{host:"127.0.0.1", port:9200}] hosts: [{host:"127.0.0.1", port:9211}]
args { args {
timeout: 60 timeout: 60
dead_timeout: 10 dead_timeout: 10

View File

@ -0,0 +1,16 @@
fixed_users {
guest {
enabled: false
default_company: "025315a9321f49f8be07f5ac48fbcf92"
name: "Guest"
username: "guest"
password: "guest"
# Allow access only to the following endpoints when using user/pass credentials
allow_endpoints: [
"auth.login"
]
}
}

View File

@ -0,0 +1,8 @@
# Order of featured projects, by name or ID
featured_order: [
# {id: "<project-id>"}
# OR
# {name: "<project-name>"}
# OR
# {name_regex: "<python-regex>"}
]

View File

@ -41,3 +41,6 @@ def get_deployment_type() -> str:
def get_default_company(): def get_default_company():
return config.get("apiserver.default_company") return config.get("apiserver.default_company")
missed_es_upgrade = False

View File

@ -32,6 +32,8 @@ class Role(object):
""" Company user """ """ Company user """
annotator = "annotator" annotator = "annotator"
""" Annotator with limited access""" """ Annotator with limited access"""
guest = "guest"
""" Guest user. Read Only."""
@classmethod @classmethod
def get_system_roles(cls) -> set: def get_system_roles(cls) -> set:

View File

@ -1,7 +1,7 @@
import re import re
from collections import namedtuple from collections import namedtuple
from functools import reduce from functools import reduce
from typing import Collection, Sequence, Union, Optional from typing import Collection, Sequence, Union, Optional, Type
from boltons.iterutils import first, bucketize from boltons.iterutils import first, bucketize
from dateutil.parser import parse as parse_datetime from dateutil.parser import parse as parse_datetime
@ -9,6 +9,7 @@ from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor from pymongo.command_cursor import CommandCursor
from apierrors import errors from apierrors import errors
from apierrors.base import BaseError
from config import config from config import config
from database.errors import MakeGetAllQueryError from database.errors import MakeGetAllQueryError
from database.projection import project_dict, ProjectionHelper from database.projection import project_dict, ProjectionHelper
@ -483,6 +484,21 @@ class GetMixin(PropsMixin):
query=_query, parameters=parameters, override_projection=override_projection query=_query, parameters=parameters, override_projection=override_projection
) )
@classmethod
def get_many_public(
cls, query: Q = None, projection: Collection[str] = None,
):
"""
Fetch all public documents matching a provided query.
:param query: Optional query object (mongoengine.Q).
:param projection: A list of projection fields.
:return: A list of documents matching the query.
"""
q = get_company_or_none_constraint()
_query = (q & query) if query else q
return cls._get_many_no_company(query=_query, override_projection=projection)
@classmethod @classmethod
def _get_many_no_company( def _get_many_no_company(
cls: Union["GetMixin", Document], cls: Union["GetMixin", Document],
@ -728,6 +744,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
) )
return cls.objects.aggregate(pipeline, **kwargs) return cls.objects.aggregate(pipeline, **kwargs)
@classmethod
def set_public(
cls: Type[Document],
company_id: str,
ids: Sequence[str],
invalid_cls: Type[BaseError],
enabled: bool = True,
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, unset__company=1)
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
).only("id")
)
update = dict(set__company=company_id, unset__company_origin=1)
if len(items) < len(ids):
missing = tuple(set(ids).difference(i.id for i in items))
raise invalid_cls(ids=missing)
return {"updated": cls.objects(id__in=ids).update(**update)}
def validate_id(cls, company, **kwargs): def validate_id(cls, company, **kwargs):
""" """

View File

@ -72,3 +72,4 @@ class Model(DbModelMixin, Document):
ui_cache = SafeDictField( ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True default=dict, user_set_allowed=True, exclude_by_default=True
) )
company_origin = StringField(exclude_by_default=True)

View File

@ -1,4 +1,4 @@
from mongoengine import StringField, DateTimeField from mongoengine import StringField, DateTimeField, IntField
from database import Database, strict from database import Database, strict
from database.fields import StrippedStringField, SafeSortedListField from database.fields import StrippedStringField, SafeSortedListField
@ -40,3 +40,7 @@ class Project(AttributedDocument):
system_tags = SafeSortedListField(StringField(required=True)) system_tags = SafeSortedListField(StringField(required=True))
default_output_destination = StrippedStringField() default_output_destination = StrippedStringField()
last_update = DateTimeField() last_update = DateTimeField()
featured = IntField(default=9999)
logo_url = StringField()
logo_blob = StringField(exclude_by_default=True)
company_origin = StringField(exclude_by_default=True)

View File

@ -118,7 +118,7 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument): class Task(AttributedDocument):
_field_collation_overrides = { _field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True}, "execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True} "last_metrics.": {"locale": "en_US", "numericOrdering": True},
} }
meta = { meta = {
@ -194,3 +194,5 @@ class Task(AttributedDocument):
last_iteration = IntField(default=DEFAULT_LAST_ITERATION) last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent))) last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats)) metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds

View File

@ -4,9 +4,9 @@ Apply elasticsearch mappings to given hosts.
""" """
import argparse import argparse
import json import json
import requests
from pathlib import Path from pathlib import Path
import requests
from requests.adapters import HTTPAdapter from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry from requests.packages.urllib3.util.retry import Retry
@ -14,21 +14,24 @@ HERE = Path(__file__).resolve().parent
session = requests.Session() session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5)) adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
session.mount('http://', adapter) session.mount("http://", adapter)
def get_template(host: str, template) -> dict:
url = f"{host}/_template/{template}"
res = session.get(url)
return res.json()
def apply_mappings_to_host(host: str): def apply_mappings_to_host(host: str):
def _send_mapping(f): def _send_mapping(f):
with f.open() as json_data: with f.open() as json_data:
data = json.load(json_data) data = json.load(json_data)
es_server = host url = f"{host}/_template/{f.stem}"
url = f"{es_server}/_template/{f.stem}"
session.delete(url) session.delete(url)
r = session.post( r = session.post(
url, url, headers={"Content-Type": "application/json"}, data=json.dumps(data)
headers={"Content-Type": "application/json"},
data=json.dumps(data),
) )
return {"mapping": f.stem, "result": r.text} return {"mapping": f.stem, "result": r.text}
@ -47,7 +50,8 @@ def parse_args():
def main(): def main():
for host in parse_args().hosts: args = parse_args()
for host in args.hosts:
print(">>>>> Applying mapping to " + host) print(">>>>> Applying mapping to " + host)
res = apply_mappings_to_host(host) res = apply_mappings_to_host(host)
print(res) print(res)

View File

@ -1,7 +1,7 @@
from furl import furl from furl import furl
from config import config from config import config
from elastic.apply_mappings import apply_mappings_to_host from elastic.apply_mappings import apply_mappings_to_host, get_template
from es_factory import get_cluster_config from es_factory import get_cluster_config
log = config.logger(__file__) log = config.logger(__file__)
@ -15,13 +15,22 @@ class MissingElasticConfiguration(Exception):
pass pass
def init_es_data(): def _url_from_host_conf(conf: dict) -> str:
return furl(scheme="http", host=conf["host"], port=conf["port"]).url
def init_es_data() -> bool:
"""Return True if the db was empty"""
hosts_config = get_cluster_config("events").get("hosts") hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config: if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'") raise MissingElasticConfiguration("for cluster 'events'")
empty_db = not get_template(_url_from_host_conf(hosts_config[0]), "events*")
for conf in hosts_config: for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url host = _url_from_host_conf(conf)
log.info(f"Applying mappings to host: {host}") log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host) res = apply_mappings_to_host(host)
log.info(res) log.info(res)
return empty_db

View File

@ -1,26 +1,39 @@
{ {
"template": "events-*", "index_patterns": "events-*",
"settings": { "settings": {
"number_of_shards": 1 "number_of_shards": 1
}, },
"mappings": { "mappings": {
"_default_": { "_source": {
"_source": { "enabled": true
"enabled": true },
"properties": {
"@timestamp": {
"type": "date"
}, },
"_routing": { "task": {
"required": true "type": "keyword"
}, },
"properties": { "type": {
"@timestamp": { "type": "date" }, "type": "keyword"
"task": { "type": "keyword" }, },
"type": { "type": "keyword" }, "worker": {
"worker": { "type": "keyword" }, "type": "keyword"
"timestamp": { "type": "date" }, },
"iter": { "type": "long" }, "timestamp": {
"metric": { "type": "keyword" }, "type": "date"
"variant": { "type": "keyword" }, },
"value": { "type": "float" } "iter": {
"type": "long"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
} }
} }
} }

View File

@ -1,11 +1,14 @@
{ {
"template": "events-log-*", "index_patterns": "events-log-*",
"order" : 1, "order": 1,
"mappings": { "mappings": {
"_default_": { "properties": {
"properties": { "msg": {
"msg": { "type":"text", "index": false }, "type": "text",
"level": { "type":"keyword" } "index": false
},
"level": {
"type": "keyword"
} }
} }
} }

View File

@ -1,10 +1,11 @@
{ {
"template": "events-plot-*", "index_patterns": "events-plot-*",
"order" : 1, "order": 1,
"mappings": { "mappings": {
"_default_": { "properties": {
"properties": { "plot_str": {
"plot_str": { "type":"text", "index": false } "type": "text",
"index": false
} }
} }
} }

View File

@ -1,11 +1,13 @@
{ {
"template": "events-training_debug_image-*", "index_patterns": "events-training_debug_image-*",
"order" : 1, "order": 1,
"mappings": { "mappings": {
"_default_": { "properties": {
"properties": { "key": {
"key": { "type": "keyword" }, "type": "keyword"
"url": { "type": "keyword" } },
"url": {
"type": "keyword"
} }
} }
} }

View File

@ -1,26 +1,24 @@
{ {
"template": "queue_metrics_*", "index_patterns": "queue_metrics_*",
"settings": { "settings": {
"number_of_shards": 1 "number_of_shards": 1
}, },
"mappings": { "mappings": {
"metrics": { "_source": {
"_source": { "enabled": true
"enabled": true },
"properties": {
"timestamp": {
"type": "date"
}, },
"properties": { "queue": {
"timestamp": { "type": "keyword"
"type": "date" },
}, "average_waiting_time": {
"queue": { "type": "float"
"type": "keyword" },
}, "queue_length": {
"average_waiting_time": { "type": "integer"
"type": "float"
},
"queue_length": {
"type": "integer"
}
} }
} }
} }

View File

@ -1,22 +1,36 @@
{ {
"template": "worker_stats_*", "index_patterns": "worker_stats_*",
"settings": { "settings": {
"number_of_shards": 1 "number_of_shards": 1
}, },
"mappings": { "mappings": {
"stat": { "_source": {
"_source": { "enabled": true
"enabled": true },
"properties": {
"timestamp": {
"type": "date"
}, },
"properties": { "worker": {
"timestamp": { "type": "date" }, "type": "keyword"
"worker": { "type": "keyword" }, },
"category": { "type": "keyword" }, "category": {
"metric": { "type": "keyword" }, "type": "keyword"
"variant": { "type": "keyword" }, },
"value": { "type": "float" }, "metric": {
"unit": { "type": "keyword" }, "type": "keyword"
"task": { "type": "keyword" } },
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"unit": {
"type": "keyword"
},
"task": {
"type": "keyword"
} }
} }
} }

View File

@ -24,14 +24,9 @@ def _pre_populate(company_id: str, zip_file: str):
else: else:
log.info(f"Pre-populating using {zip_file}") log.info(f"Pre-populating using {zip_file}")
user_id = _ensure_backend_user(
"__allegroai__", company_id, "Allegro.ai"
)
PrePopulate.import_from_zip( PrePopulate.import_from_zip(
zip_file, zip_file,
company_id="", company_id="",
user_id=user_id,
artifacts_path=config.get( artifacts_path=config.get(
"apiserver.pre_populate.artifacts_path", None "apiserver.pre_populate.artifacts_path", None
), ),
@ -60,7 +55,7 @@ def init_mongo_data() -> bool:
_ensure_uuid() _ensure_uuid()
company_id = _ensure_company(log) company_id = _ensure_company(get_default_company(), "trains", log)
_ensure_default_queue(company_id) _ensure_default_queue(company_id)
@ -82,9 +77,13 @@ def init_mongo_data() -> bool:
if fixed_mode: if fixed_mode:
log.info("Fixed users mode is enabled") log.info("Fixed users mode is enabled")
FixedUser.validate() FixedUser.validate()
if FixedUser.guest_enabled():
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
for user in FixedUser.from_config(): for user in FixedUser.from_config():
try: try:
ensure_fixed_user(user, company_id, log=log) ensure_fixed_user(user, log=log)
except Exception as ex: except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}") log.error(f"Failed creating fixed user {user.name}: {ex}")

View File

@ -1,30 +1,44 @@
import hashlib import hashlib
import importlib import importlib
import os import os
import re
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial
from io import BytesIO from io import BytesIO
from itertools import chain from itertools import chain
from operator import attrgetter from operator import attrgetter
from os.path import splitext from os.path import splitext
from pathlib import Path from pathlib import Path
from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union from typing import (
Optional,
Any,
Type,
Set,
Dict,
Sequence,
Tuple,
BinaryIO,
Union,
Mapping,
)
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2 from zipfile import ZipFile, ZIP_BZIP2
import attr
import mongoengine import mongoengine
from boltons.iterutils import chunked_iter from boltons.iterutils import chunked_iter
from furl import furl from furl import furl
from mongoengine import Q from mongoengine import Q
from bll.event import EventBLL from bll.event import EventBLL
from config import config
from database.model import EntityVisibility from database.model import EntityVisibility
from database.model.model import Model from database.model.model import Model
from database.model.project import Project from database.model.project import Project
from database.model.task.task import Task, ArtifactModes, TaskStatus from database.model.task.task import Task, ArtifactModes, TaskStatus
from database.utils import get_options from database.utils import get_options
from utilities import json from utilities import json
from .user import _ensure_backend_user
class PrePopulate: class PrePopulate:
@ -32,6 +46,7 @@ class PrePopulate:
events_file_suffix = "_events" events_file_suffix = "_events"
export_tag_prefix = "Exported:" export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S" export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
class JsonLinesWriter: class JsonLinesWriter:
def __init__(self, file: BinaryIO): def __init__(self, file: BinaryIO):
@ -54,26 +69,21 @@ class PrePopulate:
self._write("\n" + line) self._write("\n" + line)
self.empty = False self.empty = False
@attr.s(auto_attribs=True)
class _MapData:
files: Sequence[str] = None
entities: Dict[str, datetime] = None
@staticmethod @staticmethod
def _get_last_update_time(entity) -> datetime: def _get_last_update_time(entity) -> datetime:
return getattr(entity, "last_update", None) or getattr(entity, "created") return getattr(entity, "last_update", None) or getattr(entity, "created")
@classmethod @classmethod
def _check_for_update( def _check_for_update(
cls, map_file: Path, entities: dict cls, map_file: Path, entities: dict, metadata_hash: str
) -> Tuple[bool, Sequence[str]]: ) -> Tuple[bool, Sequence[str]]:
if not map_file.is_file(): if not map_file.is_file():
return True, [] return True, []
files = [] files = []
try: try:
map_data = cls._MapData(**json.loads(map_file.read_text())) map_data = json.loads(map_file.read_text())
files = map_data.files files = map_data.get("files", [])
for file in files: for file in files:
if not Path(file).is_file(): if not Path(file).is_file():
return True, files return True, files
@ -82,7 +92,7 @@ class PrePopulate:
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc) item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
for item in chain.from_iterable(entities.values()) for item in chain.from_iterable(entities.values())
} }
old_times = map_data.entities old_times = map_data.get("entities", {})
if set(new_times.keys()) != set(old_times.keys()): if set(new_times.keys()) != set(old_times.keys()):
return True, files return True, files
@ -90,6 +100,10 @@ class PrePopulate:
for id_, new_timestamp in new_times.items(): for id_, new_timestamp in new_times.items():
if new_timestamp != old_times[id_]: if new_timestamp != old_times[id_]:
return True, files return True, files
if metadata_hash != map_data.get("metadata_hash", ""):
return True, files
except Exception as ex: except Exception as ex:
print("Error reading map file. " + str(ex)) print("Error reading map file. " + str(ex))
return True, files return True, files
@ -98,16 +112,24 @@ class PrePopulate:
@classmethod @classmethod
def _write_update_file( def _write_update_file(
cls, map_file: Path, entities: dict, created_files: Sequence[str] cls,
map_file: Path,
entities: dict,
created_files: Sequence[str],
metadata_hash: str,
): ):
map_data = cls._MapData( map_file.write_text(
files=created_files, json.dumps(
entities={ dict(
entity.id: cls._get_last_update_time(entity) files=created_files,
for entity in chain.from_iterable(entities.values()) entities={
}, entity.id: cls._get_last_update_time(entity)
for entity in chain.from_iterable(entities.values())
},
metadata_hash=metadata_hash,
)
)
) )
map_file.write_text(json.dumps(attr.asdict(map_data)))
@staticmethod @staticmethod
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]: def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
@ -117,7 +139,9 @@ class PrePopulate:
return True return True
if a.startswith("http"): if a.startswith("http"):
parsed = urlparse(a) parsed = urlparse(a)
if parsed.scheme in {"http", "https"} and parsed.port == 8081: if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
"8081"
):
return True return True
return False return False
@ -137,6 +161,7 @@ class PrePopulate:
artifacts_path: str = None, artifacts_path: str = None,
task_statuses: Sequence[str] = None, task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False, tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
) -> Sequence[str]: ) -> Sequence[str]:
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)): if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses") raise ValueError("Invalid task statuses")
@ -146,11 +171,22 @@ class PrePopulate:
experiments=experiments, projects=projects, task_statuses=task_statuses experiments=experiments, projects=projects, task_statuses=task_statuses
) )
hash_ = hashlib.md5()
if metadata:
meta_str = json.dumps(metadata)
hash_.update(meta_str.encode())
metadata_hash = hash_.hexdigest()
else:
meta_str, metadata_hash = "", ""
map_file = file.with_suffix(".map") map_file = file.with_suffix(".map")
updated, old_files = cls._check_for_update(map_file, entities) updated, old_files = cls._check_for_update(
map_file, entities=entities, metadata_hash=metadata_hash
)
if not updated: if not updated:
print(f"There are no updates from the last export") print(f"There are no updates from the last export")
return old_files return old_files
for old in old_files: for old in old_files:
old_path = Path(old) old_path = Path(old)
if old_path.is_file(): if old_path.is_file():
@ -158,10 +194,16 @@ class PrePopulate:
zip_args = dict(mode="w", compression=ZIP_BZIP2) zip_args = dict(mode="w", compression=ZIP_BZIP2)
with ZipFile(file, **zip_args) as zfile: with ZipFile(file, **zip_args) as zfile:
artifacts, hash_ = cls._export( if metadata:
zfile, entities, tag_entities=tag_exported_entities zfile.writestr(cls.metadata_filename, meta_str)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
) )
file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}")
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
file.replace(file_with_hash) file.replace(file_with_hash)
created_files = [str(file_with_hash)] created_files = [str(file_with_hash)]
@ -172,16 +214,43 @@ class PrePopulate:
cls._export_artifacts(zfile, artifacts, artifacts_path) cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file)) created_files.append(str(artifacts_file))
cls._write_update_file(map_file, entities, created_files) cls._write_update_file(
map_file,
entities=entities,
created_files=created_files,
metadata_hash=metadata_hash,
)
return created_files return created_files
@classmethod @classmethod
def import_from_zip( def import_from_zip(
cls, filename: str, company_id: str, user_id: str, artifacts_path: str cls,
filename: str,
company_id: str,
artifacts_path: str,
user_id: str = "",
user_name: str = "",
): ):
metadata = None
with ZipFile(filename) as zfile: with ZipFile(filename) as zfile:
cls._import(zfile, company_id, user_id) try:
with zfile.open(cls.metadata_filename) as f:
metadata = json.loads(f.read())
if not user_id:
meta_user_id = metadata.get("user_id", "")
meta_user_name = metadata.get("user_name", "")
user_id, user_name = meta_user_id, meta_user_name
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
user_id = _ensure_backend_user(user_id, company_id, user_name)
cls._import(zfile, company_id, user_id, metadata)
if artifacts_path and os.path.isdir(artifacts_path): if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(".artifacts") artifacts_file = Path(filename).with_suffix(".artifacts")
@ -190,6 +259,24 @@ class PrePopulate:
with ZipFile(artifacts_file) as zfile: with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path) zfile.extractall(artifacts_path)
@classmethod
def update_featured_projects_order(cls):
featured_order = config.get("services.projects.featured_order", [])
def get_index(p: Project):
for index, entry in enumerate(featured_order):
if (
entry.get("id", None) == p.id
or entry.get("name", None) == p.name
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
):
return index
return 999
for project in Project.get_many_public(projection=["id", "name"]):
featured_index = get_index(project)
Project.objects(id=project.id).update(featured=featured_index)
@staticmethod @staticmethod
def _resolve_type( def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]] cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
@ -389,15 +476,14 @@ class PrePopulate:
@classmethod @classmethod
def _export( def _export(
cls, writer: ZipFile, entities: dict, tag_entities: bool = False cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
) -> Tuple[Sequence[str], str]: ) -> Sequence[str]:
""" """
Export the requested experiments, projects and models and return the list of artifact files Export the requested experiments, projects and models and return the list of artifact files
Always do the export on sorted items since the order of items influence hash Always do the export on sorted items since the order of items influence hash
""" """
artifacts = [] artifacts = []
now = datetime.utcnow() now = datetime.utcnow()
hash_ = hashlib.md5()
for cls_ in sorted(entities, key=attrgetter("__name__")): for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("id")) items = sorted(entities[cls_], key=attrgetter("id"))
if not items: if not items:
@ -423,7 +509,7 @@ class PrePopulate:
if tag_entities: if tag_entities:
cls._add_tag(items, now.strftime(cls.export_tag)) cls._add_tag(items, now.strftime(cls.export_tag))
return artifacts, hash_.hexdigest() return artifacts
@staticmethod @staticmethod
def json_lines(file: BinaryIO): def json_lines(file: BinaryIO):
@ -441,7 +527,13 @@ class PrePopulate:
yield clean yield clean
@classmethod @classmethod
def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None): def _import(
cls,
reader: ZipFile,
company_id: str = "",
user_id: str = None,
metadata: Mapping[str, Any] = None,
):
""" """
Import entities and events from the zip file Import entities and events from the zip file
Start from entities since event import will require the tasks already in DB Start from entities since event import will require the tasks already in DB
@ -451,12 +543,13 @@ class PrePopulate:
fi fi
for fi in reader.filelist for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending) if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
) )
event_files = ( event_files = (
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending) fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
) )
for files, reader_func in ( for files, reader_func in (
(entity_files, cls._import_entity), (entity_files, partial(cls._import_entity, metadata=metadata or {})),
(event_files, cls._import_events), (event_files, cls._import_events),
): ):
for file_info in files: for file_info in files:
@ -466,11 +559,20 @@ class PrePopulate:
reader_func(f, full_name, company_id, user_id) reader_func(f, full_name, company_id, user_id)
@classmethod @classmethod
def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str): def _import_entity(
cls,
f: BinaryIO,
full_name: str,
company_id: str,
user_id: str,
metadata: Mapping[str, Any],
):
module_name, _, class_name = full_name.rpartition(".") module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name) cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database") print(f"Writing {cls_.__name__.lower()}s into database")
override_project_count = 0
for item in cls.json_lines(f): for item in cls.json_lines(f):
doc = cls_.from_json(item, created=True) doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"): if hasattr(doc, "user"):
@ -478,10 +580,24 @@ class PrePopulate:
if hasattr(doc, "company"): if hasattr(doc, "company"):
doc.company = company_id doc.company = company_id
if isinstance(doc, Project): if isinstance(doc, Project):
override_project_name = metadata.get("project_name", None)
if override_project_name:
if override_project_count:
override_project_name = (
f"{override_project_name} {override_project_count + 1}"
)
override_project_count += 1
doc.name = override_project_name
doc.logo_url = metadata.get("logo_url", None)
doc.logo_blob = metadata.get("logo_blob", None)
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update( cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}" set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
) )
doc.save() doc.save()
if isinstance(doc, Task): if isinstance(doc, Task):
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True) cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)

View File

@ -58,15 +58,15 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
return user_id return user_id
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger): def ensure_fixed_user(user: FixedUser, log: Logger):
if User.objects(id=user.user_id).first(): if User.objects(company=user.company, id=user.user_id).first():
return return
data = attr.asdict(user) data = attr.asdict(user)
data["id"] = user.user_id data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com" data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user data["role"] = Role.guest if user.is_guest else Role.user
_ensure_auth_user(user_data=data, company_id=company_id, log=log) _ensure_auth_user(user_data=data, company_id=user.company, log=log)
return _ensure_backend_user(user.user_id, company_id, user.name) return _ensure_backend_user(user.user_id, user.company, user.name)

View File

@ -3,7 +3,6 @@ from uuid import uuid4
from bll.queue import QueueBLL from bll.queue import QueueBLL
from config import config from config import config
from config.info import get_default_company
from database.model.company import Company from database.model.company import Company
from database.model.queue import Queue from database.model.queue import Queue
from database.model.settings import Settings, SettingKeys from database.model.settings import Settings, SettingKeys
@ -11,13 +10,11 @@ from database.model.settings import Settings, SettingKeys
log = config.logger(__file__) log = config.logger(__file__)
def _ensure_company(log: Logger): def _ensure_company(company_id, company_name, log: Logger):
company_id = get_default_company()
company = Company.objects(id=company_id).only("id").first() company = Company.objects(id=company_id).only("id").first()
if company: if company:
return company_id return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}") log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name) company = Company(id=company_id, name=company_name)
company.save() company.save()

View File

@ -1,7 +1,8 @@
attrs>=19.1.0 attrs>=19.1.0
boltons>=19.1.0 boltons>=19.1.0
boto3==1.14.13
dpath>=1.4.2,<2.0 dpath>=1.4.2,<2.0
elasticsearch>=5.0.0,<6.0.0 elasticsearch>=7.0.0,<8.0.0
fastjsonschema>=2.8 fastjsonschema>=2.8
Flask-Compress>=1.4.0 Flask-Compress>=1.4.0
Flask-Cors>=3.0.5 Flask-Cors>=3.0.5
@ -24,7 +25,7 @@ python-rapidjson>=0.6.3
redis>=2.10.5 redis>=2.10.5
related>=0.7.2 related>=0.7.2
requests>=2.13.0 requests>=2.13.0
semantic_version>=2.8.0,<3 semantic_version>=2.8.3,<3
six six
tqdm tqdm
validators>=0.12.4 validators>=0.12.4

View File

@ -328,6 +328,9 @@ fixed_users_mode {
description: "Fixed users mode enabled" description: "Fixed users mode enabled"
type: boolean type: boolean
} }
migration_warning {
type: boolean
}
} }
} }
} }

View File

@ -848,7 +848,7 @@
description: "Task ID" description: "Task ID"
} }
samples { samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000." description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
type: integer type: integer
} }
key { key {
@ -886,7 +886,7 @@
] ]
properties { properties {
tasks { tasks {
description: "List of task Task IDs" description: "List of task Task IDs. Maximum amount of tasks is 10"
type: array type: array
items { items {
type: string type: string
@ -894,7 +894,7 @@
} }
} }
samples { samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000." description: "The amount of histogram points to return. Optional, the default value is 6000"
type: integer type: integer
} }
key { key {

File diff suppressed because it is too large Load Diff

View File

@ -573,6 +573,7 @@ get_hyper_parameters {
} }
} }
} }
get_task_tags { get_task_tags {
"2.8" { "2.8" {
description: "Get user and system tags used for the tasks under the specified projects" description: "Get user and system tags used for the tasks under the specified projects"
@ -580,6 +581,7 @@ get_task_tags {
response = ${_definitions.tags_response} response = ${_definitions.tags_response}
} }
} }
get_model_tags { get_model_tags {
"2.8" { "2.8" {
description: "Get user and system tags used for the models under the specified projects" description: "Get user and system tags used for the models under the specified projects"
@ -587,3 +589,53 @@ get_model_tags {
response = ${_definitions.tags_response} response = ${_definitions.tags_response}
} }
} }
make_public {
"2.9" {
description: """Convert company projects to public"""
request {
type: object
properties {
ids {
description: "Ids of the projects to convert"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated"
type: integer
}
}
}
}
}
make_private {
"2.9" {
description: """Convert public projects to private"""
request {
type: object
properties {
ids {
description: "Ids of the projects to convert. Only the projects originated by the company can be converted"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated"
type: integer
}
}
}
}
}

View File

@ -1442,3 +1442,53 @@ add_or_update_artifacts {
} }
} }
} }
make_public {
"2.9" {
description: """Convert company tasks to public"""
request {
type: object
properties {
ids {
description: "Ids of the tasks to convert"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated"
type: integer
}
}
}
}
}
make_private {
"2.9" {
description: """Convert public tasks to private"""
request {
type: object
properties {
ids {
description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated"
type: integer
}
}
}
}
}

View File

@ -135,6 +135,10 @@
description: "Task currently being run by the worker" description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry" "$ref": "#/definitions/current_task_entry"
} }
project {
description: "Project in which currently executing task resides"
"$ref": "#/definitions/id_name_entry"
}
queue { queue {
description: "Queue from which running task was taken" description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry" "$ref": "#/definitions/queue_entry"
@ -151,11 +155,11 @@
type: object type: object
properties { properties {
id { id {
description: "Worker ID" description: "ID"
type: string type: string
} }
name { name {
description: "Worker name" description: "Name"
type: string type: string
} }
} }

View File

@ -10,7 +10,7 @@ from werkzeug.exceptions import BadRequest
import database import database
from apierrors.base import BaseError from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter from bll.statistics.stats_reporter import StatisticsReporter
from config import config from config import config, info
from elastic.initialize import init_es_data from elastic.initialize import init_es_data
from mongo.initialize import init_mongo_data, pre_populate_data from mongo.initialize import init_mongo_data, pre_populate_data
from service_repo import ServiceRepo, APICall from service_repo import ServiceRepo, APICall
@ -39,9 +39,11 @@ database.initialize()
hosts_string = ";".join(sorted(database.get_hosts())) hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest() key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)): with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)):
print(key) empty_es = init_es_data()
init_es_data()
empty_db = init_mongo_data() empty_db = init_mongo_data()
if empty_es and not empty_db:
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
if empty_db and config.get("apiserver.pre_populate.enabled", False): if empty_db and config.get("apiserver.pre_populate.enabled", False):
pre_populate_data() pre_populate_data()

View File

@ -69,6 +69,10 @@ def authorize_credentials(auth_data, service, action, call_data_items):
if fixed_user: if fixed_user:
if secret_key != fixed_user.password: if secret_key != fixed_user.password:
raise errors.unauthorized.InvalidCredentials('bad username or password') raise errors.unauthorized.InvalidCredentials('bad username or password')
if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action):
raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest')
query = Q(id=fixed_user.user_id) query = Q(id=fixed_user.user_id)
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'): with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):

View File

@ -1,14 +1,12 @@
import hashlib import hashlib
from functools import lru_cache from functools import lru_cache
from typing import Sequence, TypeVar from typing import Sequence, Optional
import attr import attr
from config import config from config import config
from config.info import get_default_company from config.info import get_default_company
T = TypeVar("T", bound="FixedUser")
class FixedUsersError(Exception): class FixedUsersError(Exception):
pass pass
@ -21,6 +19,8 @@ class FixedUser:
name: str name: str
company: str = get_default_company() company: str = get_default_company()
is_guest: bool = False
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest() self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest()
@ -28,6 +28,10 @@ class FixedUser:
def enabled(cls): def enabled(cls):
return config.get("apiserver.auth.fixed_users.enabled", False) return config.get("apiserver.auth.fixed_users.enabled", False)
@classmethod
def guest_enabled(cls):
return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False)
@classmethod @classmethod
def validate(cls): def validate(cls):
if not cls.enabled(): if not cls.enabled():
@ -39,18 +43,50 @@ class FixedUser:
) )
@classmethod @classmethod
@lru_cache() # @lru_cache()
def from_config(cls) -> Sequence[T]: def from_config(cls) -> Sequence["FixedUser"]:
return [ users = [
cls(**user) for user in config.get("apiserver.auth.fixed_users.users", []) cls(**user) for user in config.get("apiserver.auth.fixed_users.users", [])
] ]
if cls.guest_enabled():
users.insert(
0,
cls.get_guest_user()
)
return users
@classmethod @classmethod
@lru_cache() @lru_cache()
def get_by_username(cls, username) -> T: def get_by_username(cls, username) -> "FixedUser":
return next( return next(
(user for user in cls.from_config() if user.username == username), None (user for user in cls.from_config() if user.username == username), None
) )
@classmethod
@lru_cache()
def is_guest_endpoint(cls, service, action):
"""
Validate a potential guest user,
This method will verify the user is indeed the guest user,
and that the guest user may access the service/action using its username/password
"""
return any(
ep == ".".join((service, action))
for ep in config.get("services.auth.fixed_users.guest.allow_endpoints", [])
)
@classmethod
def get_guest_user(cls) -> Optional["FixedUser"]:
if cls.guest_enabled():
return cls(
is_guest=True,
username=config.get("services.auth.fixed_users.guest.username"),
password=config.get("services.auth.fixed_users.guest.password"),
name=config.get("services.auth.fixed_users.guest.name"),
company=config.get("services.auth.fixed_users.guest.default_company"),
)
def __hash__(self): def __hash__(self):
return hash(self.user_id) return hash(self.user_id)

View File

@ -16,7 +16,7 @@ from apimodels.auth import (
) )
from apimodels.base import UpdateResponse from apimodels.base import UpdateResponse
from bll.auth import AuthBLL from bll.auth import AuthBLL
from config import config from config import config, info
from database.errors import translate_errors_context from database.errors import translate_errors_context
from database.model.auth import User from database.model.auth import User
from service_repo import APICall, endpoint from service_repo import APICall, endpoint
@ -176,4 +176,17 @@ def update(call, company_id, _):
@endpoint("auth.fixed_users_mode") @endpoint("auth.fixed_users_mode")
def fixed_users_mode(call: APICall, *_, **__): def fixed_users_mode(call: APICall, *_, **__):
call.result.data = dict(enabled=FixedUser.enabled()) data = {
"enabled": FixedUser.enabled(),
"migration_warning": info.missed_es_upgrade,
"guest": {
"enabled": FixedUser.guest_enabled(),
}
}
guest_user = FixedUser.get_guest_user()
if guest_user:
data["guest"]["name"] = guest_user.name
data["guest"]["username"] = guest_user.username
data["guest"]["password"] = guest_user.password
call.result.data = data

View File

@ -5,7 +5,8 @@ from mongoengine import Q, EmbeddedDocument
import database import database
from apierrors import errors from apierrors import errors
from apimodels.base import UpdateResponse from apierrors.errors.bad_request import InvalidModelId
from apimodels.base import UpdateResponse, MakePublicRequest
from apimodels.models import ( from apimodels.models import (
CreateModelRequest, CreateModelRequest,
CreateModelResponse, CreateModelResponse,
@ -467,3 +468,21 @@ def update(call: APICall, company_id, _):
if del_count: if del_count:
_reset_cached_tags(company_id, projects=[model.project]) _reset_cached_tags(company_id, projects=[model.project])
call.result.data = dict(deleted=del_count > 0) call.result.data = dict(deleted=del_count > 0)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
)
@endpoint(
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
)

View File

@ -8,7 +8,8 @@ from mongoengine import Q
import database import database
from apierrors import errors from apierrors import errors
from apimodels.base import UpdateResponse from apierrors.errors.bad_request import InvalidProjectId
from apimodels.base import UpdateResponse, MakePublicRequest
from apimodels.projects import ( from apimodels.projects import (
GetHyperParamReq, GetHyperParamReq,
GetHyperParamResp, GetHyperParamResp,
@ -422,3 +423,23 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
projects=request.projects, projects=request.projects,
) )
call.result.data = get_tags_response(ret) call.result.data = get_tags_response(ret)
@endpoint(
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
)
@endpoint(
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
)

View File

@ -11,7 +11,8 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne from pymongo import UpdateOne
from apierrors import errors, APIError from apierrors import errors, APIError
from apimodels.base import UpdateResponse, IdResponse from apierrors.errors.bad_request import InvalidTaskId
from apimodels.base import UpdateResponse, IdResponse, MakePublicRequest
from apimodels.tasks import ( from apimodels.tasks import (
StartedResponse, StartedResponse,
ResetResponse, ResetResponse,
@ -78,10 +79,24 @@ def set_task_status_from_call(
task = TaskBLL.get_task_with_access( task = TaskBLL.get_task_with_access(
request.task, request.task,
company_id=company_id, company_id=company_id,
only=tuple({"status", "project"} | fields_resolver.get_names()), only=tuple(
{"status", "project", "started", "duration"} | fields_resolver.get_names()
),
requires_write_access=True, requires_write_access=True,
) )
if "duration" not in fields_resolver.get_names():
if new_status == Task.started:
fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
elif new_status in (
TaskStatus.completed,
TaskStatus.failed,
TaskStatus.stopped,
):
fields_resolver.add_fields(
duration=int((task.started - datetime.utcnow()).total_seconds())
)
status_reason = request.status_reason status_reason = request.status_reason
status_message = request.status_message status_message = request.status_message
force = request.force force = request.force
@ -354,9 +369,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
def _reset_cached_tags(company: str, projects: Sequence[str]): def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags( org_bll.reset_tags(company, Tags.Task, projects=projects)
company, Tags.Task, projects=projects
)
@endpoint( @endpoint(
@ -573,9 +586,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
if updated: if updated:
new_project = fixed_fields.get("project", task.project) new_project = fixed_fields.get("project", task.project)
if new_project != task.project: if new_project != task.project:
_reset_cached_tags( _reset_cached_tags(company_id, projects=[new_project, task.project])
company_id, projects=[new_project, task.project]
)
else: else:
_update_cached_tags( _update_cached_tags(
company_id, project=task.project, fields=fixed_fields company_id, project=task.project, fields=fixed_fields
@ -1005,3 +1016,19 @@ def add_or_update_artifacts(
task_id=request.task, company_id=company_id, artifacts=request.artifacts task_id=request.task, company_id=company_id, artifacts=request.artifacts
) )
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated) call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)

View File

@ -1,3 +1,4 @@
from apierrors.errors.bad_request import InvalidModelId
from tests.automated import TestService from tests.automated import TestService
MODEL_CANNOT_BE_UPDATED_CODES = (400, 203) MODEL_CANNOT_BE_UPDATED_CODES = (400, 203)
@ -7,7 +8,7 @@ IN_PROGRESS = "in_progress"
class TestModelsService(TestService): class TestModelsService(TestService):
def setUp(self, version="2.8"): def setUp(self, version="2.9"):
super().setUp(version=version) super().setUp(version=version)
def test_publish_output_model_running_task(self): def test_publish_output_model_running_task(self):
@ -197,6 +198,28 @@ class TestModelsService(TestService):
res = self.api.models.get_frameworks(projects=[project]) res = self.api.models.get_frameworks(projects=[project])
self.assertEqual([], res.frameworks) self.assertEqual([], res.frameworks)
def test_make_public(self):
m1 = self._create_model(name="public model test")
# model with company_origin not set to the current company cannot be converted to private
with self.api.raises(InvalidModelId):
self.api.models.make_private(ids=[m1])
# public model can be retrieved but not updated
res = self.api.models.make_public(ids=[m1])
self.assertEqual(res.updated, 1)
res = self.api.models.get_all(id=[m1])
self.assertEqual([m.id for m in res.models], [m1])
with self.api.raises(InvalidModelId):
self.api.models.update(model=m1, name="public model test change 1")
# task made private again and can be both retrieved and updated
res = self.api.models.make_private(ids=[m1])
self.assertEqual(res.updated, 1)
res = self.api.models.get_all(id=[m1])
self.assertEqual([m.id for m in res.models], [m1])
self.api.models.update(model=m1, name="public model test change 2")
def _assert_task_status(self, task_id, status): def _assert_task_status(self, task_id, status):
task = self.api.tasks.get_by_id(task=task_id).task task = self.api.tasks.get_by_id(task=task_id).task
assert task.status == status assert task.status == status

View File

@ -0,0 +1,34 @@
from apierrors.errors.bad_request import InvalidProjectId
from apierrors.errors.forbidden import NoWritePermission
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestProjectsEdit(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.9")
def test_make_public(self):
p1 = self.create_temp("projects", name="Test public", description="test")
# project with company_origin not set to the current company cannot be converted to private
with self.api.raises(InvalidProjectId):
self.api.projects.make_private(ids=[p1])
# public project can be retrieved but not updated
res = self.api.projects.make_public(ids=[p1])
self.assertEqual(res.updated, 1)
res = self.api.projects.get_all(id=[p1])
self.assertEqual([p.id for p in res.projects], [p1])
with self.api.raises(NoWritePermission):
self.api.projects.update(project=p1, name="Test public change 1")
# task made private again and can be both retrieved and updated
res = self.api.projects.make_private(ids=[p1])
self.assertEqual(res.updated, 1)
res = self.api.projects.get_all(id=[p1])
self.assertEqual([p.id for p in res.projects], [p1])
self.api.projects.update(project=p1, name="Test public change 2")

View File

@ -1,4 +1,5 @@
from apierrors.errors.bad_request import InvalidModelId, ValidationError from apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
from apierrors.errors.forbidden import NoWritePermission
from config import config from config import config
from tests.automated import TestService from tests.automated import TestService
@ -8,7 +9,7 @@ log = config.logger(__file__)
class TestTasksEdit(TestService): class TestTasksEdit(TestService):
def setUp(self, **kwargs): def setUp(self, **kwargs):
super().setUp(version=2.5) super().setUp(version="2.9")
def new_task(self, **kwargs): def new_task(self, **kwargs):
self.update_missing( self.update_missing(
@ -145,3 +146,28 @@ class TestTasksEdit(TestService):
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
) )
return new_task return new_task
def test_make_public(self):
task = self.new_task()
# task is created as private and can be updated
self.api.tasks.started(task=task)
# task with company_origin not set to the current company cannot be converted to private
with self.api.raises(InvalidTaskId):
self.api.tasks.make_private(ids=[task])
# public task can be retrieved but not updated
res = self.api.tasks.make_public(ids=[task])
self.assertEqual(res.updated, 1)
res = self.api.tasks.get_all_ex(id=[task])
self.assertEqual([t.id for t in res.tasks], [task])
with self.api.raises(NoWritePermission):
self.api.tasks.stopped(task=task)
# task made private again and can be both retrieved and updated
res = self.api.tasks.make_private(ids=[task])
self.assertEqual(res.updated, 1)
res = self.api.tasks.get_all_ex(id=[task])
self.assertEqual([t.id for t in res.tasks], [task])
self.api.tasks.stopped(task=task)