From baba8b5b739e766e0254e0bf38e0f969eddaea59 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 10 Aug 2020 08:30:40 +0300 Subject: [PATCH] Move to ElasticSearch 7 Add initial support for project ordering Add support for sortable task duration (used by the UI in the experiment's table) Add support for project name in worker's current task info Add support for results and artifacts in pre-populates examples Add demo server features --- server/apimodels/base.py | 31 +- server/apimodels/events.py | 6 +- server/apimodels/workers.py | 1 + server/bll/event/debug_images_iterator.py | 13 +- server/bll/event/event_bll.py | 113 +- server/bll/event/event_metrics.py | 438 +++--- server/bll/event/log_events_iterator.py | 6 +- server/bll/event/scalar_key.py | 4 +- server/bll/queue/queue_metrics.py | 5 +- server/bll/statistics/stats_reporter.py | 1 - server/bll/util.py | 23 +- server/bll/workers/__init__.py | 9 +- server/bll/workers/stats.py | 7 +- server/config/default/apiserver.conf | 2 +- server/config/default/hosts.conf | 4 +- server/config/default/services/auth.conf | 16 + server/config/default/services/projects.conf | 8 + server/config/info.py | 3 + server/database/model/auth.py | 2 + server/database/model/base.py | 43 +- server/database/model/model.py | 1 + server/database/model/project.py | 6 +- server/database/model/task/task.py | 4 +- server/elastic/apply_mappings.py | 20 +- server/elastic/initialize.py | 15 +- server/elastic/mappings/events.json | 45 +- server/elastic/mappings/events_log.json | 15 +- server/elastic/mappings/events_plot.json | 11 +- .../mappings/events_training_debug_image.json | 14 +- server/elastic/mappings/queue_metrics.json | 32 +- server/elastic/mappings/worker_stats.json | 40 +- server/mongo/initialize/__init__.py | 13 +- server/mongo/initialize/pre_populate.py | 184 ++- server/mongo/initialize/user.py | 10 +- server/mongo/initialize/util.py | 5 +- server/requirements.txt | 5 +- server/schema/services/auth.conf | 3 + server/schema/services/events.conf | 6 +- server/schema/services/models.conf | 1270 +++++++++-------- server/schema/services/projects.conf | 52 + server/schema/services/tasks.conf | 50 + server/schema/services/workers.conf | 8 +- server/server.py | 8 +- server/service_repo/auth/auth.py | 4 + server/service_repo/auth/fixed_user.py | 50 +- server/services/auth.py | 17 +- server/services/models.py | 21 +- server/services/projects.py | 23 +- server/services/tasks.py | 43 +- server/tests/automated/test_models.py | 25 +- server/tests/automated/test_projects_edit.py | 34 + server/tests/automated/test_tasks_edit.py | 30 +- 52 files changed, 1655 insertions(+), 1144 deletions(-) create mode 100644 server/config/default/services/auth.conf create mode 100644 server/config/default/services/projects.conf create mode 100644 server/tests/automated/test_projects_edit.py diff --git a/server/apimodels/base.py b/server/apimodels/base.py index 51ba290..f1e8cb6 100644 --- a/server/apimodels/base.py +++ b/server/apimodels/base.py @@ -1,7 +1,8 @@ from jsonmodels import models, fields +from jsonmodels.validators import Length from mongoengine.base import BaseDocument -from apimodels import DictField +from apimodels import DictField, ListField class MongoengineFieldsDict(DictField): @@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField): """ mongoengine_update_operators = ( - 'inc', - 'dec', - 'push', - 'push_all', - 'pop', - 'pull', - 'pull_all', - 'add_to_set', + "inc", + "dec", + "push", + "push_all", + "pop", + "pull", + "pull_all", + "add_to_set", ) @staticmethod @@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField): @classmethod def _normalize_mongo_field_path(cls, path, value): - parts = path.split('__') + parts = path.split("__") if len(parts) > 1: - if parts[0] == 'set': + if parts[0] == "set": parts = parts[1:] - elif parts[0] == 'unset': + elif parts[0] == "unset": parts = parts[1:] value = None elif parts[0] in cls.mongoengine_update_operators: return None, None - return '.'.join(parts), cls._normalize_mongo_value(value) + return ".".join(parts), cls._normalize_mongo_value(value) def parse_value(self, value): value = super(MongoengineFieldsDict, self).parse_value(value) @@ -62,3 +63,7 @@ class PagedRequest(models.Base): class IdResponse(models.Base): id = fields.StringField(required=True) + + +class MakePublicRequest(models.Base): + ids = ListField(items_types=str, validators=[Length(minimum_value=1)]) diff --git a/server/apimodels/events.py b/server/apimodels/events.py index 83427ef..9ab55b8 100644 --- a/server/apimodels/events.py +++ b/server/apimodels/events.py @@ -3,7 +3,7 @@ from typing import Sequence, Optional from jsonmodels import validators from jsonmodels.fields import StringField, BoolField from jsonmodels.models import Base -from jsonmodels.validators import Length +from jsonmodels.validators import Length, Min, Max from apimodels import ListField, IntField, ActualEnumField from bll.event.event_metrics import EventType @@ -11,7 +11,7 @@ from bll.event.scalar_key import ScalarKeyEnum class HistogramRequestBase(Base): - samples: int = IntField(default=10000) + samples: int = IntField(default=6000, validators=[Min(1), Max(6000)]) key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) @@ -21,7 +21,7 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase): class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): tasks: Sequence[str] = ListField( - items_types=str, validators=[Length(minimum_value=1)] + items_types=str, validators=[Length(minimum_value=1, maximum_value=10)] ) diff --git a/server/apimodels/workers.py b/server/apimodels/workers.py index 40d8f4d..7fbb950 100644 --- a/server/apimodels/workers.py +++ b/server/apimodels/workers.py @@ -67,6 +67,7 @@ class WorkerEntry(Base, JsonSerializableMixin): company = EmbeddedField(IdNameEntry) ip = StringField() task = EmbeddedField(IdNameEntry) + project = EmbeddedField(IdNameEntry) queue = StringField() # queue from which current task was taken queues = ListField(str) # list of queues this worker listens to register_time = DateTimeField(required=True) diff --git a/server/bll/event/debug_images_iterator.py b/server/bll/event/debug_images_iterator.py index ccedf09..a98a85e 100644 --- a/server/bll/event/debug_images_iterator.py +++ b/server/bll/event/debug_images_iterator.py @@ -208,7 +208,11 @@ class DebugImagesIterator: "size": 0, "query": { "bool": { - "must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}] + "must": [ + {"term": {"task": task}}, + {"terms": {"metric": metrics}}, + {"exists": {"field": "url"}}, + ] } }, "aggs": { @@ -251,7 +255,7 @@ class DebugImagesIterator: } with translate_errors_context(), TimingContext("es", "_init_metric_states"): - es_res = self.es.search(index=es_index, body=es_req, routing=task) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return [] @@ -298,6 +302,7 @@ class DebugImagesIterator: must_conditions = [ {"term": {"task": metric.task}}, {"term": {"metric": metric.name}}, + {"exists": {"field": "url"}}, ] must_not_conditions = [] @@ -368,7 +373,7 @@ class DebugImagesIterator: "terms": { "field": "iter", "size": iter_count, - "order": {"_term": "desc" if navigate_earlier else "asc"}, + "order": {"_key": "desc" if navigate_earlier else "asc"}, }, "aggs": { "variants": { @@ -387,7 +392,7 @@ class DebugImagesIterator: }, } with translate_errors_context(), TimingContext("es", "get_debug_image_events"): - es_res = self.es.search(index=es_index, body=es_req, routing=metric.task) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return metric.task, metric.name, [] diff --git a/server/bll/event/event_bll.py b/server/bll/event/event_bll.py index 428aff7..65e2667 100644 --- a/server/bll/event/event_bll.py +++ b/server/bll/event/event_bll.py @@ -3,7 +3,7 @@ from collections import defaultdict from contextlib import closing from datetime import datetime from operator import attrgetter -from typing import Sequence, Set, Tuple +from typing import Sequence, Set, Tuple, Optional import six from elasticsearch import helpers @@ -22,6 +22,7 @@ from database.errors import translate_errors_context from database.model.task.task import Task, TaskStatus from redis_manager import redman from timing_context import TimingContext +from tools import safe_get from utilities.dicts import flatten_nested_items # noinspection PyTypeChecker @@ -134,7 +135,6 @@ class EventBLL(object): es_action = { "_op_type": "index", # overwrite if exists with same ID "_index": index_name, - "_type": "event", "_source": event, } @@ -144,7 +144,6 @@ class EventBLL(object): else: es_action["_id"] = dbutils.id() - es_action["_routing"] = task_id task_ids.add(task_id) if ( iter is not None @@ -342,14 +341,9 @@ class EventBLL(object): } with translate_errors_context(), TimingContext("es", "scroll_task_events"): - es_res = self.es.search( - index=es_index, body=es_req, scroll="1h", routing=task_id - ) - - events = [hit["_source"] for hit in es_res["hits"]["hits"]] - next_scroll_id = es_res["_scroll_id"] - total_events = es_res["hits"]["total"] + es_res = self.es.search(index=es_index, body=es_req, scroll="1h") + events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) return events, next_scroll_id, total_events def get_last_iterations_per_event_metric_variant( @@ -377,7 +371,7 @@ class EventBLL(object): "terms": { "field": "iter", "size": num_last_iterations, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, } } }, @@ -393,7 +387,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "task_last_iter_metric_variant" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return [] @@ -422,13 +416,11 @@ class EventBLL(object): if not self.es.indices.exists(es_index): return TaskEventsResult() - query = {"bool": defaultdict(list)} - + must = [] if last_iterations_per_plot is None: - must = query["bool"]["must"] must.append({"terms": {"task": tasks}}) else: - should = query["bool"]["should"] + should = [] for i, task_id in enumerate(tasks): last_iters = self.get_last_iterations_per_event_metric_variant( es_index, task_id, last_iterations_per_plot, event_type @@ -451,32 +443,41 @@ class EventBLL(object): ) if not should: return TaskEventsResult() + must.append({"bool": {"should": should}}) if sort is None: sort = [{"timestamp": {"order": "asc"}}] - es_req = {"sort": sort, "size": min(size, 10000), "query": query} - - routing = ",".join(tasks) + es_req = { + "sort": sort, + "size": min(size, 10000), + "query": {"bool": {"must": must}}, + } with translate_errors_context(), TimingContext("es", "get_task_plots"): es_res = self.es.search( - index=es_index, - body=es_req, - ignore=404, - routing=routing, - scroll="1h", + index=es_index, body=es_req, ignore=404, scroll="1h", ) - events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] - # scroll id may be missing when queering a totally empty DB - next_scroll_id = es_res.get("_scroll_id") - total_events = es_res["hits"]["total"] - + events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) return TaskEventsResult( events=events, next_scroll_id=next_scroll_id, total_events=total_events ) + def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]: + """ + Return events and next scroll id from the scrolled query + Release the scroll once it is exhausted + """ + total_events = safe_get(es_res, "hits/total/value", default=0) + events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])] + next_scroll_id = es_res.get("_scroll_id") + if next_scroll_id and not events: + self.es.clear_scroll(scroll_id=next_scroll_id) + next_scroll_id = None + + return events, total_events, next_scroll_id + def get_task_events( self, company_id, @@ -502,20 +503,16 @@ class EventBLL(object): if not self.es.indices.exists(es_index): return TaskEventsResult() - query = {"bool": defaultdict(list)} - - if metric or variant: - must = query["bool"]["must"] - if metric: - must.append({"term": {"metric": metric}}) - if variant: - must.append({"term": {"variant": variant}}) + must = [] + if metric: + must.append({"term": {"metric": metric}}) + if variant: + must.append({"term": {"variant": variant}}) if last_iter_count is None: - must = query["bool"]["must"] must.append({"terms": {"task": task_ids}}) else: - should = query["bool"]["should"] + should = [] for i, task_id in enumerate(task_ids): last_iters = self.get_last_iters( es_index, task_id, event_type, last_iter_count @@ -534,27 +531,23 @@ class EventBLL(object): ) if not should: return TaskEventsResult() + must.append({"bool": {"should": should}}) if sort is None: sort = [{"timestamp": {"order": "asc"}}] - es_req = {"sort": sort, "size": min(size, 10000), "query": query} - - routing = ",".join(task_ids) + es_req = { + "sort": sort, + "size": min(size, 10000), + "query": {"bool": {"must": must}}, + } with translate_errors_context(), TimingContext("es", "get_task_events"): es_res = self.es.search( - index=es_index, - body=es_req, - ignore=404, - routing=routing, - scroll="1h", + index=es_index, body=es_req, ignore=404, scroll="1h", ) - events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] - next_scroll_id = es_res.get("_scroll_id") - total_events = es_res["hits"]["total"] - + events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) return TaskEventsResult( events=events, next_scroll_id=next_scroll_id, total_events=total_events ) @@ -590,7 +583,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) metrics = {} for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"): @@ -622,14 +615,14 @@ class EventBLL(object): "terms": { "field": "metric", "size": EventMetrics.MAX_METRICS_COUNT, - "order": {"_term": "asc"}, + "order": {"_key": "asc"}, }, "aggs": { "variants": { "terms": { "field": "variant", "size": EventMetrics.MAX_VARIANTS_COUNT, - "order": {"_term": "asc"}, + "order": {"_key": "asc"}, }, "aggs": { "last_value": { @@ -659,7 +652,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) metrics = [] max_timestamp = 0 @@ -706,7 +699,7 @@ class EventBLL(object): "sort": ["iter"], } with translate_errors_context(), TimingContext("es", "task_stats_vector"): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) vectors = [] iterations = [] @@ -727,7 +720,7 @@ class EventBLL(object): "terms": { "field": "iter", "size": iters, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, } } }, @@ -737,7 +730,7 @@ class EventBLL(object): es_req["query"]["bool"]["must"].append({"term": {"type": event_type}}) with translate_errors_context(), TimingContext("es", "task_last_iter"): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return [] @@ -759,8 +752,6 @@ class EventBLL(object): es_index = EventMetrics.get_index_name(company_id, "*") es_req = {"query": {"term": {"task": task_id}}} with translate_errors_context(), TimingContext("es", "delete_task_events"): - es_res = self.es.delete_by_query( - index=es_index, body=es_req, routing=task_id, refresh=True - ) + es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True) return es_res.get("deleted", 0) diff --git a/server/bll/event/event_metrics.py b/server/bll/event/event_metrics.py index a331f2d..a930ab4 100644 --- a/server/bll/event/event_metrics.py +++ b/server/bll/event/event_metrics.py @@ -1,12 +1,11 @@ import itertools from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures.thread import ThreadPoolExecutor from enum import Enum from functools import partial from operator import itemgetter -from typing import Sequence, Tuple, Callable, Iterable +from typing import Sequence, Tuple -from boltons.iterutils import bucketize from elasticsearch import Elasticsearch from mongoengine import Q @@ -16,7 +15,7 @@ from config import config from database.errors import translate_errors_context from database.model.task.task import Task from timing_context import TimingContext -from utilities import safe_get +from tools import safe_get log = config.logger(__file__) @@ -30,14 +29,18 @@ class EventType(Enum): class EventMetrics: - MAX_TASKS_COUNT = 50 - MAX_METRICS_COUNT = 200 - MAX_VARIANTS_COUNT = 500 + MAX_METRICS_COUNT = 100 + MAX_VARIANTS_COUNT = 100 MAX_AGGS_ELEMENTS_COUNT = 50 + MAX_SAMPLE_BUCKETS = 6000 def __init__(self, es: Elasticsearch): self.es = es + @property + def _max_concurrency(self): + return config.get("services.events.max_metrics_concurrency", 4) + @staticmethod def get_index_name(company_id, event_type): event_type = event_type.lower().replace(" ", "_") @@ -51,15 +54,48 @@ class EventMetrics: The amount of points in each histogram should not exceed the requested samples """ + es_index = self.get_index_name(company_id, "training_stats_scalar") + if not self.es.indices.exists(es_index): + return {} - return self._run_get_scalar_metrics_as_parallel( - company_id, - task_ids=[task_id], - samples=samples, - key=ScalarKey.resolve(key), - get_func=self._get_scalar_average, + return self._get_scalar_average_per_iter_core( + task_id, es_index, samples, ScalarKey.resolve(key) ) + def _get_scalar_average_per_iter_core( + self, + task_id: str, + es_index: str, + samples: int, + key: ScalarKey, + run_parallel: bool = True, + ) -> dict: + intervals = self._get_task_metric_intervals( + es_index=es_index, task_id=task_id, samples=samples, field=key.field + ) + if not intervals: + return {} + interval_groups = self._group_task_metric_intervals(intervals) + + get_scalar_average = partial( + self._get_scalar_average, task_id=task_id, es_index=es_index, key=key + ) + if run_parallel: + with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool: + metrics = itertools.chain.from_iterable( + pool.map(get_scalar_average, interval_groups) + ) + else: + metrics = itertools.chain.from_iterable( + get_scalar_average(group) for group in interval_groups + ) + + ret = defaultdict(dict) + for metric_key, metric_values in metrics: + ret[metric_key].update(metric_values) + + return ret + def compare_scalar_metrics_average_per_iter( self, company_id, @@ -72,12 +108,6 @@ class EventMetrics: Compare scalar metrics for different tasks per metric and variant The amount of points in each histogram should not exceed the requested samples """ - if len(task_ids) > self.MAX_TASKS_COUNT: - raise errors.BadRequest( - f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", - len(task_ids), - ) - task_name_by_id = {} with translate_errors_context(): task_objs = Task.get_many( @@ -90,7 +120,6 @@ class EventMetrics: if len(task_objs) < len(task_ids): invalid = tuple(set(task_ids) - set(r.id for r in task_objs)) raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) - task_name_by_id = {t.id: t.name for t in task_objs} companies = {t.company for t in task_objs} @@ -99,138 +128,95 @@ class EventMetrics: "only tasks from the same company are supported" ) - ret = self._run_get_scalar_metrics_as_parallel( - next(iter(companies)), - task_ids=task_ids, - samples=samples, - key=ScalarKey.resolve(key), - get_func=self._get_scalar_average_per_task, - ) - - for metric_data in ret.values(): - for variant_data in metric_data.values(): - for task_id, task_data in variant_data.items(): - task_data["name"] = task_name_by_id[task_id] - - return ret - - TaskMetric = Tuple[str, str, str] - - MetricInterval = Tuple[int, Sequence[TaskMetric]] - MetricData = Tuple[str, dict] - - def _split_metrics_by_max_aggs_count( - self, task_metrics: Sequence[TaskMetric] - ) -> Iterable[Sequence[TaskMetric]]: - """ - Return task metrics in groups where amount of task metrics in each group - is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and - variants while always preserving all their tasks in the same group - """ - if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT: - yield task_metrics - return - - tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2)) - groups = [] - for group in tm_grouped.values(): - groups.append(group) - if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT: - yield list(itertools.chain(*groups)) - groups = [] - - if groups: - yield list(itertools.chain(*groups)) - - return - - def _run_get_scalar_metrics_as_parallel( - self, - company_id: str, - task_ids: Sequence[str], - samples: int, - key: ScalarKey, - get_func: Callable[ - [MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData] - ], - ) -> dict: - """ - Group metrics per interval length and execute get_func for each group in parallel - :param company_id: id of the company - :params task_ids: ids of the tasks to collect data for - :param samples: maximum number of samples per metric - :param get_func: callable that given metric names for the same interval - performs histogram aggregation for the metrics and return the aggregated data - """ - es_index = self.get_index_name(company_id, "training_stats_scalar") + es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar") if not self.es.indices.exists(es_index): return {} - intervals = self._get_metric_intervals( - es_index=es_index, task_ids=task_ids, samples=samples, field=key.field + get_scalar_average_per_iter = partial( + self._get_scalar_average_per_iter_core, + es_index=es_index, + samples=samples, + key=ScalarKey.resolve(key), + run_parallel=False, ) - - if not intervals: - return {} - - intervals = list( - itertools.chain.from_iterable( - zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms)) - for i, tms in intervals - ) - ) - max_concurrency = config.get("services.events.max_metrics_concurrency", 4) - with ThreadPoolExecutor(max_workers=max_concurrency) as pool: - metrics = itertools.chain.from_iterable( - pool.map( - partial(get_func, task_ids=task_ids, es_index=es_index, key=key), - intervals, - ) + with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool: + task_metrics = zip( + task_ids, pool.map(get_scalar_average_per_iter, task_ids) ) - ret = defaultdict(dict) - for metric_key, metric_values in metrics: - ret[metric_key].update(metric_values) + res = defaultdict(lambda: defaultdict(dict)) + for task_id, task_data in task_metrics: + task_name = task_name_by_id[task_id] + for metric_key, metric_data in task_data.items(): + for variant_key, variant_data in metric_data.items(): + variant_data["name"] = task_name + res[metric_key][variant_key][task_id] = variant_data - return ret + return res - def _get_metric_intervals( - self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter" + MetricInterval = Tuple[str, str, int, int] + MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]] + + @classmethod + def _group_task_metric_intervals( + cls, intervals: Sequence[MetricInterval] + ) -> Sequence[MetricIntervalGroup]: + """ + Group task metric intervals so that the following conditions are meat: + - All the metrics in the same group have the same interval (with 10% rounding) + - The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT + - The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS + """ + metric_interval_groups = [] + interval_group = [] + group_interval_upper_bound = 0 + group_max_interval = 0 + group_samples = 0 + for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)): + if ( + interval > group_interval_upper_bound + or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS + or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT + ): + if interval_group: + metric_interval_groups.append((group_max_interval, interval_group)) + interval_group = [] + group_max_interval = interval + group_interval_upper_bound = interval + int(interval * 0.1) + group_samples = 0 + interval_group.append((metric, variant)) + group_samples += size + group_max_interval = max(group_max_interval, interval) + if interval_group: + metric_interval_groups.append((group_max_interval, interval_group)) + + return metric_interval_groups + + def _get_task_metric_intervals( + self, es_index, task_id: str, samples: int, field: str = "iter" ) -> Sequence[MetricInterval]: """ Calculate interval per task metric variant so that the resulting amount of points does not exceed sample. - Return metric variants grouped by interval value with 10% rounding - For samples==0 return empty list + Return the list og metric variant intervals as the following tuple: + (metric, variant, interval, samples) """ - default_intervals = [(1, [])] - if not samples: - return default_intervals - es_req = { "size": 0, - "query": {"terms": {"task": task_ids}}, + "query": {"term": {"task": task_id}}, "aggs": { - "tasks": { - "terms": {"field": "task", "size": self.MAX_TASKS_COUNT}, + "metrics": { + "terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}, "aggs": { - "metrics": { + "variants": { "terms": { - "field": "metric", - "size": self.MAX_METRICS_COUNT, + "field": "variant", + "size": self.MAX_VARIANTS_COUNT, }, "aggs": { - "variants": { - "terms": { - "field": "variant", - "size": self.MAX_VARIANTS_COUNT, - }, - "aggs": { - "count": {"value_count": {"field": field}}, - "min_index": {"min": {"field": field}}, - "max_index": {"max": {"field": field}}, - }, - } + "count": {"value_count": {"field": field}}, + "min_index": {"min": {"field": field}}, + "max_index": {"max": {"field": field}}, }, } }, @@ -239,88 +225,75 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "task_stats_get_interval"): - es_res = self.es.search( - index=es_index, body=es_req, routing=",".join(task_ids) - ) + es_res = self.es.search(index=es_index, body=es_req) aggs_result = es_res.get("aggregations") if not aggs_result: - return default_intervals + return [] - intervals = [ - ( - task["key"], - metric["key"], - variant["key"], - self._calculate_metric_interval(variant, samples), - ) - for task in aggs_result["tasks"]["buckets"] - for metric in task["metrics"]["buckets"] + return [ + self._build_metric_interval(metric["key"], variant["key"], variant, samples) + for metric in aggs_result["metrics"]["buckets"] for variant in metric["variants"]["buckets"] ] - metric_intervals = [] - upper_border = 0 - interval_metrics = None - for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)): - if not interval_metrics or interval > upper_border: - interval_metrics = [] - metric_intervals.append((interval, interval_metrics)) - upper_border = interval + int(interval * 0.1) - interval_metrics.append((task, metric, variant)) - - return metric_intervals - @staticmethod - def _calculate_metric_interval(metric_variant: dict, samples: int) -> int: + def _build_metric_interval( + metric: str, variant: str, data: dict, samples: int + ) -> Tuple[str, str, int, int]: """ Calculate index interval per metric_variant variant so that the total amount of intervals does not exceeds the samples + Return the interval and resulting amount of intervals """ - count = safe_get(metric_variant, "count/value") - if not count or count < samples: - return 1 + count = safe_get(data, "count/value", default=0) + if count < samples: + return metric, variant, 1, count - min_index = safe_get(metric_variant, "min_index/value", default=0) - max_index = safe_get(metric_variant, "max_index/value", default=min_index) - return max(1, int(max_index - min_index + 1) // samples) + min_index = safe_get(data, "min_index/value", default=0) + max_index = safe_get(data, "max_index/value", default=min_index) + return ( + metric, + variant, + max(1, int(max_index - min_index + 1) // samples), + samples, + ) + + MetricData = Tuple[str, dict] def _get_scalar_average( self, - metrics_interval: MetricInterval, - task_ids: Sequence[str], + metrics_interval: MetricIntervalGroup, + task_id: str, es_index: str, key: ScalarKey, ) -> Sequence[MetricData]: """ Retrieve scalar histograms per several metric variants that share the same interval - Note: the function works with a single task only """ - - assert len(task_ids) == 1 - interval, task_metrics = metrics_interval + interval, metrics = metrics_interval aggregation = self._add_aggregation_average(key.get_aggregation(interval)) aggs = { "metrics": { "terms": { "field": "metric", "size": self.MAX_METRICS_COUNT, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, }, "aggs": { "variants": { "terms": { "field": "variant", "size": self.MAX_VARIANTS_COUNT, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, }, "aggs": aggregation, } }, } } - aggs_result = self._query_aggregation_for_metrics_and_tasks( - es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics + aggs_result = self._query_aggregation_for_task_metrics( + es_index, aggs=aggs, task_id=task_id, metrics=metrics ) if not aggs_result: @@ -341,61 +314,6 @@ class EventMetrics: ] return metrics - def _get_scalar_average_per_task( - self, - metrics_interval: MetricInterval, - task_ids: Sequence[str], - es_index: str, - key: ScalarKey, - ) -> Sequence[MetricData]: - """ - Retrieve scalar histograms per several metric variants that share the same interval - """ - interval, task_metrics = metrics_interval - - aggregation = self._add_aggregation_average(key.get_aggregation(interval)) - aggs = { - "metrics": { - "terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}, - "aggs": { - "variants": { - "terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT}, - "aggs": { - "tasks": { - "terms": { - "field": "task", - "size": self.MAX_TASKS_COUNT, - }, - "aggs": aggregation, - } - }, - } - }, - } - } - - aggs_result = self._query_aggregation_for_metrics_and_tasks( - es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics - ) - - if not aggs_result: - return {} - - metrics = [ - ( - metric["key"], - { - variant["key"]: { - task["key"]: key.get_iterations_data(task) - for task in variant["tasks"]["buckets"] - } - for variant in metric["variants"]["buckets"] - }, - ) - for metric in aggs_result["metrics"]["buckets"] - ] - return metrics - @staticmethod def _add_aggregation_average(aggregation): average_agg = {"avg_val": {"avg": {"field": "value"}}} @@ -404,69 +322,55 @@ class EventMetrics: for key, value in aggregation.items() } - def _query_aggregation_for_metrics_and_tasks( + def _query_aggregation_for_task_metrics( self, es_index: str, aggs: dict, - task_ids: Sequence[str], - task_metrics: Sequence[TaskMetric], + task_id: str, + metrics: Sequence[Tuple[str, str]], ) -> dict: """ Return the result of elastic search query for the given aggregation filtered by the given task_ids and metrics """ - if task_metrics: - condition = { - "should": [ - self._build_metric_terms(task, metric, variant) - for task, metric, variant in task_metrics - ] - } - else: - condition = {"must": [{"terms": {"task": task_ids}}]} + must = [{"term": {"task": task_id}}] + if metrics: + should = [ + { + "bool": { + "must": [ + {"term": {"metric": metric}}, + {"term": {"variant": variant}}, + ] + } + } + for metric, variant in metrics + ] + must.append({"bool": {"should": should}}) + es_req = { "size": 0, - "_source": {"excludes": []}, - "query": {"bool": condition}, + "query": {"bool": {"must": must}}, "aggs": aggs, - "version": True, } with translate_errors_context(), TimingContext("es", "task_stats_scalar"): - es_res = self.es.search( - index=es_index, body=es_req, routing=",".join(task_ids) - ) + es_res = self.es.search(index=es_index, body=es_req) return es_res.get("aggregations") - @staticmethod - def _build_metric_terms(task: str, metric: str, variant: str) -> dict: - """ - Build query term for a metric + variant - """ - return { - "bool": { - "must": [ - {"term": {"task": task}}, - {"term": {"metric": metric}}, - {"term": {"variant": variant}}, - ] - } - } - def get_tasks_metrics( self, company_id, task_ids: Sequence, event_type: EventType - ) -> Sequence[Tuple]: + ) -> Sequence: """ For the requested tasks return all the metrics that reported events of the requested types """ es_index = EventMetrics.get_index_name(company_id, event_type.value) if not self.es.indices.exists(es_index): - return [(tid, []) for tid in task_ids] + return {} - max_concurrency = config.get("services.events.max_metrics_concurrency", 4) - with ThreadPoolExecutor(max_concurrency) as pool: + with ThreadPoolExecutor(self._max_concurrency) as pool: res = pool.map( partial( self._get_task_metrics, es_index=es_index, event_type=event_type, @@ -494,7 +398,7 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "_get_task_metrics"): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) return [ metric["key"] diff --git a/server/bll/event/log_events_iterator.py b/server/bll/event/log_events_iterator.py index 89a32e2..3160060 100644 --- a/server/bll/event/log_events_iterator.py +++ b/server/bll/event/log_events_iterator.py @@ -71,9 +71,9 @@ class LogEventsIterator: es_req["search_after"] = [from_timestamp] with translate_errors_context(), TimingContext("es", "get_task_events"): - es_result = self.es.search(index=es_index, body=es_req, routing=task_id) + es_result = self.es.search(index=es_index, body=es_req) hits = es_result["hits"]["hits"] - hits_total = es_result["hits"]["total"] + hits_total = es_result["hits"]["total"]["value"] if not hits: return [], hits_total @@ -92,7 +92,7 @@ class LogEventsIterator: } }, } - es_result = self.es.search(index=es_index, body=es_req, routing=task_id) + es_result = self.es.search(index=es_index, body=es_req) hits = es_result["hits"]["hits"] if not hits or len(hits) < 2: # if only one element is returned for the last timestamp diff --git a/server/bll/event/scalar_key.py b/server/bll/event/scalar_key.py index 18fe1b1..fb63ab5 100644 --- a/server/bll/event/scalar_key.py +++ b/server/bll/event/scalar_key.py @@ -111,7 +111,7 @@ class TimestampKey(ScalarKey): self.name: { "date_histogram": { "field": "timestamp", - "interval": f"{interval}ms", + "fixed_interval": f"{interval}ms", "min_doc_count": 1, } } @@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey): self.name: { "date_histogram": { "field": "timestamp", - "interval": f"{interval}ms", + "fixed_interval": f"{interval}ms", "min_doc_count": 1, "format": "strict_date_time", } diff --git a/server/bll/queue/queue_metrics.py b/server/bll/queue/queue_metrics.py index 41d7df1..ad76f9f 100644 --- a/server/bll/queue/queue_metrics.py +++ b/server/bll/queue/queue_metrics.py @@ -18,7 +18,6 @@ log = config.logger(__file__) class QueueMetrics: class EsKeys: - DOC_TYPE = "metrics" WAITING_TIME_FIELD = "average_waiting_time" QUEUE_LENGTH_FIELD = "queue_length" TIMESTAMP_FIELD = "timestamp" @@ -66,7 +65,6 @@ class QueueMetrics: entries = [e for e in queue.entries if e.added] return dict( _index=es_index, - _type=self.EsKeys.DOC_TYPE, _source={ self.EsKeys.TIMESTAMP_FIELD: timestamp, self.EsKeys.QUEUE_FIELD: queue.id, @@ -93,7 +91,6 @@ class QueueMetrics: def _search_company_metrics(self, company_id: str, es_req: dict) -> dict: return self.es.search( index=f"{self._queue_metrics_prefix_for_company(company_id)}*", - doc_type=self.EsKeys.DOC_TYPE, body=es_req, ) @@ -109,7 +106,7 @@ class QueueMetrics: "dates": { "date_histogram": { "field": cls.EsKeys.TIMESTAMP_FIELD, - "interval": f"{interval}s", + "fixed_interval": f"{interval}s", "min_doc_count": 1, }, "aggs": { diff --git a/server/bll/statistics/stats_reporter.py b/server/bll/statistics/stats_reporter.py index 5d9f17c..4ed33ec 100644 --- a/server/bll/statistics/stats_reporter.py +++ b/server/bll/statistics/stats_reporter.py @@ -237,7 +237,6 @@ class StatisticsReporter: def _run_worker_stats_query(cls, company_id, es_req) -> dict: return worker_bll.es_client.search( index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*", - doc_type="stat", body=es_req, ) diff --git a/server/bll/util.py b/server/bll/util.py index 77c1b4d..0991f42 100644 --- a/server/bll/util.py +++ b/server/bll/util.py @@ -35,14 +35,21 @@ class SetFieldsResolver: SET_MODIFIERS = ("min", "max") def __init__(self, set_fields: Dict[str, Any]): - self.orig_fields = set_fields - self.fields = { - f: fname - for f, modifier, dunder, fname in ( - (f,) + f.partition("__") for f in set_fields.keys() - ) - if dunder and modifier in self.SET_MODIFIERS - } + self.orig_fields = {} + self.fields = {} + self.add_fields(**set_fields) + + def add_fields(self, **set_fields: Any): + self.orig_fields.update(set_fields) + self.fields.update( + { + f: fname + for f, modifier, dunder, fname in ( + (f,) + f.partition("__") for f in set_fields.keys() + ) + if dunder and modifier in self.SET_MODIFIERS + } + ) def _get_updated_name(self, doc: AttributedDocument, name: str) -> str: if name in self.fields and doc.get_field_value(self.fields[name]) is None: diff --git a/server/bll/workers/__init__.py b/server/bll/workers/__init__.py index d83cdb4..cf04985 100644 --- a/server/bll/workers/__init__.py +++ b/server/bll/workers/__init__.py @@ -21,6 +21,7 @@ from config import config from database.errors import translate_errors_context from database.model.auth import User from database.model.company import Company +from database.model.project import Project from database.model.queue import Queue from database.model.task.task import Task from redis_manager import redman @@ -146,6 +147,7 @@ class WorkerBLL: if not report.task: entry.task = None + entry.project = None else: with translate_errors_context(): query = dict(id=report.task, company=company_id) @@ -160,6 +162,12 @@ class WorkerBLL: raise bad_request.InvalidTaskId(**query) entry.task = IdNameEntry(id=task.id, name=task.name) + entry.project = None + if task.project: + project = Project.objects(id=task.project).only("name").first() + if project: + entry.project = IdNameEntry(id=project.id, name=project.name) + entry.last_report_time = now except APIError: raise @@ -369,7 +377,6 @@ class WorkerBLL: def make_doc(category, metric, variant, value) -> dict: return dict( _index=es_index, - _type="stat", _source=dict( timestamp=timestamp, worker=worker, diff --git a/server/bll/workers/stats.py b/server/bll/workers/stats.py index 5d55044..ecf7442 100644 --- a/server/bll/workers/stats.py +++ b/server/bll/workers/stats.py @@ -25,7 +25,6 @@ class WorkerStats: def _search_company_stats(self, company_id: str, es_req: dict) -> dict: return self.es.search( index=f"{self.worker_stats_prefix_for_company(company_id)}*", - doc_type="stat", body=es_req, ) @@ -53,7 +52,7 @@ class WorkerStats: res = self._search_company_stats(company_id, es_req) - if not res["hits"]["total"]: + if not res["hits"]["total"]["value"]: raise bad_request.WorkerStatsNotFound( f"No statistic metrics found for the company {company_id} and workers {worker_ids}" ) @@ -87,7 +86,7 @@ class WorkerStats: "dates": { "date_histogram": { "field": "timestamp", - "interval": f"{request.interval}s", + "fixed_interval": f"{request.interval}s", "min_doc_count": 1, }, "aggs": { @@ -216,7 +215,7 @@ class WorkerStats: "dates": { "date_histogram": { "field": "timestamp", - "interval": f"{interval}s", + "fixed_interval": f"{interval}s", }, "aggs": {"workers_count": {"cardinality": {"field": "worker"}}}, } diff --git a/server/config/default/apiserver.conf b/server/config/default/apiserver.conf index e248c35..198c1fa 100644 --- a/server/config/default/apiserver.conf +++ b/server/config/default/apiserver.conf @@ -30,7 +30,7 @@ enabled: false zip_files: ["/path/to/export.zip"] fail_on_error: false - artifacts_path: "/mnt/fileserver" + # artifacts_path: "/mnt/fileserver" } # time in seconds to take an exclusive lock to init es and mongodb diff --git a/server/config/default/hosts.conf b/server/config/default/hosts.conf index 17d9ab8..51aa77c 100644 --- a/server/config/default/hosts.conf +++ b/server/config/default/hosts.conf @@ -1,6 +1,6 @@ elastic { events { - hosts: [{host: "127.0.0.1", port: 9200}] + hosts: [{host: "127.0.0.1", port: 9211}] args { timeout: 60 dead_timeout: 10 @@ -11,7 +11,7 @@ elastic { } workers { - hosts: [{host:"127.0.0.1", port:9200}] + hosts: [{host:"127.0.0.1", port:9211}] args { timeout: 60 dead_timeout: 10 diff --git a/server/config/default/services/auth.conf b/server/config/default/services/auth.conf new file mode 100644 index 0000000..fe9c87e --- /dev/null +++ b/server/config/default/services/auth.conf @@ -0,0 +1,16 @@ +fixed_users { + guest { + enabled: false + + default_company: "025315a9321f49f8be07f5ac48fbcf92" + + name: "Guest" + username: "guest" + password: "guest" + + # Allow access only to the following endpoints when using user/pass credentials + allow_endpoints: [ + "auth.login" + ] + } +} \ No newline at end of file diff --git a/server/config/default/services/projects.conf b/server/config/default/services/projects.conf new file mode 100644 index 0000000..27a1c85 --- /dev/null +++ b/server/config/default/services/projects.conf @@ -0,0 +1,8 @@ +# Order of featured projects, by name or ID +featured_order: [ + # {id: ""} + # OR + # {name: ""} + # OR + # {name_regex: ""} +] diff --git a/server/config/info.py b/server/config/info.py index 529b77a..60439ad 100644 --- a/server/config/info.py +++ b/server/config/info.py @@ -41,3 +41,6 @@ def get_deployment_type() -> str: def get_default_company(): return config.get("apiserver.default_company") + + +missed_es_upgrade = False diff --git a/server/database/model/auth.py b/server/database/model/auth.py index 9dd0b39..406ae9d 100644 --- a/server/database/model/auth.py +++ b/server/database/model/auth.py @@ -32,6 +32,8 @@ class Role(object): """ Company user """ annotator = "annotator" """ Annotator with limited access""" + guest = "guest" + """ Guest user. Read Only.""" @classmethod def get_system_roles(cls) -> set: diff --git a/server/database/model/base.py b/server/database/model/base.py index 076a449..9075bf5 100644 --- a/server/database/model/base.py +++ b/server/database/model/base.py @@ -1,7 +1,7 @@ import re from collections import namedtuple from functools import reduce -from typing import Collection, Sequence, Union, Optional +from typing import Collection, Sequence, Union, Optional, Type from boltons.iterutils import first, bucketize from dateutil.parser import parse as parse_datetime @@ -9,6 +9,7 @@ from mongoengine import Q, Document, ListField, StringField from pymongo.command_cursor import CommandCursor from apierrors import errors +from apierrors.base import BaseError from config import config from database.errors import MakeGetAllQueryError from database.projection import project_dict, ProjectionHelper @@ -483,6 +484,21 @@ class GetMixin(PropsMixin): query=_query, parameters=parameters, override_projection=override_projection ) + @classmethod + def get_many_public( + cls, query: Q = None, projection: Collection[str] = None, + ): + """ + Fetch all public documents matching a provided query. + :param query: Optional query object (mongoengine.Q). + :param projection: A list of projection fields. + :return: A list of documents matching the query. + """ + q = get_company_or_none_constraint() + _query = (q & query) if query else q + + return cls._get_many_no_company(query=_query, override_projection=projection) + @classmethod def _get_many_no_company( cls: Union["GetMixin", Document], @@ -728,6 +744,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin): ) return cls.objects.aggregate(pipeline, **kwargs) + @classmethod + def set_public( + cls: Type[Document], + company_id: str, + ids: Sequence[str], + invalid_cls: Type[BaseError], + enabled: bool = True, + ): + if enabled: + items = list(cls.objects(id__in=ids, company=company_id).only("id")) + update = dict(set__company_origin=company_id, unset__company=1) + else: + items = list( + cls.objects( + id__in=ids, company__in=(None, ""), company_origin=company_id + ).only("id") + ) + update = dict(set__company=company_id, unset__company_origin=1) + + if len(items) < len(ids): + missing = tuple(set(ids).difference(i.id for i in items)) + raise invalid_cls(ids=missing) + + return {"updated": cls.objects(id__in=ids).update(**update)} + def validate_id(cls, company, **kwargs): """ diff --git a/server/database/model/model.py b/server/database/model/model.py index b777efd..8de275c 100644 --- a/server/database/model/model.py +++ b/server/database/model/model.py @@ -72,3 +72,4 @@ class Model(DbModelMixin, Document): ui_cache = SafeDictField( default=dict, user_set_allowed=True, exclude_by_default=True ) + company_origin = StringField(exclude_by_default=True) diff --git a/server/database/model/project.py b/server/database/model/project.py index bde016f..6165811 100644 --- a/server/database/model/project.py +++ b/server/database/model/project.py @@ -1,4 +1,4 @@ -from mongoengine import StringField, DateTimeField +from mongoengine import StringField, DateTimeField, IntField from database import Database, strict from database.fields import StrippedStringField, SafeSortedListField @@ -40,3 +40,7 @@ class Project(AttributedDocument): system_tags = SafeSortedListField(StringField(required=True)) default_output_destination = StrippedStringField() last_update = DateTimeField() + featured = IntField(default=9999) + logo_url = StringField() + logo_blob = StringField(exclude_by_default=True) + company_origin = StringField(exclude_by_default=True) diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index 2ee57f9..0ee4714 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -118,7 +118,7 @@ external_task_types = set(get_options(TaskType)) class Task(AttributedDocument): _field_collation_overrides = { "execution.parameters.": {"locale": "en_US", "numericOrdering": True}, - "last_metrics.": {"locale": "en_US", "numericOrdering": True} + "last_metrics.": {"locale": "en_US", "numericOrdering": True}, } meta = { @@ -194,3 +194,5 @@ class Task(AttributedDocument): last_iteration = IntField(default=DEFAULT_LAST_ITERATION) last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent))) metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats)) + company_origin = StringField(exclude_by_default=True) + duration = IntField() # task duration in seconds diff --git a/server/elastic/apply_mappings.py b/server/elastic/apply_mappings.py index 65b8112..3590e3c 100755 --- a/server/elastic/apply_mappings.py +++ b/server/elastic/apply_mappings.py @@ -4,9 +4,9 @@ Apply elasticsearch mappings to given hosts. """ import argparse import json -import requests from pathlib import Path +import requests from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry @@ -14,21 +14,24 @@ HERE = Path(__file__).resolve().parent session = requests.Session() adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5)) -session.mount('http://', adapter) +session.mount("http://", adapter) + + +def get_template(host: str, template) -> dict: + url = f"{host}/_template/{template}" + res = session.get(url) + return res.json() def apply_mappings_to_host(host: str): def _send_mapping(f): with f.open() as json_data: data = json.load(json_data) - es_server = host - url = f"{es_server}/_template/{f.stem}" + url = f"{host}/_template/{f.stem}" session.delete(url) r = session.post( - url, - headers={"Content-Type": "application/json"}, - data=json.dumps(data), + url, headers={"Content-Type": "application/json"}, data=json.dumps(data) ) return {"mapping": f.stem, "result": r.text} @@ -47,7 +50,8 @@ def parse_args(): def main(): - for host in parse_args().hosts: + args = parse_args() + for host in args.hosts: print(">>>>> Applying mapping to " + host) res = apply_mappings_to_host(host) print(res) diff --git a/server/elastic/initialize.py b/server/elastic/initialize.py index bfa51bf..5f4c63a 100644 --- a/server/elastic/initialize.py +++ b/server/elastic/initialize.py @@ -1,7 +1,7 @@ from furl import furl from config import config -from elastic.apply_mappings import apply_mappings_to_host +from elastic.apply_mappings import apply_mappings_to_host, get_template from es_factory import get_cluster_config log = config.logger(__file__) @@ -15,13 +15,22 @@ class MissingElasticConfiguration(Exception): pass -def init_es_data(): +def _url_from_host_conf(conf: dict) -> str: + return furl(scheme="http", host=conf["host"], port=conf["port"]).url + + +def init_es_data() -> bool: + """Return True if the db was empty""" hosts_config = get_cluster_config("events").get("hosts") if not hosts_config: raise MissingElasticConfiguration("for cluster 'events'") + empty_db = not get_template(_url_from_host_conf(hosts_config[0]), "events*") + for conf in hosts_config: - host = furl(scheme="http", host=conf["host"], port=conf["port"]).url + host = _url_from_host_conf(conf) log.info(f"Applying mappings to host: {host}") res = apply_mappings_to_host(host) log.info(res) + + return empty_db diff --git a/server/elastic/mappings/events.json b/server/elastic/mappings/events.json index 74708a1..eb33863 100644 --- a/server/elastic/mappings/events.json +++ b/server/elastic/mappings/events.json @@ -1,26 +1,39 @@ { - "template": "events-*", + "index_patterns": "events-*", "settings": { "number_of_shards": 1 }, "mappings": { - "_default_": { - "_source": { - "enabled": true + "_source": { + "enabled": true + }, + "properties": { + "@timestamp": { + "type": "date" }, - "_routing": { - "required": true + "task": { + "type": "keyword" }, - "properties": { - "@timestamp": { "type": "date" }, - "task": { "type": "keyword" }, - "type": { "type": "keyword" }, - "worker": { "type": "keyword" }, - "timestamp": { "type": "date" }, - "iter": { "type": "long" }, - "metric": { "type": "keyword" }, - "variant": { "type": "keyword" }, - "value": { "type": "float" } + "type": { + "type": "keyword" + }, + "worker": { + "type": "keyword" + }, + "timestamp": { + "type": "date" + }, + "iter": { + "type": "long" + }, + "metric": { + "type": "keyword" + }, + "variant": { + "type": "keyword" + }, + "value": { + "type": "float" } } } diff --git a/server/elastic/mappings/events_log.json b/server/elastic/mappings/events_log.json index 62a7051..07565b5 100644 --- a/server/elastic/mappings/events_log.json +++ b/server/elastic/mappings/events_log.json @@ -1,11 +1,14 @@ { - "template": "events-log-*", - "order" : 1, + "index_patterns": "events-log-*", + "order": 1, "mappings": { - "_default_": { - "properties": { - "msg": { "type":"text", "index": false }, - "level": { "type":"keyword" } + "properties": { + "msg": { + "type": "text", + "index": false + }, + "level": { + "type": "keyword" } } } diff --git a/server/elastic/mappings/events_plot.json b/server/elastic/mappings/events_plot.json index f5d607d..260700a 100644 --- a/server/elastic/mappings/events_plot.json +++ b/server/elastic/mappings/events_plot.json @@ -1,10 +1,11 @@ { - "template": "events-plot-*", - "order" : 1, + "index_patterns": "events-plot-*", + "order": 1, "mappings": { - "_default_": { - "properties": { - "plot_str": { "type":"text", "index": false } + "properties": { + "plot_str": { + "type": "text", + "index": false } } } diff --git a/server/elastic/mappings/events_training_debug_image.json b/server/elastic/mappings/events_training_debug_image.json index 146fb62..2c0d1e7 100644 --- a/server/elastic/mappings/events_training_debug_image.json +++ b/server/elastic/mappings/events_training_debug_image.json @@ -1,11 +1,13 @@ { - "template": "events-training_debug_image-*", - "order" : 1, + "index_patterns": "events-training_debug_image-*", + "order": 1, "mappings": { - "_default_": { - "properties": { - "key": { "type": "keyword" }, - "url": { "type": "keyword" } + "properties": { + "key": { + "type": "keyword" + }, + "url": { + "type": "keyword" } } } diff --git a/server/elastic/mappings/queue_metrics.json b/server/elastic/mappings/queue_metrics.json index 9f69251..0506c65 100644 --- a/server/elastic/mappings/queue_metrics.json +++ b/server/elastic/mappings/queue_metrics.json @@ -1,26 +1,24 @@ { - "template": "queue_metrics_*", + "index_patterns": "queue_metrics_*", "settings": { "number_of_shards": 1 }, "mappings": { - "metrics": { - "_source": { - "enabled": true + "_source": { + "enabled": true + }, + "properties": { + "timestamp": { + "type": "date" }, - "properties": { - "timestamp": { - "type": "date" - }, - "queue": { - "type": "keyword" - }, - "average_waiting_time": { - "type": "float" - }, - "queue_length": { - "type": "integer" - } + "queue": { + "type": "keyword" + }, + "average_waiting_time": { + "type": "float" + }, + "queue_length": { + "type": "integer" } } } diff --git a/server/elastic/mappings/worker_stats.json b/server/elastic/mappings/worker_stats.json index 3c2437c..6e11a3d 100644 --- a/server/elastic/mappings/worker_stats.json +++ b/server/elastic/mappings/worker_stats.json @@ -1,22 +1,36 @@ { - "template": "worker_stats_*", + "index_patterns": "worker_stats_*", "settings": { "number_of_shards": 1 }, "mappings": { - "stat": { - "_source": { - "enabled": true + "_source": { + "enabled": true + }, + "properties": { + "timestamp": { + "type": "date" }, - "properties": { - "timestamp": { "type": "date" }, - "worker": { "type": "keyword" }, - "category": { "type": "keyword" }, - "metric": { "type": "keyword" }, - "variant": { "type": "keyword" }, - "value": { "type": "float" }, - "unit": { "type": "keyword" }, - "task": { "type": "keyword" } + "worker": { + "type": "keyword" + }, + "category": { + "type": "keyword" + }, + "metric": { + "type": "keyword" + }, + "variant": { + "type": "keyword" + }, + "value": { + "type": "float" + }, + "unit": { + "type": "keyword" + }, + "task": { + "type": "keyword" } } } diff --git a/server/mongo/initialize/__init__.py b/server/mongo/initialize/__init__.py index d5bd9ec..6607c33 100644 --- a/server/mongo/initialize/__init__.py +++ b/server/mongo/initialize/__init__.py @@ -24,14 +24,9 @@ def _pre_populate(company_id: str, zip_file: str): else: log.info(f"Pre-populating using {zip_file}") - user_id = _ensure_backend_user( - "__allegroai__", company_id, "Allegro.ai" - ) - PrePopulate.import_from_zip( zip_file, company_id="", - user_id=user_id, artifacts_path=config.get( "apiserver.pre_populate.artifacts_path", None ), @@ -60,7 +55,7 @@ def init_mongo_data() -> bool: _ensure_uuid() - company_id = _ensure_company(log) + company_id = _ensure_company(get_default_company(), "trains", log) _ensure_default_queue(company_id) @@ -82,9 +77,13 @@ def init_mongo_data() -> bool: if fixed_mode: log.info("Fixed users mode is enabled") FixedUser.validate() + + if FixedUser.guest_enabled(): + _ensure_company(FixedUser.get_guest_user().company, "guests", log) + for user in FixedUser.from_config(): try: - ensure_fixed_user(user, company_id, log=log) + ensure_fixed_user(user, log=log) except Exception as ex: log.error(f"Failed creating fixed user {user.name}: {ex}") diff --git a/server/mongo/initialize/pre_populate.py b/server/mongo/initialize/pre_populate.py index e69ed44..0f93c7a 100644 --- a/server/mongo/initialize/pre_populate.py +++ b/server/mongo/initialize/pre_populate.py @@ -1,30 +1,44 @@ import hashlib import importlib import os +import re from collections import defaultdict from datetime import datetime, timezone +from functools import partial from io import BytesIO from itertools import chain from operator import attrgetter from os.path import splitext from pathlib import Path -from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union +from typing import ( + Optional, + Any, + Type, + Set, + Dict, + Sequence, + Tuple, + BinaryIO, + Union, + Mapping, +) from urllib.parse import unquote, urlparse from zipfile import ZipFile, ZIP_BZIP2 -import attr import mongoengine from boltons.iterutils import chunked_iter from furl import furl from mongoengine import Q from bll.event import EventBLL +from config import config from database.model import EntityVisibility from database.model.model import Model from database.model.project import Project from database.model.task.task import Task, ArtifactModes, TaskStatus from database.utils import get_options from utilities import json +from .user import _ensure_backend_user class PrePopulate: @@ -32,6 +46,7 @@ class PrePopulate: events_file_suffix = "_events" export_tag_prefix = "Exported:" export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S" + metadata_filename = "metadata.json" class JsonLinesWriter: def __init__(self, file: BinaryIO): @@ -54,26 +69,21 @@ class PrePopulate: self._write("\n" + line) self.empty = False - @attr.s(auto_attribs=True) - class _MapData: - files: Sequence[str] = None - entities: Dict[str, datetime] = None - @staticmethod def _get_last_update_time(entity) -> datetime: return getattr(entity, "last_update", None) or getattr(entity, "created") @classmethod def _check_for_update( - cls, map_file: Path, entities: dict + cls, map_file: Path, entities: dict, metadata_hash: str ) -> Tuple[bool, Sequence[str]]: if not map_file.is_file(): return True, [] files = [] try: - map_data = cls._MapData(**json.loads(map_file.read_text())) - files = map_data.files + map_data = json.loads(map_file.read_text()) + files = map_data.get("files", []) for file in files: if not Path(file).is_file(): return True, files @@ -82,7 +92,7 @@ class PrePopulate: item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc) for item in chain.from_iterable(entities.values()) } - old_times = map_data.entities + old_times = map_data.get("entities", {}) if set(new_times.keys()) != set(old_times.keys()): return True, files @@ -90,6 +100,10 @@ class PrePopulate: for id_, new_timestamp in new_times.items(): if new_timestamp != old_times[id_]: return True, files + + if metadata_hash != map_data.get("metadata_hash", ""): + return True, files + except Exception as ex: print("Error reading map file. " + str(ex)) return True, files @@ -98,16 +112,24 @@ class PrePopulate: @classmethod def _write_update_file( - cls, map_file: Path, entities: dict, created_files: Sequence[str] + cls, + map_file: Path, + entities: dict, + created_files: Sequence[str], + metadata_hash: str, ): - map_data = cls._MapData( - files=created_files, - entities={ - entity.id: cls._get_last_update_time(entity) - for entity in chain.from_iterable(entities.values()) - }, + map_file.write_text( + json.dumps( + dict( + files=created_files, + entities={ + entity.id: cls._get_last_update_time(entity) + for entity in chain.from_iterable(entities.values()) + }, + metadata_hash=metadata_hash, + ) + ) ) - map_file.write_text(json.dumps(attr.asdict(map_data))) @staticmethod def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]: @@ -117,7 +139,9 @@ class PrePopulate: return True if a.startswith("http"): parsed = urlparse(a) - if parsed.scheme in {"http", "https"} and parsed.port == 8081: + if parsed.scheme in {"http", "https"} and parsed.netloc.endswith( + "8081" + ): return True return False @@ -137,6 +161,7 @@ class PrePopulate: artifacts_path: str = None, task_statuses: Sequence[str] = None, tag_exported_entities: bool = False, + metadata: Mapping[str, Any] = None, ) -> Sequence[str]: if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)): raise ValueError("Invalid task statuses") @@ -146,11 +171,22 @@ class PrePopulate: experiments=experiments, projects=projects, task_statuses=task_statuses ) + hash_ = hashlib.md5() + if metadata: + meta_str = json.dumps(metadata) + hash_.update(meta_str.encode()) + metadata_hash = hash_.hexdigest() + else: + meta_str, metadata_hash = "", "" + map_file = file.with_suffix(".map") - updated, old_files = cls._check_for_update(map_file, entities) + updated, old_files = cls._check_for_update( + map_file, entities=entities, metadata_hash=metadata_hash + ) if not updated: print(f"There are no updates from the last export") return old_files + for old in old_files: old_path = Path(old) if old_path.is_file(): @@ -158,10 +194,16 @@ class PrePopulate: zip_args = dict(mode="w", compression=ZIP_BZIP2) with ZipFile(file, **zip_args) as zfile: - artifacts, hash_ = cls._export( - zfile, entities, tag_entities=tag_exported_entities + if metadata: + zfile.writestr(cls.metadata_filename, meta_str) + artifacts = cls._export( + zfile, + entities=entities, + hash_=hash_, + tag_entities=tag_exported_entities, ) - file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}") + + file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}") file.replace(file_with_hash) created_files = [str(file_with_hash)] @@ -172,16 +214,43 @@ class PrePopulate: cls._export_artifacts(zfile, artifacts, artifacts_path) created_files.append(str(artifacts_file)) - cls._write_update_file(map_file, entities, created_files) + cls._write_update_file( + map_file, + entities=entities, + created_files=created_files, + metadata_hash=metadata_hash, + ) return created_files @classmethod def import_from_zip( - cls, filename: str, company_id: str, user_id: str, artifacts_path: str + cls, + filename: str, + company_id: str, + artifacts_path: str, + user_id: str = "", + user_name: str = "", ): + metadata = None + with ZipFile(filename) as zfile: - cls._import(zfile, company_id, user_id) + try: + with zfile.open(cls.metadata_filename) as f: + metadata = json.loads(f.read()) + if not user_id: + meta_user_id = metadata.get("user_id", "") + meta_user_name = metadata.get("user_name", "") + user_id, user_name = meta_user_id, meta_user_name + except Exception: + pass + + if not user_id: + user_id, user_name = "__allegroai__", "Allegro.ai" + + user_id = _ensure_backend_user(user_id, company_id, user_name) + + cls._import(zfile, company_id, user_id, metadata) if artifacts_path and os.path.isdir(artifacts_path): artifacts_file = Path(filename).with_suffix(".artifacts") @@ -190,6 +259,24 @@ class PrePopulate: with ZipFile(artifacts_file) as zfile: zfile.extractall(artifacts_path) + @classmethod + def update_featured_projects_order(cls): + featured_order = config.get("services.projects.featured_order", []) + + def get_index(p: Project): + for index, entry in enumerate(featured_order): + if ( + entry.get("id", None) == p.id + or entry.get("name", None) == p.name + or ("name_regex" in entry and re.match(entry["name_regex"], p.name)) + ): + return index + return 999 + + for project in Project.get_many_public(projection=["id", "name"]): + featured_index = get_index(project) + Project.objects(id=project.id).update(featured=featured_index) + @staticmethod def _resolve_type( cls: Type[mongoengine.Document], ids: Optional[Sequence[str]] @@ -389,15 +476,14 @@ class PrePopulate: @classmethod def _export( - cls, writer: ZipFile, entities: dict, tag_entities: bool = False - ) -> Tuple[Sequence[str], str]: + cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False + ) -> Sequence[str]: """ Export the requested experiments, projects and models and return the list of artifact files Always do the export on sorted items since the order of items influence hash """ artifacts = [] now = datetime.utcnow() - hash_ = hashlib.md5() for cls_ in sorted(entities, key=attrgetter("__name__")): items = sorted(entities[cls_], key=attrgetter("id")) if not items: @@ -423,7 +509,7 @@ class PrePopulate: if tag_entities: cls._add_tag(items, now.strftime(cls.export_tag)) - return artifacts, hash_.hexdigest() + return artifacts @staticmethod def json_lines(file: BinaryIO): @@ -441,7 +527,13 @@ class PrePopulate: yield clean @classmethod - def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None): + def _import( + cls, + reader: ZipFile, + company_id: str = "", + user_id: str = None, + metadata: Mapping[str, Any] = None, + ): """ Import entities and events from the zip file Start from entities since event import will require the tasks already in DB @@ -451,12 +543,13 @@ class PrePopulate: fi for fi in reader.filelist if not fi.orig_filename.endswith(event_file_ending) + and fi.orig_filename != cls.metadata_filename ) event_files = ( fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending) ) for files, reader_func in ( - (entity_files, cls._import_entity), + (entity_files, partial(cls._import_entity, metadata=metadata or {})), (event_files, cls._import_events), ): for file_info in files: @@ -466,11 +559,20 @@ class PrePopulate: reader_func(f, full_name, company_id, user_id) @classmethod - def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str): + def _import_entity( + cls, + f: BinaryIO, + full_name: str, + company_id: str, + user_id: str, + metadata: Mapping[str, Any], + ): module_name, _, class_name = full_name.rpartition(".") module = importlib.import_module(module_name) cls_: Type[mongoengine.Document] = getattr(module, class_name) print(f"Writing {cls_.__name__.lower()}s into database") + + override_project_count = 0 for item in cls.json_lines(f): doc = cls_.from_json(item, created=True) if hasattr(doc, "user"): @@ -478,10 +580,24 @@ class PrePopulate: if hasattr(doc, "company"): doc.company = company_id if isinstance(doc, Project): + override_project_name = metadata.get("project_name", None) + if override_project_name: + if override_project_count: + override_project_name = ( + f"{override_project_name} {override_project_count + 1}" + ) + override_project_count += 1 + doc.name = override_project_name + + doc.logo_url = metadata.get("logo_url", None) + doc.logo_blob = metadata.get("logo_blob", None) + cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update( set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}" ) + doc.save() + if isinstance(doc, Task): cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True) diff --git a/server/mongo/initialize/user.py b/server/mongo/initialize/user.py index a42647d..6861108 100644 --- a/server/mongo/initialize/user.py +++ b/server/mongo/initialize/user.py @@ -58,15 +58,15 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str): return user_id -def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger): - if User.objects(id=user.user_id).first(): +def ensure_fixed_user(user: FixedUser, log: Logger): + if User.objects(company=user.company, id=user.user_id).first(): return data = attr.asdict(user) data["id"] = user.user_id data["email"] = f"{user.user_id}@example.com" - data["role"] = Role.user + data["role"] = Role.guest if user.is_guest else Role.user - _ensure_auth_user(user_data=data, company_id=company_id, log=log) + _ensure_auth_user(user_data=data, company_id=user.company, log=log) - return _ensure_backend_user(user.user_id, company_id, user.name) + return _ensure_backend_user(user.user_id, user.company, user.name) diff --git a/server/mongo/initialize/util.py b/server/mongo/initialize/util.py index 087176d..e90c8fc 100644 --- a/server/mongo/initialize/util.py +++ b/server/mongo/initialize/util.py @@ -3,7 +3,6 @@ from uuid import uuid4 from bll.queue import QueueBLL from config import config -from config.info import get_default_company from database.model.company import Company from database.model.queue import Queue from database.model.settings import Settings, SettingKeys @@ -11,13 +10,11 @@ from database.model.settings import Settings, SettingKeys log = config.logger(__file__) -def _ensure_company(log: Logger): - company_id = get_default_company() +def _ensure_company(company_id, company_name, log: Logger): company = Company.objects(id=company_id).only("id").first() if company: return company_id - company_name = "trains" log.info(f"Creating company: {company_name}") company = Company(id=company_id, name=company_name) company.save() diff --git a/server/requirements.txt b/server/requirements.txt index ef479c9..95f804c 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,7 +1,8 @@ attrs>=19.1.0 boltons>=19.1.0 +boto3==1.14.13 dpath>=1.4.2,<2.0 -elasticsearch>=5.0.0,<6.0.0 +elasticsearch>=7.0.0,<8.0.0 fastjsonschema>=2.8 Flask-Compress>=1.4.0 Flask-Cors>=3.0.5 @@ -24,7 +25,7 @@ python-rapidjson>=0.6.3 redis>=2.10.5 related>=0.7.2 requests>=2.13.0 -semantic_version>=2.8.0,<3 +semantic_version>=2.8.3,<3 six tqdm validators>=0.12.4 \ No newline at end of file diff --git a/server/schema/services/auth.conf b/server/schema/services/auth.conf index 46fc509..3755dfb 100644 --- a/server/schema/services/auth.conf +++ b/server/schema/services/auth.conf @@ -328,6 +328,9 @@ fixed_users_mode { description: "Fixed users mode enabled" type: boolean } + migration_warning { + type: boolean + } } } } diff --git a/server/schema/services/events.conf b/server/schema/services/events.conf index 34cb11a..a073cd4 100644 --- a/server/schema/services/events.conf +++ b/server/schema/services/events.conf @@ -848,7 +848,7 @@ description: "Task ID" } samples { - description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000." + description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000." type: integer } key { @@ -886,7 +886,7 @@ ] properties { tasks { - description: "List of task Task IDs" + description: "List of task Task IDs. Maximum amount of tasks is 10" type: array items { type: string @@ -894,7 +894,7 @@ } } samples { - description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000." + description: "The amount of histogram points to return. Optional, the default value is 6000" type: integer } key { diff --git a/server/schema/services/models.conf b/server/schema/services/models.conf index 1e23c2b..646c699 100644 --- a/server/schema/services/models.conf +++ b/server/schema/services/models.conf @@ -1,54 +1,366 @@ -{ - _description: """This service provides a management interface for models (results of training tasks) stored in the system.""" - _definitions { - multi_field_pattern_data { - type: object - properties { - pattern { - description: "Pattern string (regex)" - type: string - } - fields { - description: "List of field names" - type: array - items { type: string } - } +_description: """This service provides a management interface for models (results of training tasks) stored in the system.""" +_definitions { + multi_field_pattern_data { + type: object + properties { + pattern { + description: "Pattern string (regex)" + type: string + } + fields { + description: "List of field names" + type: array + items { type: string } } } - model { + } + model { + type: object + properties { + id { + description: "Model id" + type: string + } + name { + description: "Model name" + type: string + } + user { + description: "Associated user id" + type: string + } + company { + description: "Company id" + type: string + } + created { + description: "Model creation time" + type: string + format: "date-time" + } + task { + description: "Task ID of task in which the model was created" + type: string + } + parent { + description: "Parent model ID" + type: string + } + project { + description: "Associated project ID" + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + framework { + description: "Framework on which the model is based. Should be identical to the framework of the task which created the model" + type: string + } + design { + description: "Json object representing the model design. Should be identical to the network design of the task which created the model" + type: object + additionalProperties: true + } + labels { + description: "Json object representing the ids of the labels in the model. The keys are the layers' names and the values are the ids." + type: object + additionalProperties { type: integer } + } + uri { + description: "URI for the model, pointing to the destination storage." + type: string + } + ready { + description: "Indication if the model is final and can be used by other tasks" + type: boolean + } + ui_cache { + description: "UI cache for this model" + type: object + additionalProperties: true + } + } + } +} + +get_by_id { + "2.1" { + description: "Gets model information" + request { type: object + required: [ model ] properties { - id { + model { description: "Model id" type: string } + } + } + + response { + type: object + properties { + model { + description: "Model info" + "$ref": "#/definitions/model" + } + } + } + } +} + +get_by_task_id { + "2.1" { + description: "Gets model information" + request { + type: object + properties { + task { + description: "Task id" + type: string + } + } + } + response { + type: object + properties { + model { + description: "Model info" + "$ref": "#/definitions/model" + } + } + } + } +} +get_all_ex { + internal: true + "2.1": ${get_all."2.1"} +} +get_all { + "2.1" { + description: "Get all models" + request { + type: object + properties { name { - description: "Model name" + description: "Get only models whose name matches this pattern (python regular expression syntax)" type: string } user { - description: "Associated user id" + description: "List of user IDs used to filter results by the model's creating user" + type: array + items { type: string } + } + ready { + description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects." + type: boolean + } + tags { + description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion" + type: array + items { type: string } + } + system_tags { + description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion" + type: array + items { type: string } + } + only_fields { + description: "List of model field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)" + type: array + items { type: string } + } + page { + description: "Page number, returns a specific page out of the resulting list of models" + type: integer + minimum: 0 + } + page_size { + description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)" + type: integer + minimum: 1 + } + project { + description: "List of associated project IDs" + type: array + items { type: string } + } + order_by { + description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page" + type: array + items { type: string } + } + task { + description: "List of associated task IDs" + type: array + items { type: string } + } + id { + description: "List of model IDs" + type: array + items { type: string } + } + search_text { + description: "Free text search query" type: string } - company { - description: "Company id" + framework { + description: "List of frameworks" + type: array + items { type: string } + } + uri { + description: "List of model URIs" + type: array + items { type: string } + } + _all_ { + description: "Multi-field pattern condition (all fields match pattern)" + "$ref": "#/definitions/multi_field_pattern_data" + } + _any_ { + description: "Multi-field pattern condition (any field matches pattern)" + "$ref": "#/definitions/multi_field_pattern_data" + } + } + dependencies { + page: [ page_size ] + } + } + response { + type: object + properties { + models: { + description: "Models list" + type: array + items { "$ref": "#/definitions/model" } + } + } + } + } +} +get_frameworks { + "2.8" { + description: "Get the list of frameworks used in the company models" + request { + type: object + properties { + projects { + description: "The list of projects which models will be analyzed. If not passed or empty then all the company and public models will be analyzed" + type: array + items: {type: string} + } + } + } + response { + type: object + properties { + frameworks { + description: "Unique list of the frameworks used in the company models" + type: array + items: {type: string} + } + } + } + } +} +update_for_task { + "2.1" { + description: "Create or update a new model for a task" + request { + type: object + required: [ + task + ] + properties { + task { + description: "Task id" + type: string + } + uri { + description: "URI for the model. Exactly one of uri or override_model_id is a required." + type: string + } + name { + description: "Model name Unique within the company." + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + override_model_id { + description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required." + type: string + } + iteration { + description: "Iteration (used to update task statistics)" + type: integer + } + } + } + response { + type: object + properties { + id { + description: "ID of the model" type: string } created { - description: "Model creation time" - type: string - format: "date-time" + description: "Was the model created" + type: boolean } - task { - description: "Task ID of task in which the model was created" + updated { + description: "Number of models updated (0 or 1)" + type: integer + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } + } + } + } +} +create { + "2.1" { + description: "Create a new model not associated with a task" + request { + type: object + required: [ + uri + name + ] + properties { + uri { + description: "URI for the model" type: string } - parent { - description: "Parent model ID" - type: string - } - project { - description: "Associated project ID" + name { + description: "Model name Unique within the company." type: string } comment { @@ -66,595 +378,281 @@ items {type: string} } framework { - description: "Framework on which the model is based. Should be identical to the framework of the task which created the model" + description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." type: string } design { - description: "Json object representing the model design. Should be identical to the network design of the task which created the model" + description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" type: object additionalProperties: true } labels { - description: "Json object representing the ids of the labels in the model. The keys are the layers' names and the values are the ids." + description: "Json object" type: object additionalProperties { type: integer } } - uri { - description: "URI for the model, pointing to the destination storage." + ready { + description: "Indication if the model is final and can be used by other tasks. Default is false." + type: boolean + default: false + } + public { + description: "Create a public model Default is false." + type: boolean + default: false + } + project { + description: "Project to which to model belongs" type: string } + parent { + description: "Parent model" + type: string + } + task { + description: "Associated task ID" + type: string + } + } + } + response { + type: object + properties { + id { + description: "ID of the model" + type: string + } + created { + description: "Was the model created" + type: boolean + } + } + } + } +} +edit { + "2.1" { + description: "Edit an existing model" + request { + type: object + required: [ + model + ] + properties { + model { + description: "Model ID" + type: string + } + uri { + description: "URI for the model" + type: string + } + name { + description: "Model name Unique within the company." + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + framework { + description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." + type: string + } + design { + description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" + type: object + additionalProperties: true + } + labels { + description: "Json object" + type: object + additionalProperties { type: integer } + } ready { description: "Indication if the model is final and can be used by other tasks" type: boolean } + project { + description: "Project to which to model belongs" + type: string + } + parent { + description: "Parent model" + type: string + } + task { + description: "Associated task ID" + type: string + } + iteration { + description: "Iteration (used to update task statistics)" + type: integer + } + } + } + response { + type: object + properties { + updated { + description: "Number of models updated (0 or 1)" + type: integer + enum: [0, 1] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } + } + } + } +} +update { + "2.1" { + description: "Update a model" + request { + type: object + required: [ model ] + properties { + model { + description: "Model id" + type: string + } + name { + description: "Model name Unique within the company." + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + ready { + description: "Indication if the model is final and can be used by other tasks Default is false." + type: boolean + default: false + } + created { + description: "Model creation time (UTC) " + type: string + format: "date-time" + } ui_cache { description: "UI cache for this model" type: object additionalProperties: true } - } - } - } + project { + description: "Project to which to model belongs" + type: string + } + task { + description: "Associated task ID" + type: "string" + } + iteration { + description: "Iteration (used to update task statistics if an associated task is reported)" + type: integer + } - get_by_id { - "2.1" { - description: "Gets model information" - request { - type: object - required: [ model ] - properties { - model { - description: "Model id" - type: string - } - } } - - response { - type: object - properties { - model { - description: "Model info" - "$ref": "#/definitions/model" - } + } + response { + type: object + properties { + updated { + description: "Number of models updated (0 or 1)" + type: integer + enum: [0, 1] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true } } } } - - get_by_task_id { - "2.1" { - description: "Gets model information" - request { - type: object - properties { - task { - description: "Task id" - type: string - } +} +set_ready { + "2.1" { + description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task." + request { + type: object + required: [ model ] + properties { + model { + description: "Model id" + type: string } - } - response { - type: object - properties { - model { - description: "Model info" - "$ref": "#/definitions/model" - } + force_publish_task { + description: "Publish the associated task (if exists) even if it is not in the 'stopped' state. Optional, the default value is False." + type: boolean + } + publish_task { + description: "Indicates that the associated task (if exists) should be published. Optional, the default value is True." + type: boolean } } } - } - get_all_ex { - internal: true - "2.1": ${get_all."2.1"} - } - get_all { - "2.1" { - description: "Get all models" - request { - type: object - properties { - name { - description: "Get only models whose name matches this pattern (python regular expression syntax)" - type: string - } - user { - description: "List of user IDs used to filter results by the model's creating user" - type: array - items { type: string } - } - ready { - description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects." - type: boolean - } - tags { - description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion" - type: array - items { type: string } - } - system_tags { - description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion" - type: array - items { type: string } - } - only_fields { - description: "List of model field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)" - type: array - items { type: string } - } - page { - description: "Page number, returns a specific page out of the resulting list of models" - type: integer - minimum: 0 - } - page_size { - description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)" - type: integer - minimum: 1 - } - project { - description: "List of associated project IDs" - type: array - items { type: string } - } - order_by { - description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page" - type: array - items { type: string } - } - task { - description: "List of associated task IDs" - type: array - items { type: string } - } - id { - description: "List of model IDs" - type: array - items { type: string } - } - search_text { - description: "Free text search query" - type: string - } - framework { - description: "List of frameworks" - type: array - items { type: string } - } - uri { - description: "List of model URIs" - type: array - items { type: string } - } - _all_ { - description: "Multi-field pattern condition (all fields match pattern)" - "$ref": "#/definitions/multi_field_pattern_data" - } - _any_ { - description: "Multi-field pattern condition (any field matches pattern)" - "$ref": "#/definitions/multi_field_pattern_data" - } + response { + type: object + properties { + updated { + description: "Number of models updated (0 or 1)" + type: integer + enum: [0, 1] } - dependencies { - page: [ page_size ] - } - } - response { - type: object - properties { - models: { - description: "Models list" - type: array - items { "$ref": "#/definitions/model" } - } - } - } - } - } - get_frameworks { - "2.8" { - description: "Get the list of frameworks used in the company models" - request { - type: object - properties { - projects { - description: "The list of projects which models will be analyzed. If not passed or empty then all the company and public models will be analyzed" - type: array - items: {type: string} - } - } - } - response { - type: object - properties { - frameworks { - description: "Unique list of the frameworks used in the company models" - type: array - items: {type: string} - } - } - } - } - } - update_for_task { - "2.1" { - description: "Create or update a new model for a task" - request { - type: object - required: [ - task - ] - properties { - task { - description: "Task id" - type: string - } - uri { - description: "URI for the model. Exactly one of uri or override_model_id is a required." - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - override_model_id { - description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required." - type: string - } - iteration { - description: "Iteration (used to update task statistics)" - type: integer - } - } - } - response { - type: object - properties { - id { - description: "ID of the model" - type: string - } - created { - description: "Was the model created" - type: boolean - } - updated { - description: "Number of models updated (0 or 1)" - type: integer - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } - } - } - create { - "2.1" { - description: "Create a new model not associated with a task" - request { - type: object - required: [ - uri - name - ] - properties { - uri { - description: "URI for the model" - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - framework { - description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." - type: string - } - design { - description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" - type: object - additionalProperties: true - } - labels { - description: "Json object" - type: object - additionalProperties { type: integer } - } - ready { - description: "Indication if the model is final and can be used by other tasks. Default is false." - type: boolean - default: false - } - public { - description: "Create a public model Default is false." - type: boolean - default: false - } - project { - description: "Project to which to model belongs" - type: string - } - parent { - description: "Parent model" - type: string - } - task { - description: "Associated task ID" - type: string - } - } - } - response { - type: object - properties { - id { - description: "ID of the model" - type: string - } - created { - description: "Was the model created" - type: boolean - } - } - } - } - } - edit { - "2.1" { - description: "Edit an existing model" - request { - type: object - required: [ - model - ] - properties { - model { - description: "Model ID" - type: string - } - uri { - description: "URI for the model" - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - framework { - description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." - type: string - } - design { - description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" - type: object - additionalProperties: true - } - labels { - description: "Json object" - type: object - additionalProperties { type: integer } - } - ready { - description: "Indication if the model is final and can be used by other tasks" - type: boolean - } - project { - description: "Project to which to model belongs" - type: string - } - parent { - description: "Parent model" - type: string - } - task { - description: "Associated task ID" - type: string - } - iteration { - description: "Iteration (used to update task statistics)" - type: integer - } - } - } - response { - type: object - properties { - updated { - description: "Number of models updated (0 or 1)" - type: integer - enum: [0, 1] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } - } - } - update { - "2.1" { - description: "Update a model" - request { - type: object - required: [ model ] - properties { - model { - description: "Model id" - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - ready { - description: "Indication if the model is final and can be used by other tasks Default is false." - type: boolean - default: false - } - created { - description: "Model creation time (UTC) " - type: string - format: "date-time" - } - ui_cache { - description: "UI cache for this model" - type: object - additionalProperties: true - } - project { - description: "Project to which to model belongs" - type: string - } - task { - description: "Associated task ID" - type: "string" - } - iteration { - description: "Iteration (used to update task statistics if an associated task is reported)" - type: integer - } - - } - } - response { - type: object - properties { - updated { - description: "Number of models updated (0 or 1)" - type: integer - enum: [0, 1] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } - } - } - set_ready { - "2.1" { - description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task." - request { - type: object - required: [ model ] - properties { - model { - description: "Model id" - type: string - } - force_publish_task { - description: "Publish the associated task (if exists) even if it is not in the 'stopped' state. Optional, the default value is False." - type: boolean - } - publish_task { - description: "Indicates that the associated task (if exists) should be published. Optional, the default value is True." - type: boolean - } - } - } - response { - type: object - properties { - updated { - description: "Number of models updated (0 or 1)" - type: integer - enum: [0, 1] - } - published_task { - description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing." - type: object - properties { - id { - description: "Task id" - type: string - } - data { - description: "Data returned from the task publishing operation." - type: object - properties { - committed_versions_results { - description: "Committed versions results" - type: array - items { - type: object - additionalProperties: true - } - } - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" + published_task { + description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing." + type: object + properties { + id { + description: "Task id" + type: string + } + data { + description: "Data returned from the task publishing operation." + type: object + properties { + committed_versions_results { + description: "Committed versions results" + type: array + items { type: object additionalProperties: true } } + updated { + description: "Number of tasks updated (0 or 1)" + type: integer + enum: [ 0, 1 ] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } } } } @@ -662,37 +660,87 @@ } } } - delete { - "2.1" { - description: "Delete a model." - request { - required: [ - model - ] - type: object - properties { - model { - description: "Model ID" - type: string - } - force { - description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published. - """ - type: boolean - } +} +delete { + "2.1" { + description: "Delete a model." + request { + required: [ + model + ] + type: object + properties { + model { + description: "Model ID" + type: string + } + force { + description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published. + """ + type: boolean + } + } + } + response { + type: object + properties { + deleted { + description: "Indicates whether the model was deleted" + type: boolean + } + + } + } + } +} + +make_public { + "2.9" { + description: """Convert company models to public""" + request { + type: object + properties { + ids { + description: "Ids of the models to convert" + type: array + items { type: string} } } - response { - type: object - properties { - deleted { - description: "Indicates whether the model was deleted" - type: boolean - } - + } + response { + type: object + properties { + updated { + description: "Number of models updated" + type: integer } } } } } + +make_private { + "2.9" { + description: """Convert public models to private""" + request { + type: object + properties { + ids { + description: "Ids of the models to convert. Only the models originated by the company can be converted" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of models updated" + type: integer + } + } + } + } +} \ No newline at end of file diff --git a/server/schema/services/projects.conf b/server/schema/services/projects.conf index 5cd5038..e3cb4fb 100644 --- a/server/schema/services/projects.conf +++ b/server/schema/services/projects.conf @@ -573,6 +573,7 @@ get_hyper_parameters { } } } + get_task_tags { "2.8" { description: "Get user and system tags used for the tasks under the specified projects" @@ -580,10 +581,61 @@ get_task_tags { response = ${_definitions.tags_response} } } + get_model_tags { "2.8" { description: "Get user and system tags used for the models under the specified projects" request = ${_definitions.tags_request} response = ${_definitions.tags_response} } +} + +make_public { + "2.9" { + description: """Convert company projects to public""" + request { + type: object + properties { + ids { + description: "Ids of the projects to convert" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of projects updated" + type: integer + } + } + } + } +} + +make_private { + "2.9" { + description: """Convert public projects to private""" + request { + type: object + properties { + ids { + description: "Ids of the projects to convert. Only the projects originated by the company can be converted" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of projects updated" + type: integer + } + } + } + } } \ No newline at end of file diff --git a/server/schema/services/tasks.conf b/server/schema/services/tasks.conf index 2f0d10b..7b4c747 100644 --- a/server/schema/services/tasks.conf +++ b/server/schema/services/tasks.conf @@ -1441,4 +1441,54 @@ add_or_update_artifacts { } } } +} + +make_public { + "2.9" { + description: """Convert company tasks to public""" + request { + type: object + properties { + ids { + description: "Ids of the tasks to convert" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of tasks updated" + type: integer + } + } + } + } +} + +make_private { + "2.9" { + description: """Convert public tasks to private""" + request { + type: object + properties { + ids { + description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of tasks updated" + type: integer + } + } + } + } } \ No newline at end of file diff --git a/server/schema/services/workers.conf b/server/schema/services/workers.conf index 1a18498..81ef288 100644 --- a/server/schema/services/workers.conf +++ b/server/schema/services/workers.conf @@ -135,6 +135,10 @@ description: "Task currently being run by the worker" "$ref": "#/definitions/current_task_entry" } + project { + description: "Project in which currently executing task resides" + "$ref": "#/definitions/id_name_entry" + } queue { description: "Queue from which running task was taken" "$ref": "#/definitions/queue_entry" @@ -151,11 +155,11 @@ type: object properties { id { - description: "Worker ID" + description: "ID" type: string } name { - description: "Worker name" + description: "Name" type: string } } diff --git a/server/server.py b/server/server.py index ec847f9..4458364 100644 --- a/server/server.py +++ b/server/server.py @@ -10,7 +10,7 @@ from werkzeug.exceptions import BadRequest import database from apierrors.base import BaseError from bll.statistics.stats_reporter import StatisticsReporter -from config import config +from config import config, info from elastic.initialize import init_es_data from mongo.initialize import init_mongo_data, pre_populate_data from service_repo import ServiceRepo, APICall @@ -39,9 +39,11 @@ database.initialize() hosts_string = ";".join(sorted(database.get_hosts())) key = "db_init_" + md5(hosts_string.encode()).hexdigest() with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)): - print(key) - init_es_data() + empty_es = init_es_data() empty_db = init_mongo_data() +if empty_es and not empty_db: + log.info(f"ES database seems not migrated") + info.missed_es_upgrade = True if empty_db and config.get("apiserver.pre_populate.enabled", False): pre_populate_data() diff --git a/server/service_repo/auth/auth.py b/server/service_repo/auth/auth.py index 9401d46..d2c2483 100644 --- a/server/service_repo/auth/auth.py +++ b/server/service_repo/auth/auth.py @@ -69,6 +69,10 @@ def authorize_credentials(auth_data, service, action, call_data_items): if fixed_user: if secret_key != fixed_user.password: raise errors.unauthorized.InvalidCredentials('bad username or password') + + if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action): + raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest') + query = Q(id=fixed_user.user_id) with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'): diff --git a/server/service_repo/auth/fixed_user.py b/server/service_repo/auth/fixed_user.py index a4188ba..b814864 100644 --- a/server/service_repo/auth/fixed_user.py +++ b/server/service_repo/auth/fixed_user.py @@ -1,14 +1,12 @@ import hashlib from functools import lru_cache -from typing import Sequence, TypeVar +from typing import Sequence, Optional import attr from config import config from config.info import get_default_company -T = TypeVar("T", bound="FixedUser") - class FixedUsersError(Exception): pass @@ -21,6 +19,8 @@ class FixedUser: name: str company: str = get_default_company() + is_guest: bool = False + def __attrs_post_init__(self): self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest() @@ -28,6 +28,10 @@ class FixedUser: def enabled(cls): return config.get("apiserver.auth.fixed_users.enabled", False) + @classmethod + def guest_enabled(cls): + return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False) + @classmethod def validate(cls): if not cls.enabled(): @@ -39,18 +43,50 @@ class FixedUser: ) @classmethod - @lru_cache() - def from_config(cls) -> Sequence[T]: - return [ + # @lru_cache() + def from_config(cls) -> Sequence["FixedUser"]: + users = [ cls(**user) for user in config.get("apiserver.auth.fixed_users.users", []) ] + if cls.guest_enabled(): + users.insert( + 0, + cls.get_guest_user() + ) + + return users + @classmethod @lru_cache() - def get_by_username(cls, username) -> T: + def get_by_username(cls, username) -> "FixedUser": return next( (user for user in cls.from_config() if user.username == username), None ) + @classmethod + @lru_cache() + def is_guest_endpoint(cls, service, action): + """ + Validate a potential guest user, + This method will verify the user is indeed the guest user, + and that the guest user may access the service/action using its username/password + """ + return any( + ep == ".".join((service, action)) + for ep in config.get("services.auth.fixed_users.guest.allow_endpoints", []) + ) + + @classmethod + def get_guest_user(cls) -> Optional["FixedUser"]: + if cls.guest_enabled(): + return cls( + is_guest=True, + username=config.get("services.auth.fixed_users.guest.username"), + password=config.get("services.auth.fixed_users.guest.password"), + name=config.get("services.auth.fixed_users.guest.name"), + company=config.get("services.auth.fixed_users.guest.default_company"), + ) + def __hash__(self): return hash(self.user_id) diff --git a/server/services/auth.py b/server/services/auth.py index 26b5771..8e0eb26 100644 --- a/server/services/auth.py +++ b/server/services/auth.py @@ -16,7 +16,7 @@ from apimodels.auth import ( ) from apimodels.base import UpdateResponse from bll.auth import AuthBLL -from config import config +from config import config, info from database.errors import translate_errors_context from database.model.auth import User from service_repo import APICall, endpoint @@ -176,4 +176,17 @@ def update(call, company_id, _): @endpoint("auth.fixed_users_mode") def fixed_users_mode(call: APICall, *_, **__): - call.result.data = dict(enabled=FixedUser.enabled()) + data = { + "enabled": FixedUser.enabled(), + "migration_warning": info.missed_es_upgrade, + "guest": { + "enabled": FixedUser.guest_enabled(), + } + } + guest_user = FixedUser.get_guest_user() + if guest_user: + data["guest"]["name"] = guest_user.name + data["guest"]["username"] = guest_user.username + data["guest"]["password"] = guest_user.password + + call.result.data = data diff --git a/server/services/models.py b/server/services/models.py index 1a94762..5429866 100644 --- a/server/services/models.py +++ b/server/services/models.py @@ -5,7 +5,8 @@ from mongoengine import Q, EmbeddedDocument import database from apierrors import errors -from apimodels.base import UpdateResponse +from apierrors.errors.bad_request import InvalidModelId +from apimodels.base import UpdateResponse, MakePublicRequest from apimodels.models import ( CreateModelRequest, CreateModelResponse, @@ -467,3 +468,21 @@ def update(call: APICall, company_id, _): if del_count: _reset_cached_tags(company_id, projects=[model.project]) call.result.data = dict(deleted=del_count > 0) + + +@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Model.set_public( + company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True + ) + + +@endpoint( + "models.make_private", min_version="2.9", request_data_model=MakePublicRequest +) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Model.set_public( + company_id, request.ids, invalid_cls=InvalidModelId, enabled=False + ) diff --git a/server/services/projects.py b/server/services/projects.py index e72a282..e51cda5 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -8,7 +8,8 @@ from mongoengine import Q import database from apierrors import errors -from apimodels.base import UpdateResponse +from apierrors.errors.bad_request import InvalidProjectId +from apimodels.base import UpdateResponse, MakePublicRequest from apimodels.projects import ( GetHyperParamReq, GetHyperParamResp, @@ -422,3 +423,23 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest): projects=request.projects, ) call.result.data = get_tags_response(ret) + + +@endpoint( + "projects.make_public", min_version="2.9", request_data_model=MakePublicRequest +) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Project.set_public( + company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True + ) + + +@endpoint( + "projects.make_private", min_version="2.9", request_data_model=MakePublicRequest +) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Project.set_public( + company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False + ) diff --git a/server/services/tasks.py b/server/services/tasks.py index 79d15bb..3c40301 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -11,7 +11,8 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne from apierrors import errors, APIError -from apimodels.base import UpdateResponse, IdResponse +from apierrors.errors.bad_request import InvalidTaskId +from apimodels.base import UpdateResponse, IdResponse, MakePublicRequest from apimodels.tasks import ( StartedResponse, ResetResponse, @@ -78,10 +79,24 @@ def set_task_status_from_call( task = TaskBLL.get_task_with_access( request.task, company_id=company_id, - only=tuple({"status", "project"} | fields_resolver.get_names()), + only=tuple( + {"status", "project", "started", "duration"} | fields_resolver.get_names() + ), requires_write_access=True, ) + if "duration" not in fields_resolver.get_names(): + if new_status == Task.started: + fields_resolver.add_fields(min__duration=max(0, task.duration or 0)) + elif new_status in ( + TaskStatus.completed, + TaskStatus.failed, + TaskStatus.stopped, + ): + fields_resolver.add_fields( + duration=int((task.started - datetime.utcnow()).total_seconds()) + ) + status_reason = request.status_reason status_message = request.status_message force = request.force @@ -354,9 +369,7 @@ def _update_cached_tags(company: str, project: str, fields: dict): def _reset_cached_tags(company: str, projects: Sequence[str]): - org_bll.reset_tags( - company, Tags.Task, projects=projects - ) + org_bll.reset_tags(company, Tags.Task, projects=projects) @endpoint( @@ -573,9 +586,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): if updated: new_project = fixed_fields.get("project", task.project) if new_project != task.project: - _reset_cached_tags( - company_id, projects=[new_project, task.project] - ) + _reset_cached_tags(company_id, projects=[new_project, task.project]) else: _update_cached_tags( company_id, project=task.project, fields=fixed_fields @@ -1005,3 +1016,19 @@ def add_or_update_artifacts( task_id=request.task, company_id=company_id, artifacts=request.artifacts ) call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated) + + +@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Task.set_public( + company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True + ) + + +@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Task.set_public( + company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False + ) diff --git a/server/tests/automated/test_models.py b/server/tests/automated/test_models.py index e6d9747..126a5f4 100644 --- a/server/tests/automated/test_models.py +++ b/server/tests/automated/test_models.py @@ -1,3 +1,4 @@ +from apierrors.errors.bad_request import InvalidModelId from tests.automated import TestService MODEL_CANNOT_BE_UPDATED_CODES = (400, 203) @@ -7,7 +8,7 @@ IN_PROGRESS = "in_progress" class TestModelsService(TestService): - def setUp(self, version="2.8"): + def setUp(self, version="2.9"): super().setUp(version=version) def test_publish_output_model_running_task(self): @@ -197,6 +198,28 @@ class TestModelsService(TestService): res = self.api.models.get_frameworks(projects=[project]) self.assertEqual([], res.frameworks) + def test_make_public(self): + m1 = self._create_model(name="public model test") + + # model with company_origin not set to the current company cannot be converted to private + with self.api.raises(InvalidModelId): + self.api.models.make_private(ids=[m1]) + + # public model can be retrieved but not updated + res = self.api.models.make_public(ids=[m1]) + self.assertEqual(res.updated, 1) + res = self.api.models.get_all(id=[m1]) + self.assertEqual([m.id for m in res.models], [m1]) + with self.api.raises(InvalidModelId): + self.api.models.update(model=m1, name="public model test change 1") + + # task made private again and can be both retrieved and updated + res = self.api.models.make_private(ids=[m1]) + self.assertEqual(res.updated, 1) + res = self.api.models.get_all(id=[m1]) + self.assertEqual([m.id for m in res.models], [m1]) + self.api.models.update(model=m1, name="public model test change 2") + def _assert_task_status(self, task_id, status): task = self.api.tasks.get_by_id(task=task_id).task assert task.status == status diff --git a/server/tests/automated/test_projects_edit.py b/server/tests/automated/test_projects_edit.py new file mode 100644 index 0000000..49863cf --- /dev/null +++ b/server/tests/automated/test_projects_edit.py @@ -0,0 +1,34 @@ +from apierrors.errors.bad_request import InvalidProjectId +from apierrors.errors.forbidden import NoWritePermission +from config import config +from tests.automated import TestService + + +log = config.logger(__file__) + + +class TestProjectsEdit(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.9") + + def test_make_public(self): + p1 = self.create_temp("projects", name="Test public", description="test") + + # project with company_origin not set to the current company cannot be converted to private + with self.api.raises(InvalidProjectId): + self.api.projects.make_private(ids=[p1]) + + # public project can be retrieved but not updated + res = self.api.projects.make_public(ids=[p1]) + self.assertEqual(res.updated, 1) + res = self.api.projects.get_all(id=[p1]) + self.assertEqual([p.id for p in res.projects], [p1]) + with self.api.raises(NoWritePermission): + self.api.projects.update(project=p1, name="Test public change 1") + + # task made private again and can be both retrieved and updated + res = self.api.projects.make_private(ids=[p1]) + self.assertEqual(res.updated, 1) + res = self.api.projects.get_all(id=[p1]) + self.assertEqual([p.id for p in res.projects], [p1]) + self.api.projects.update(project=p1, name="Test public change 2") diff --git a/server/tests/automated/test_tasks_edit.py b/server/tests/automated/test_tasks_edit.py index 0819f43..eec8c57 100644 --- a/server/tests/automated/test_tasks_edit.py +++ b/server/tests/automated/test_tasks_edit.py @@ -1,4 +1,5 @@ -from apierrors.errors.bad_request import InvalidModelId, ValidationError +from apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId +from apierrors.errors.forbidden import NoWritePermission from config import config from tests.automated import TestService @@ -8,7 +9,7 @@ log = config.logger(__file__) class TestTasksEdit(TestService): def setUp(self, **kwargs): - super().setUp(version=2.5) + super().setUp(version="2.9") def new_task(self, **kwargs): self.update_missing( @@ -145,3 +146,28 @@ class TestTasksEdit(TestService): self.api.tasks.delete, task=new_task, move_to_trash=False, force=True ) return new_task + + def test_make_public(self): + task = self.new_task() + + # task is created as private and can be updated + self.api.tasks.started(task=task) + + # task with company_origin not set to the current company cannot be converted to private + with self.api.raises(InvalidTaskId): + self.api.tasks.make_private(ids=[task]) + + # public task can be retrieved but not updated + res = self.api.tasks.make_public(ids=[task]) + self.assertEqual(res.updated, 1) + res = self.api.tasks.get_all_ex(id=[task]) + self.assertEqual([t.id for t in res.tasks], [task]) + with self.api.raises(NoWritePermission): + self.api.tasks.stopped(task=task) + + # task made private again and can be both retrieved and updated + res = self.api.tasks.make_private(ids=[task]) + self.assertEqual(res.updated, 1) + res = self.api.tasks.get_all_ex(id=[task]) + self.assertEqual([t.id for t in res.tasks], [task]) + self.api.tasks.stopped(task=task)