diff --git a/server/apimodels/base.py b/server/apimodels/base.py index 51ba290..f1e8cb6 100644 --- a/server/apimodels/base.py +++ b/server/apimodels/base.py @@ -1,7 +1,8 @@ from jsonmodels import models, fields +from jsonmodels.validators import Length from mongoengine.base import BaseDocument -from apimodels import DictField +from apimodels import DictField, ListField class MongoengineFieldsDict(DictField): @@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField): """ mongoengine_update_operators = ( - 'inc', - 'dec', - 'push', - 'push_all', - 'pop', - 'pull', - 'pull_all', - 'add_to_set', + "inc", + "dec", + "push", + "push_all", + "pop", + "pull", + "pull_all", + "add_to_set", ) @staticmethod @@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField): @classmethod def _normalize_mongo_field_path(cls, path, value): - parts = path.split('__') + parts = path.split("__") if len(parts) > 1: - if parts[0] == 'set': + if parts[0] == "set": parts = parts[1:] - elif parts[0] == 'unset': + elif parts[0] == "unset": parts = parts[1:] value = None elif parts[0] in cls.mongoengine_update_operators: return None, None - return '.'.join(parts), cls._normalize_mongo_value(value) + return ".".join(parts), cls._normalize_mongo_value(value) def parse_value(self, value): value = super(MongoengineFieldsDict, self).parse_value(value) @@ -62,3 +63,7 @@ class PagedRequest(models.Base): class IdResponse(models.Base): id = fields.StringField(required=True) + + +class MakePublicRequest(models.Base): + ids = ListField(items_types=str, validators=[Length(minimum_value=1)]) diff --git a/server/apimodels/events.py b/server/apimodels/events.py index 83427ef..9ab55b8 100644 --- a/server/apimodels/events.py +++ b/server/apimodels/events.py @@ -3,7 +3,7 @@ from typing import Sequence, Optional from jsonmodels import validators from jsonmodels.fields import StringField, BoolField from jsonmodels.models import Base -from jsonmodels.validators import Length +from jsonmodels.validators import Length, Min, Max from apimodels import ListField, IntField, ActualEnumField from bll.event.event_metrics import EventType @@ -11,7 +11,7 @@ from bll.event.scalar_key import ScalarKeyEnum class HistogramRequestBase(Base): - samples: int = IntField(default=10000) + samples: int = IntField(default=6000, validators=[Min(1), Max(6000)]) key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter) @@ -21,7 +21,7 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase): class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): tasks: Sequence[str] = ListField( - items_types=str, validators=[Length(minimum_value=1)] + items_types=str, validators=[Length(minimum_value=1, maximum_value=10)] ) diff --git a/server/apimodels/workers.py b/server/apimodels/workers.py index 40d8f4d..7fbb950 100644 --- a/server/apimodels/workers.py +++ b/server/apimodels/workers.py @@ -67,6 +67,7 @@ class WorkerEntry(Base, JsonSerializableMixin): company = EmbeddedField(IdNameEntry) ip = StringField() task = EmbeddedField(IdNameEntry) + project = EmbeddedField(IdNameEntry) queue = StringField() # queue from which current task was taken queues = ListField(str) # list of queues this worker listens to register_time = DateTimeField(required=True) diff --git a/server/bll/event/debug_images_iterator.py b/server/bll/event/debug_images_iterator.py index ccedf09..a98a85e 100644 --- a/server/bll/event/debug_images_iterator.py +++ b/server/bll/event/debug_images_iterator.py @@ -208,7 +208,11 @@ class DebugImagesIterator: "size": 0, "query": { "bool": { - "must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}] + "must": [ + {"term": {"task": task}}, + {"terms": {"metric": metrics}}, + {"exists": {"field": "url"}}, + ] } }, "aggs": { @@ -251,7 +255,7 @@ class DebugImagesIterator: } with translate_errors_context(), TimingContext("es", "_init_metric_states"): - es_res = self.es.search(index=es_index, body=es_req, routing=task) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return [] @@ -298,6 +302,7 @@ class DebugImagesIterator: must_conditions = [ {"term": {"task": metric.task}}, {"term": {"metric": metric.name}}, + {"exists": {"field": "url"}}, ] must_not_conditions = [] @@ -368,7 +373,7 @@ class DebugImagesIterator: "terms": { "field": "iter", "size": iter_count, - "order": {"_term": "desc" if navigate_earlier else "asc"}, + "order": {"_key": "desc" if navigate_earlier else "asc"}, }, "aggs": { "variants": { @@ -387,7 +392,7 @@ class DebugImagesIterator: }, } with translate_errors_context(), TimingContext("es", "get_debug_image_events"): - es_res = self.es.search(index=es_index, body=es_req, routing=metric.task) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return metric.task, metric.name, [] diff --git a/server/bll/event/event_bll.py b/server/bll/event/event_bll.py index 428aff7..65e2667 100644 --- a/server/bll/event/event_bll.py +++ b/server/bll/event/event_bll.py @@ -3,7 +3,7 @@ from collections import defaultdict from contextlib import closing from datetime import datetime from operator import attrgetter -from typing import Sequence, Set, Tuple +from typing import Sequence, Set, Tuple, Optional import six from elasticsearch import helpers @@ -22,6 +22,7 @@ from database.errors import translate_errors_context from database.model.task.task import Task, TaskStatus from redis_manager import redman from timing_context import TimingContext +from tools import safe_get from utilities.dicts import flatten_nested_items # noinspection PyTypeChecker @@ -134,7 +135,6 @@ class EventBLL(object): es_action = { "_op_type": "index", # overwrite if exists with same ID "_index": index_name, - "_type": "event", "_source": event, } @@ -144,7 +144,6 @@ class EventBLL(object): else: es_action["_id"] = dbutils.id() - es_action["_routing"] = task_id task_ids.add(task_id) if ( iter is not None @@ -342,14 +341,9 @@ class EventBLL(object): } with translate_errors_context(), TimingContext("es", "scroll_task_events"): - es_res = self.es.search( - index=es_index, body=es_req, scroll="1h", routing=task_id - ) - - events = [hit["_source"] for hit in es_res["hits"]["hits"]] - next_scroll_id = es_res["_scroll_id"] - total_events = es_res["hits"]["total"] + es_res = self.es.search(index=es_index, body=es_req, scroll="1h") + events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) return events, next_scroll_id, total_events def get_last_iterations_per_event_metric_variant( @@ -377,7 +371,7 @@ class EventBLL(object): "terms": { "field": "iter", "size": num_last_iterations, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, } } }, @@ -393,7 +387,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "task_last_iter_metric_variant" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return [] @@ -422,13 +416,11 @@ class EventBLL(object): if not self.es.indices.exists(es_index): return TaskEventsResult() - query = {"bool": defaultdict(list)} - + must = [] if last_iterations_per_plot is None: - must = query["bool"]["must"] must.append({"terms": {"task": tasks}}) else: - should = query["bool"]["should"] + should = [] for i, task_id in enumerate(tasks): last_iters = self.get_last_iterations_per_event_metric_variant( es_index, task_id, last_iterations_per_plot, event_type @@ -451,32 +443,41 @@ class EventBLL(object): ) if not should: return TaskEventsResult() + must.append({"bool": {"should": should}}) if sort is None: sort = [{"timestamp": {"order": "asc"}}] - es_req = {"sort": sort, "size": min(size, 10000), "query": query} - - routing = ",".join(tasks) + es_req = { + "sort": sort, + "size": min(size, 10000), + "query": {"bool": {"must": must}}, + } with translate_errors_context(), TimingContext("es", "get_task_plots"): es_res = self.es.search( - index=es_index, - body=es_req, - ignore=404, - routing=routing, - scroll="1h", + index=es_index, body=es_req, ignore=404, scroll="1h", ) - events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] - # scroll id may be missing when queering a totally empty DB - next_scroll_id = es_res.get("_scroll_id") - total_events = es_res["hits"]["total"] - + events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) return TaskEventsResult( events=events, next_scroll_id=next_scroll_id, total_events=total_events ) + def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]: + """ + Return events and next scroll id from the scrolled query + Release the scroll once it is exhausted + """ + total_events = safe_get(es_res, "hits/total/value", default=0) + events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])] + next_scroll_id = es_res.get("_scroll_id") + if next_scroll_id and not events: + self.es.clear_scroll(scroll_id=next_scroll_id) + next_scroll_id = None + + return events, total_events, next_scroll_id + def get_task_events( self, company_id, @@ -502,20 +503,16 @@ class EventBLL(object): if not self.es.indices.exists(es_index): return TaskEventsResult() - query = {"bool": defaultdict(list)} - - if metric or variant: - must = query["bool"]["must"] - if metric: - must.append({"term": {"metric": metric}}) - if variant: - must.append({"term": {"variant": variant}}) + must = [] + if metric: + must.append({"term": {"metric": metric}}) + if variant: + must.append({"term": {"variant": variant}}) if last_iter_count is None: - must = query["bool"]["must"] must.append({"terms": {"task": task_ids}}) else: - should = query["bool"]["should"] + should = [] for i, task_id in enumerate(task_ids): last_iters = self.get_last_iters( es_index, task_id, event_type, last_iter_count @@ -534,27 +531,23 @@ class EventBLL(object): ) if not should: return TaskEventsResult() + must.append({"bool": {"should": should}}) if sort is None: sort = [{"timestamp": {"order": "asc"}}] - es_req = {"sort": sort, "size": min(size, 10000), "query": query} - - routing = ",".join(task_ids) + es_req = { + "sort": sort, + "size": min(size, 10000), + "query": {"bool": {"must": must}}, + } with translate_errors_context(), TimingContext("es", "get_task_events"): es_res = self.es.search( - index=es_index, - body=es_req, - ignore=404, - routing=routing, - scroll="1h", + index=es_index, body=es_req, ignore=404, scroll="1h", ) - events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])] - next_scroll_id = es_res.get("_scroll_id") - total_events = es_res["hits"]["total"] - + events, total_events, next_scroll_id = self._get_events_from_es_res(es_res) return TaskEventsResult( events=events, next_scroll_id=next_scroll_id, total_events=total_events ) @@ -590,7 +583,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) metrics = {} for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"): @@ -622,14 +615,14 @@ class EventBLL(object): "terms": { "field": "metric", "size": EventMetrics.MAX_METRICS_COUNT, - "order": {"_term": "asc"}, + "order": {"_key": "asc"}, }, "aggs": { "variants": { "terms": { "field": "variant", "size": EventMetrics.MAX_VARIANTS_COUNT, - "order": {"_term": "asc"}, + "order": {"_key": "asc"}, }, "aggs": { "last_value": { @@ -659,7 +652,7 @@ class EventBLL(object): with translate_errors_context(), TimingContext( "es", "events_get_metrics_and_variants" ): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) metrics = [] max_timestamp = 0 @@ -706,7 +699,7 @@ class EventBLL(object): "sort": ["iter"], } with translate_errors_context(), TimingContext("es", "task_stats_vector"): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) vectors = [] iterations = [] @@ -727,7 +720,7 @@ class EventBLL(object): "terms": { "field": "iter", "size": iters, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, } } }, @@ -737,7 +730,7 @@ class EventBLL(object): es_req["query"]["bool"]["must"].append({"term": {"type": event_type}}) with translate_errors_context(), TimingContext("es", "task_last_iter"): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) if "aggregations" not in es_res: return [] @@ -759,8 +752,6 @@ class EventBLL(object): es_index = EventMetrics.get_index_name(company_id, "*") es_req = {"query": {"term": {"task": task_id}}} with translate_errors_context(), TimingContext("es", "delete_task_events"): - es_res = self.es.delete_by_query( - index=es_index, body=es_req, routing=task_id, refresh=True - ) + es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True) return es_res.get("deleted", 0) diff --git a/server/bll/event/event_metrics.py b/server/bll/event/event_metrics.py index a331f2d..a930ab4 100644 --- a/server/bll/event/event_metrics.py +++ b/server/bll/event/event_metrics.py @@ -1,12 +1,11 @@ import itertools from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures.thread import ThreadPoolExecutor from enum import Enum from functools import partial from operator import itemgetter -from typing import Sequence, Tuple, Callable, Iterable +from typing import Sequence, Tuple -from boltons.iterutils import bucketize from elasticsearch import Elasticsearch from mongoengine import Q @@ -16,7 +15,7 @@ from config import config from database.errors import translate_errors_context from database.model.task.task import Task from timing_context import TimingContext -from utilities import safe_get +from tools import safe_get log = config.logger(__file__) @@ -30,14 +29,18 @@ class EventType(Enum): class EventMetrics: - MAX_TASKS_COUNT = 50 - MAX_METRICS_COUNT = 200 - MAX_VARIANTS_COUNT = 500 + MAX_METRICS_COUNT = 100 + MAX_VARIANTS_COUNT = 100 MAX_AGGS_ELEMENTS_COUNT = 50 + MAX_SAMPLE_BUCKETS = 6000 def __init__(self, es: Elasticsearch): self.es = es + @property + def _max_concurrency(self): + return config.get("services.events.max_metrics_concurrency", 4) + @staticmethod def get_index_name(company_id, event_type): event_type = event_type.lower().replace(" ", "_") @@ -51,15 +54,48 @@ class EventMetrics: The amount of points in each histogram should not exceed the requested samples """ + es_index = self.get_index_name(company_id, "training_stats_scalar") + if not self.es.indices.exists(es_index): + return {} - return self._run_get_scalar_metrics_as_parallel( - company_id, - task_ids=[task_id], - samples=samples, - key=ScalarKey.resolve(key), - get_func=self._get_scalar_average, + return self._get_scalar_average_per_iter_core( + task_id, es_index, samples, ScalarKey.resolve(key) ) + def _get_scalar_average_per_iter_core( + self, + task_id: str, + es_index: str, + samples: int, + key: ScalarKey, + run_parallel: bool = True, + ) -> dict: + intervals = self._get_task_metric_intervals( + es_index=es_index, task_id=task_id, samples=samples, field=key.field + ) + if not intervals: + return {} + interval_groups = self._group_task_metric_intervals(intervals) + + get_scalar_average = partial( + self._get_scalar_average, task_id=task_id, es_index=es_index, key=key + ) + if run_parallel: + with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool: + metrics = itertools.chain.from_iterable( + pool.map(get_scalar_average, interval_groups) + ) + else: + metrics = itertools.chain.from_iterable( + get_scalar_average(group) for group in interval_groups + ) + + ret = defaultdict(dict) + for metric_key, metric_values in metrics: + ret[metric_key].update(metric_values) + + return ret + def compare_scalar_metrics_average_per_iter( self, company_id, @@ -72,12 +108,6 @@ class EventMetrics: Compare scalar metrics for different tasks per metric and variant The amount of points in each histogram should not exceed the requested samples """ - if len(task_ids) > self.MAX_TASKS_COUNT: - raise errors.BadRequest( - f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison", - len(task_ids), - ) - task_name_by_id = {} with translate_errors_context(): task_objs = Task.get_many( @@ -90,7 +120,6 @@ class EventMetrics: if len(task_objs) < len(task_ids): invalid = tuple(set(task_ids) - set(r.id for r in task_objs)) raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) - task_name_by_id = {t.id: t.name for t in task_objs} companies = {t.company for t in task_objs} @@ -99,138 +128,95 @@ class EventMetrics: "only tasks from the same company are supported" ) - ret = self._run_get_scalar_metrics_as_parallel( - next(iter(companies)), - task_ids=task_ids, - samples=samples, - key=ScalarKey.resolve(key), - get_func=self._get_scalar_average_per_task, - ) - - for metric_data in ret.values(): - for variant_data in metric_data.values(): - for task_id, task_data in variant_data.items(): - task_data["name"] = task_name_by_id[task_id] - - return ret - - TaskMetric = Tuple[str, str, str] - - MetricInterval = Tuple[int, Sequence[TaskMetric]] - MetricData = Tuple[str, dict] - - def _split_metrics_by_max_aggs_count( - self, task_metrics: Sequence[TaskMetric] - ) -> Iterable[Sequence[TaskMetric]]: - """ - Return task metrics in groups where amount of task metrics in each group - is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and - variants while always preserving all their tasks in the same group - """ - if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT: - yield task_metrics - return - - tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2)) - groups = [] - for group in tm_grouped.values(): - groups.append(group) - if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT: - yield list(itertools.chain(*groups)) - groups = [] - - if groups: - yield list(itertools.chain(*groups)) - - return - - def _run_get_scalar_metrics_as_parallel( - self, - company_id: str, - task_ids: Sequence[str], - samples: int, - key: ScalarKey, - get_func: Callable[ - [MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData] - ], - ) -> dict: - """ - Group metrics per interval length and execute get_func for each group in parallel - :param company_id: id of the company - :params task_ids: ids of the tasks to collect data for - :param samples: maximum number of samples per metric - :param get_func: callable that given metric names for the same interval - performs histogram aggregation for the metrics and return the aggregated data - """ - es_index = self.get_index_name(company_id, "training_stats_scalar") + es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar") if not self.es.indices.exists(es_index): return {} - intervals = self._get_metric_intervals( - es_index=es_index, task_ids=task_ids, samples=samples, field=key.field + get_scalar_average_per_iter = partial( + self._get_scalar_average_per_iter_core, + es_index=es_index, + samples=samples, + key=ScalarKey.resolve(key), + run_parallel=False, ) - - if not intervals: - return {} - - intervals = list( - itertools.chain.from_iterable( - zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms)) - for i, tms in intervals - ) - ) - max_concurrency = config.get("services.events.max_metrics_concurrency", 4) - with ThreadPoolExecutor(max_workers=max_concurrency) as pool: - metrics = itertools.chain.from_iterable( - pool.map( - partial(get_func, task_ids=task_ids, es_index=es_index, key=key), - intervals, - ) + with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool: + task_metrics = zip( + task_ids, pool.map(get_scalar_average_per_iter, task_ids) ) - ret = defaultdict(dict) - for metric_key, metric_values in metrics: - ret[metric_key].update(metric_values) + res = defaultdict(lambda: defaultdict(dict)) + for task_id, task_data in task_metrics: + task_name = task_name_by_id[task_id] + for metric_key, metric_data in task_data.items(): + for variant_key, variant_data in metric_data.items(): + variant_data["name"] = task_name + res[metric_key][variant_key][task_id] = variant_data - return ret + return res - def _get_metric_intervals( - self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter" + MetricInterval = Tuple[str, str, int, int] + MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]] + + @classmethod + def _group_task_metric_intervals( + cls, intervals: Sequence[MetricInterval] + ) -> Sequence[MetricIntervalGroup]: + """ + Group task metric intervals so that the following conditions are meat: + - All the metrics in the same group have the same interval (with 10% rounding) + - The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT + - The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS + """ + metric_interval_groups = [] + interval_group = [] + group_interval_upper_bound = 0 + group_max_interval = 0 + group_samples = 0 + for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)): + if ( + interval > group_interval_upper_bound + or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS + or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT + ): + if interval_group: + metric_interval_groups.append((group_max_interval, interval_group)) + interval_group = [] + group_max_interval = interval + group_interval_upper_bound = interval + int(interval * 0.1) + group_samples = 0 + interval_group.append((metric, variant)) + group_samples += size + group_max_interval = max(group_max_interval, interval) + if interval_group: + metric_interval_groups.append((group_max_interval, interval_group)) + + return metric_interval_groups + + def _get_task_metric_intervals( + self, es_index, task_id: str, samples: int, field: str = "iter" ) -> Sequence[MetricInterval]: """ Calculate interval per task metric variant so that the resulting amount of points does not exceed sample. - Return metric variants grouped by interval value with 10% rounding - For samples==0 return empty list + Return the list og metric variant intervals as the following tuple: + (metric, variant, interval, samples) """ - default_intervals = [(1, [])] - if not samples: - return default_intervals - es_req = { "size": 0, - "query": {"terms": {"task": task_ids}}, + "query": {"term": {"task": task_id}}, "aggs": { - "tasks": { - "terms": {"field": "task", "size": self.MAX_TASKS_COUNT}, + "metrics": { + "terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}, "aggs": { - "metrics": { + "variants": { "terms": { - "field": "metric", - "size": self.MAX_METRICS_COUNT, + "field": "variant", + "size": self.MAX_VARIANTS_COUNT, }, "aggs": { - "variants": { - "terms": { - "field": "variant", - "size": self.MAX_VARIANTS_COUNT, - }, - "aggs": { - "count": {"value_count": {"field": field}}, - "min_index": {"min": {"field": field}}, - "max_index": {"max": {"field": field}}, - }, - } + "count": {"value_count": {"field": field}}, + "min_index": {"min": {"field": field}}, + "max_index": {"max": {"field": field}}, }, } }, @@ -239,88 +225,75 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "task_stats_get_interval"): - es_res = self.es.search( - index=es_index, body=es_req, routing=",".join(task_ids) - ) + es_res = self.es.search(index=es_index, body=es_req) aggs_result = es_res.get("aggregations") if not aggs_result: - return default_intervals + return [] - intervals = [ - ( - task["key"], - metric["key"], - variant["key"], - self._calculate_metric_interval(variant, samples), - ) - for task in aggs_result["tasks"]["buckets"] - for metric in task["metrics"]["buckets"] + return [ + self._build_metric_interval(metric["key"], variant["key"], variant, samples) + for metric in aggs_result["metrics"]["buckets"] for variant in metric["variants"]["buckets"] ] - metric_intervals = [] - upper_border = 0 - interval_metrics = None - for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)): - if not interval_metrics or interval > upper_border: - interval_metrics = [] - metric_intervals.append((interval, interval_metrics)) - upper_border = interval + int(interval * 0.1) - interval_metrics.append((task, metric, variant)) - - return metric_intervals - @staticmethod - def _calculate_metric_interval(metric_variant: dict, samples: int) -> int: + def _build_metric_interval( + metric: str, variant: str, data: dict, samples: int + ) -> Tuple[str, str, int, int]: """ Calculate index interval per metric_variant variant so that the total amount of intervals does not exceeds the samples + Return the interval and resulting amount of intervals """ - count = safe_get(metric_variant, "count/value") - if not count or count < samples: - return 1 + count = safe_get(data, "count/value", default=0) + if count < samples: + return metric, variant, 1, count - min_index = safe_get(metric_variant, "min_index/value", default=0) - max_index = safe_get(metric_variant, "max_index/value", default=min_index) - return max(1, int(max_index - min_index + 1) // samples) + min_index = safe_get(data, "min_index/value", default=0) + max_index = safe_get(data, "max_index/value", default=min_index) + return ( + metric, + variant, + max(1, int(max_index - min_index + 1) // samples), + samples, + ) + + MetricData = Tuple[str, dict] def _get_scalar_average( self, - metrics_interval: MetricInterval, - task_ids: Sequence[str], + metrics_interval: MetricIntervalGroup, + task_id: str, es_index: str, key: ScalarKey, ) -> Sequence[MetricData]: """ Retrieve scalar histograms per several metric variants that share the same interval - Note: the function works with a single task only """ - - assert len(task_ids) == 1 - interval, task_metrics = metrics_interval + interval, metrics = metrics_interval aggregation = self._add_aggregation_average(key.get_aggregation(interval)) aggs = { "metrics": { "terms": { "field": "metric", "size": self.MAX_METRICS_COUNT, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, }, "aggs": { "variants": { "terms": { "field": "variant", "size": self.MAX_VARIANTS_COUNT, - "order": {"_term": "desc"}, + "order": {"_key": "desc"}, }, "aggs": aggregation, } }, } } - aggs_result = self._query_aggregation_for_metrics_and_tasks( - es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics + aggs_result = self._query_aggregation_for_task_metrics( + es_index, aggs=aggs, task_id=task_id, metrics=metrics ) if not aggs_result: @@ -341,61 +314,6 @@ class EventMetrics: ] return metrics - def _get_scalar_average_per_task( - self, - metrics_interval: MetricInterval, - task_ids: Sequence[str], - es_index: str, - key: ScalarKey, - ) -> Sequence[MetricData]: - """ - Retrieve scalar histograms per several metric variants that share the same interval - """ - interval, task_metrics = metrics_interval - - aggregation = self._add_aggregation_average(key.get_aggregation(interval)) - aggs = { - "metrics": { - "terms": {"field": "metric", "size": self.MAX_METRICS_COUNT}, - "aggs": { - "variants": { - "terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT}, - "aggs": { - "tasks": { - "terms": { - "field": "task", - "size": self.MAX_TASKS_COUNT, - }, - "aggs": aggregation, - } - }, - } - }, - } - } - - aggs_result = self._query_aggregation_for_metrics_and_tasks( - es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics - ) - - if not aggs_result: - return {} - - metrics = [ - ( - metric["key"], - { - variant["key"]: { - task["key"]: key.get_iterations_data(task) - for task in variant["tasks"]["buckets"] - } - for variant in metric["variants"]["buckets"] - }, - ) - for metric in aggs_result["metrics"]["buckets"] - ] - return metrics - @staticmethod def _add_aggregation_average(aggregation): average_agg = {"avg_val": {"avg": {"field": "value"}}} @@ -404,69 +322,55 @@ class EventMetrics: for key, value in aggregation.items() } - def _query_aggregation_for_metrics_and_tasks( + def _query_aggregation_for_task_metrics( self, es_index: str, aggs: dict, - task_ids: Sequence[str], - task_metrics: Sequence[TaskMetric], + task_id: str, + metrics: Sequence[Tuple[str, str]], ) -> dict: """ Return the result of elastic search query for the given aggregation filtered by the given task_ids and metrics """ - if task_metrics: - condition = { - "should": [ - self._build_metric_terms(task, metric, variant) - for task, metric, variant in task_metrics - ] - } - else: - condition = {"must": [{"terms": {"task": task_ids}}]} + must = [{"term": {"task": task_id}}] + if metrics: + should = [ + { + "bool": { + "must": [ + {"term": {"metric": metric}}, + {"term": {"variant": variant}}, + ] + } + } + for metric, variant in metrics + ] + must.append({"bool": {"should": should}}) + es_req = { "size": 0, - "_source": {"excludes": []}, - "query": {"bool": condition}, + "query": {"bool": {"must": must}}, "aggs": aggs, - "version": True, } with translate_errors_context(), TimingContext("es", "task_stats_scalar"): - es_res = self.es.search( - index=es_index, body=es_req, routing=",".join(task_ids) - ) + es_res = self.es.search(index=es_index, body=es_req) return es_res.get("aggregations") - @staticmethod - def _build_metric_terms(task: str, metric: str, variant: str) -> dict: - """ - Build query term for a metric + variant - """ - return { - "bool": { - "must": [ - {"term": {"task": task}}, - {"term": {"metric": metric}}, - {"term": {"variant": variant}}, - ] - } - } - def get_tasks_metrics( self, company_id, task_ids: Sequence, event_type: EventType - ) -> Sequence[Tuple]: + ) -> Sequence: """ For the requested tasks return all the metrics that reported events of the requested types """ es_index = EventMetrics.get_index_name(company_id, event_type.value) if not self.es.indices.exists(es_index): - return [(tid, []) for tid in task_ids] + return {} - max_concurrency = config.get("services.events.max_metrics_concurrency", 4) - with ThreadPoolExecutor(max_concurrency) as pool: + with ThreadPoolExecutor(self._max_concurrency) as pool: res = pool.map( partial( self._get_task_metrics, es_index=es_index, event_type=event_type, @@ -494,7 +398,7 @@ class EventMetrics: } with translate_errors_context(), TimingContext("es", "_get_task_metrics"): - es_res = self.es.search(index=es_index, body=es_req, routing=task_id) + es_res = self.es.search(index=es_index, body=es_req) return [ metric["key"] diff --git a/server/bll/event/log_events_iterator.py b/server/bll/event/log_events_iterator.py index 89a32e2..3160060 100644 --- a/server/bll/event/log_events_iterator.py +++ b/server/bll/event/log_events_iterator.py @@ -71,9 +71,9 @@ class LogEventsIterator: es_req["search_after"] = [from_timestamp] with translate_errors_context(), TimingContext("es", "get_task_events"): - es_result = self.es.search(index=es_index, body=es_req, routing=task_id) + es_result = self.es.search(index=es_index, body=es_req) hits = es_result["hits"]["hits"] - hits_total = es_result["hits"]["total"] + hits_total = es_result["hits"]["total"]["value"] if not hits: return [], hits_total @@ -92,7 +92,7 @@ class LogEventsIterator: } }, } - es_result = self.es.search(index=es_index, body=es_req, routing=task_id) + es_result = self.es.search(index=es_index, body=es_req) hits = es_result["hits"]["hits"] if not hits or len(hits) < 2: # if only one element is returned for the last timestamp diff --git a/server/bll/event/scalar_key.py b/server/bll/event/scalar_key.py index 18fe1b1..fb63ab5 100644 --- a/server/bll/event/scalar_key.py +++ b/server/bll/event/scalar_key.py @@ -111,7 +111,7 @@ class TimestampKey(ScalarKey): self.name: { "date_histogram": { "field": "timestamp", - "interval": f"{interval}ms", + "fixed_interval": f"{interval}ms", "min_doc_count": 1, } } @@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey): self.name: { "date_histogram": { "field": "timestamp", - "interval": f"{interval}ms", + "fixed_interval": f"{interval}ms", "min_doc_count": 1, "format": "strict_date_time", } diff --git a/server/bll/queue/queue_metrics.py b/server/bll/queue/queue_metrics.py index 41d7df1..ad76f9f 100644 --- a/server/bll/queue/queue_metrics.py +++ b/server/bll/queue/queue_metrics.py @@ -18,7 +18,6 @@ log = config.logger(__file__) class QueueMetrics: class EsKeys: - DOC_TYPE = "metrics" WAITING_TIME_FIELD = "average_waiting_time" QUEUE_LENGTH_FIELD = "queue_length" TIMESTAMP_FIELD = "timestamp" @@ -66,7 +65,6 @@ class QueueMetrics: entries = [e for e in queue.entries if e.added] return dict( _index=es_index, - _type=self.EsKeys.DOC_TYPE, _source={ self.EsKeys.TIMESTAMP_FIELD: timestamp, self.EsKeys.QUEUE_FIELD: queue.id, @@ -93,7 +91,6 @@ class QueueMetrics: def _search_company_metrics(self, company_id: str, es_req: dict) -> dict: return self.es.search( index=f"{self._queue_metrics_prefix_for_company(company_id)}*", - doc_type=self.EsKeys.DOC_TYPE, body=es_req, ) @@ -109,7 +106,7 @@ class QueueMetrics: "dates": { "date_histogram": { "field": cls.EsKeys.TIMESTAMP_FIELD, - "interval": f"{interval}s", + "fixed_interval": f"{interval}s", "min_doc_count": 1, }, "aggs": { diff --git a/server/bll/statistics/stats_reporter.py b/server/bll/statistics/stats_reporter.py index 5d9f17c..4ed33ec 100644 --- a/server/bll/statistics/stats_reporter.py +++ b/server/bll/statistics/stats_reporter.py @@ -237,7 +237,6 @@ class StatisticsReporter: def _run_worker_stats_query(cls, company_id, es_req) -> dict: return worker_bll.es_client.search( index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*", - doc_type="stat", body=es_req, ) diff --git a/server/bll/util.py b/server/bll/util.py index 77c1b4d..0991f42 100644 --- a/server/bll/util.py +++ b/server/bll/util.py @@ -35,14 +35,21 @@ class SetFieldsResolver: SET_MODIFIERS = ("min", "max") def __init__(self, set_fields: Dict[str, Any]): - self.orig_fields = set_fields - self.fields = { - f: fname - for f, modifier, dunder, fname in ( - (f,) + f.partition("__") for f in set_fields.keys() - ) - if dunder and modifier in self.SET_MODIFIERS - } + self.orig_fields = {} + self.fields = {} + self.add_fields(**set_fields) + + def add_fields(self, **set_fields: Any): + self.orig_fields.update(set_fields) + self.fields.update( + { + f: fname + for f, modifier, dunder, fname in ( + (f,) + f.partition("__") for f in set_fields.keys() + ) + if dunder and modifier in self.SET_MODIFIERS + } + ) def _get_updated_name(self, doc: AttributedDocument, name: str) -> str: if name in self.fields and doc.get_field_value(self.fields[name]) is None: diff --git a/server/bll/workers/__init__.py b/server/bll/workers/__init__.py index d83cdb4..cf04985 100644 --- a/server/bll/workers/__init__.py +++ b/server/bll/workers/__init__.py @@ -21,6 +21,7 @@ from config import config from database.errors import translate_errors_context from database.model.auth import User from database.model.company import Company +from database.model.project import Project from database.model.queue import Queue from database.model.task.task import Task from redis_manager import redman @@ -146,6 +147,7 @@ class WorkerBLL: if not report.task: entry.task = None + entry.project = None else: with translate_errors_context(): query = dict(id=report.task, company=company_id) @@ -160,6 +162,12 @@ class WorkerBLL: raise bad_request.InvalidTaskId(**query) entry.task = IdNameEntry(id=task.id, name=task.name) + entry.project = None + if task.project: + project = Project.objects(id=task.project).only("name").first() + if project: + entry.project = IdNameEntry(id=project.id, name=project.name) + entry.last_report_time = now except APIError: raise @@ -369,7 +377,6 @@ class WorkerBLL: def make_doc(category, metric, variant, value) -> dict: return dict( _index=es_index, - _type="stat", _source=dict( timestamp=timestamp, worker=worker, diff --git a/server/bll/workers/stats.py b/server/bll/workers/stats.py index 5d55044..ecf7442 100644 --- a/server/bll/workers/stats.py +++ b/server/bll/workers/stats.py @@ -25,7 +25,6 @@ class WorkerStats: def _search_company_stats(self, company_id: str, es_req: dict) -> dict: return self.es.search( index=f"{self.worker_stats_prefix_for_company(company_id)}*", - doc_type="stat", body=es_req, ) @@ -53,7 +52,7 @@ class WorkerStats: res = self._search_company_stats(company_id, es_req) - if not res["hits"]["total"]: + if not res["hits"]["total"]["value"]: raise bad_request.WorkerStatsNotFound( f"No statistic metrics found for the company {company_id} and workers {worker_ids}" ) @@ -87,7 +86,7 @@ class WorkerStats: "dates": { "date_histogram": { "field": "timestamp", - "interval": f"{request.interval}s", + "fixed_interval": f"{request.interval}s", "min_doc_count": 1, }, "aggs": { @@ -216,7 +215,7 @@ class WorkerStats: "dates": { "date_histogram": { "field": "timestamp", - "interval": f"{interval}s", + "fixed_interval": f"{interval}s", }, "aggs": {"workers_count": {"cardinality": {"field": "worker"}}}, } diff --git a/server/config/default/apiserver.conf b/server/config/default/apiserver.conf index e248c35..198c1fa 100644 --- a/server/config/default/apiserver.conf +++ b/server/config/default/apiserver.conf @@ -30,7 +30,7 @@ enabled: false zip_files: ["/path/to/export.zip"] fail_on_error: false - artifacts_path: "/mnt/fileserver" + # artifacts_path: "/mnt/fileserver" } # time in seconds to take an exclusive lock to init es and mongodb diff --git a/server/config/default/hosts.conf b/server/config/default/hosts.conf index 17d9ab8..51aa77c 100644 --- a/server/config/default/hosts.conf +++ b/server/config/default/hosts.conf @@ -1,6 +1,6 @@ elastic { events { - hosts: [{host: "127.0.0.1", port: 9200}] + hosts: [{host: "127.0.0.1", port: 9211}] args { timeout: 60 dead_timeout: 10 @@ -11,7 +11,7 @@ elastic { } workers { - hosts: [{host:"127.0.0.1", port:9200}] + hosts: [{host:"127.0.0.1", port:9211}] args { timeout: 60 dead_timeout: 10 diff --git a/server/config/default/services/auth.conf b/server/config/default/services/auth.conf new file mode 100644 index 0000000..fe9c87e --- /dev/null +++ b/server/config/default/services/auth.conf @@ -0,0 +1,16 @@ +fixed_users { + guest { + enabled: false + + default_company: "025315a9321f49f8be07f5ac48fbcf92" + + name: "Guest" + username: "guest" + password: "guest" + + # Allow access only to the following endpoints when using user/pass credentials + allow_endpoints: [ + "auth.login" + ] + } +} \ No newline at end of file diff --git a/server/config/default/services/projects.conf b/server/config/default/services/projects.conf new file mode 100644 index 0000000..27a1c85 --- /dev/null +++ b/server/config/default/services/projects.conf @@ -0,0 +1,8 @@ +# Order of featured projects, by name or ID +featured_order: [ + # {id: ""} + # OR + # {name: ""} + # OR + # {name_regex: ""} +] diff --git a/server/config/info.py b/server/config/info.py index 529b77a..60439ad 100644 --- a/server/config/info.py +++ b/server/config/info.py @@ -41,3 +41,6 @@ def get_deployment_type() -> str: def get_default_company(): return config.get("apiserver.default_company") + + +missed_es_upgrade = False diff --git a/server/database/model/auth.py b/server/database/model/auth.py index 9dd0b39..406ae9d 100644 --- a/server/database/model/auth.py +++ b/server/database/model/auth.py @@ -32,6 +32,8 @@ class Role(object): """ Company user """ annotator = "annotator" """ Annotator with limited access""" + guest = "guest" + """ Guest user. Read Only.""" @classmethod def get_system_roles(cls) -> set: diff --git a/server/database/model/base.py b/server/database/model/base.py index 076a449..9075bf5 100644 --- a/server/database/model/base.py +++ b/server/database/model/base.py @@ -1,7 +1,7 @@ import re from collections import namedtuple from functools import reduce -from typing import Collection, Sequence, Union, Optional +from typing import Collection, Sequence, Union, Optional, Type from boltons.iterutils import first, bucketize from dateutil.parser import parse as parse_datetime @@ -9,6 +9,7 @@ from mongoengine import Q, Document, ListField, StringField from pymongo.command_cursor import CommandCursor from apierrors import errors +from apierrors.base import BaseError from config import config from database.errors import MakeGetAllQueryError from database.projection import project_dict, ProjectionHelper @@ -483,6 +484,21 @@ class GetMixin(PropsMixin): query=_query, parameters=parameters, override_projection=override_projection ) + @classmethod + def get_many_public( + cls, query: Q = None, projection: Collection[str] = None, + ): + """ + Fetch all public documents matching a provided query. + :param query: Optional query object (mongoengine.Q). + :param projection: A list of projection fields. + :return: A list of documents matching the query. + """ + q = get_company_or_none_constraint() + _query = (q & query) if query else q + + return cls._get_many_no_company(query=_query, override_projection=projection) + @classmethod def _get_many_no_company( cls: Union["GetMixin", Document], @@ -728,6 +744,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin): ) return cls.objects.aggregate(pipeline, **kwargs) + @classmethod + def set_public( + cls: Type[Document], + company_id: str, + ids: Sequence[str], + invalid_cls: Type[BaseError], + enabled: bool = True, + ): + if enabled: + items = list(cls.objects(id__in=ids, company=company_id).only("id")) + update = dict(set__company_origin=company_id, unset__company=1) + else: + items = list( + cls.objects( + id__in=ids, company__in=(None, ""), company_origin=company_id + ).only("id") + ) + update = dict(set__company=company_id, unset__company_origin=1) + + if len(items) < len(ids): + missing = tuple(set(ids).difference(i.id for i in items)) + raise invalid_cls(ids=missing) + + return {"updated": cls.objects(id__in=ids).update(**update)} + def validate_id(cls, company, **kwargs): """ diff --git a/server/database/model/model.py b/server/database/model/model.py index b777efd..8de275c 100644 --- a/server/database/model/model.py +++ b/server/database/model/model.py @@ -72,3 +72,4 @@ class Model(DbModelMixin, Document): ui_cache = SafeDictField( default=dict, user_set_allowed=True, exclude_by_default=True ) + company_origin = StringField(exclude_by_default=True) diff --git a/server/database/model/project.py b/server/database/model/project.py index bde016f..6165811 100644 --- a/server/database/model/project.py +++ b/server/database/model/project.py @@ -1,4 +1,4 @@ -from mongoengine import StringField, DateTimeField +from mongoengine import StringField, DateTimeField, IntField from database import Database, strict from database.fields import StrippedStringField, SafeSortedListField @@ -40,3 +40,7 @@ class Project(AttributedDocument): system_tags = SafeSortedListField(StringField(required=True)) default_output_destination = StrippedStringField() last_update = DateTimeField() + featured = IntField(default=9999) + logo_url = StringField() + logo_blob = StringField(exclude_by_default=True) + company_origin = StringField(exclude_by_default=True) diff --git a/server/database/model/task/task.py b/server/database/model/task/task.py index 2ee57f9..0ee4714 100644 --- a/server/database/model/task/task.py +++ b/server/database/model/task/task.py @@ -118,7 +118,7 @@ external_task_types = set(get_options(TaskType)) class Task(AttributedDocument): _field_collation_overrides = { "execution.parameters.": {"locale": "en_US", "numericOrdering": True}, - "last_metrics.": {"locale": "en_US", "numericOrdering": True} + "last_metrics.": {"locale": "en_US", "numericOrdering": True}, } meta = { @@ -194,3 +194,5 @@ class Task(AttributedDocument): last_iteration = IntField(default=DEFAULT_LAST_ITERATION) last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent))) metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats)) + company_origin = StringField(exclude_by_default=True) + duration = IntField() # task duration in seconds diff --git a/server/elastic/apply_mappings.py b/server/elastic/apply_mappings.py index 65b8112..3590e3c 100755 --- a/server/elastic/apply_mappings.py +++ b/server/elastic/apply_mappings.py @@ -4,9 +4,9 @@ Apply elasticsearch mappings to given hosts. """ import argparse import json -import requests from pathlib import Path +import requests from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry @@ -14,21 +14,24 @@ HERE = Path(__file__).resolve().parent session = requests.Session() adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5)) -session.mount('http://', adapter) +session.mount("http://", adapter) + + +def get_template(host: str, template) -> dict: + url = f"{host}/_template/{template}" + res = session.get(url) + return res.json() def apply_mappings_to_host(host: str): def _send_mapping(f): with f.open() as json_data: data = json.load(json_data) - es_server = host - url = f"{es_server}/_template/{f.stem}" + url = f"{host}/_template/{f.stem}" session.delete(url) r = session.post( - url, - headers={"Content-Type": "application/json"}, - data=json.dumps(data), + url, headers={"Content-Type": "application/json"}, data=json.dumps(data) ) return {"mapping": f.stem, "result": r.text} @@ -47,7 +50,8 @@ def parse_args(): def main(): - for host in parse_args().hosts: + args = parse_args() + for host in args.hosts: print(">>>>> Applying mapping to " + host) res = apply_mappings_to_host(host) print(res) diff --git a/server/elastic/initialize.py b/server/elastic/initialize.py index bfa51bf..5f4c63a 100644 --- a/server/elastic/initialize.py +++ b/server/elastic/initialize.py @@ -1,7 +1,7 @@ from furl import furl from config import config -from elastic.apply_mappings import apply_mappings_to_host +from elastic.apply_mappings import apply_mappings_to_host, get_template from es_factory import get_cluster_config log = config.logger(__file__) @@ -15,13 +15,22 @@ class MissingElasticConfiguration(Exception): pass -def init_es_data(): +def _url_from_host_conf(conf: dict) -> str: + return furl(scheme="http", host=conf["host"], port=conf["port"]).url + + +def init_es_data() -> bool: + """Return True if the db was empty""" hosts_config = get_cluster_config("events").get("hosts") if not hosts_config: raise MissingElasticConfiguration("for cluster 'events'") + empty_db = not get_template(_url_from_host_conf(hosts_config[0]), "events*") + for conf in hosts_config: - host = furl(scheme="http", host=conf["host"], port=conf["port"]).url + host = _url_from_host_conf(conf) log.info(f"Applying mappings to host: {host}") res = apply_mappings_to_host(host) log.info(res) + + return empty_db diff --git a/server/elastic/mappings/events.json b/server/elastic/mappings/events.json index 74708a1..eb33863 100644 --- a/server/elastic/mappings/events.json +++ b/server/elastic/mappings/events.json @@ -1,26 +1,39 @@ { - "template": "events-*", + "index_patterns": "events-*", "settings": { "number_of_shards": 1 }, "mappings": { - "_default_": { - "_source": { - "enabled": true + "_source": { + "enabled": true + }, + "properties": { + "@timestamp": { + "type": "date" }, - "_routing": { - "required": true + "task": { + "type": "keyword" }, - "properties": { - "@timestamp": { "type": "date" }, - "task": { "type": "keyword" }, - "type": { "type": "keyword" }, - "worker": { "type": "keyword" }, - "timestamp": { "type": "date" }, - "iter": { "type": "long" }, - "metric": { "type": "keyword" }, - "variant": { "type": "keyword" }, - "value": { "type": "float" } + "type": { + "type": "keyword" + }, + "worker": { + "type": "keyword" + }, + "timestamp": { + "type": "date" + }, + "iter": { + "type": "long" + }, + "metric": { + "type": "keyword" + }, + "variant": { + "type": "keyword" + }, + "value": { + "type": "float" } } } diff --git a/server/elastic/mappings/events_log.json b/server/elastic/mappings/events_log.json index 62a7051..07565b5 100644 --- a/server/elastic/mappings/events_log.json +++ b/server/elastic/mappings/events_log.json @@ -1,11 +1,14 @@ { - "template": "events-log-*", - "order" : 1, + "index_patterns": "events-log-*", + "order": 1, "mappings": { - "_default_": { - "properties": { - "msg": { "type":"text", "index": false }, - "level": { "type":"keyword" } + "properties": { + "msg": { + "type": "text", + "index": false + }, + "level": { + "type": "keyword" } } } diff --git a/server/elastic/mappings/events_plot.json b/server/elastic/mappings/events_plot.json index f5d607d..260700a 100644 --- a/server/elastic/mappings/events_plot.json +++ b/server/elastic/mappings/events_plot.json @@ -1,10 +1,11 @@ { - "template": "events-plot-*", - "order" : 1, + "index_patterns": "events-plot-*", + "order": 1, "mappings": { - "_default_": { - "properties": { - "plot_str": { "type":"text", "index": false } + "properties": { + "plot_str": { + "type": "text", + "index": false } } } diff --git a/server/elastic/mappings/events_training_debug_image.json b/server/elastic/mappings/events_training_debug_image.json index 146fb62..2c0d1e7 100644 --- a/server/elastic/mappings/events_training_debug_image.json +++ b/server/elastic/mappings/events_training_debug_image.json @@ -1,11 +1,13 @@ { - "template": "events-training_debug_image-*", - "order" : 1, + "index_patterns": "events-training_debug_image-*", + "order": 1, "mappings": { - "_default_": { - "properties": { - "key": { "type": "keyword" }, - "url": { "type": "keyword" } + "properties": { + "key": { + "type": "keyword" + }, + "url": { + "type": "keyword" } } } diff --git a/server/elastic/mappings/queue_metrics.json b/server/elastic/mappings/queue_metrics.json index 9f69251..0506c65 100644 --- a/server/elastic/mappings/queue_metrics.json +++ b/server/elastic/mappings/queue_metrics.json @@ -1,26 +1,24 @@ { - "template": "queue_metrics_*", + "index_patterns": "queue_metrics_*", "settings": { "number_of_shards": 1 }, "mappings": { - "metrics": { - "_source": { - "enabled": true + "_source": { + "enabled": true + }, + "properties": { + "timestamp": { + "type": "date" }, - "properties": { - "timestamp": { - "type": "date" - }, - "queue": { - "type": "keyword" - }, - "average_waiting_time": { - "type": "float" - }, - "queue_length": { - "type": "integer" - } + "queue": { + "type": "keyword" + }, + "average_waiting_time": { + "type": "float" + }, + "queue_length": { + "type": "integer" } } } diff --git a/server/elastic/mappings/worker_stats.json b/server/elastic/mappings/worker_stats.json index 3c2437c..6e11a3d 100644 --- a/server/elastic/mappings/worker_stats.json +++ b/server/elastic/mappings/worker_stats.json @@ -1,22 +1,36 @@ { - "template": "worker_stats_*", + "index_patterns": "worker_stats_*", "settings": { "number_of_shards": 1 }, "mappings": { - "stat": { - "_source": { - "enabled": true + "_source": { + "enabled": true + }, + "properties": { + "timestamp": { + "type": "date" }, - "properties": { - "timestamp": { "type": "date" }, - "worker": { "type": "keyword" }, - "category": { "type": "keyword" }, - "metric": { "type": "keyword" }, - "variant": { "type": "keyword" }, - "value": { "type": "float" }, - "unit": { "type": "keyword" }, - "task": { "type": "keyword" } + "worker": { + "type": "keyword" + }, + "category": { + "type": "keyword" + }, + "metric": { + "type": "keyword" + }, + "variant": { + "type": "keyword" + }, + "value": { + "type": "float" + }, + "unit": { + "type": "keyword" + }, + "task": { + "type": "keyword" } } } diff --git a/server/mongo/initialize/__init__.py b/server/mongo/initialize/__init__.py index d5bd9ec..6607c33 100644 --- a/server/mongo/initialize/__init__.py +++ b/server/mongo/initialize/__init__.py @@ -24,14 +24,9 @@ def _pre_populate(company_id: str, zip_file: str): else: log.info(f"Pre-populating using {zip_file}") - user_id = _ensure_backend_user( - "__allegroai__", company_id, "Allegro.ai" - ) - PrePopulate.import_from_zip( zip_file, company_id="", - user_id=user_id, artifacts_path=config.get( "apiserver.pre_populate.artifacts_path", None ), @@ -60,7 +55,7 @@ def init_mongo_data() -> bool: _ensure_uuid() - company_id = _ensure_company(log) + company_id = _ensure_company(get_default_company(), "trains", log) _ensure_default_queue(company_id) @@ -82,9 +77,13 @@ def init_mongo_data() -> bool: if fixed_mode: log.info("Fixed users mode is enabled") FixedUser.validate() + + if FixedUser.guest_enabled(): + _ensure_company(FixedUser.get_guest_user().company, "guests", log) + for user in FixedUser.from_config(): try: - ensure_fixed_user(user, company_id, log=log) + ensure_fixed_user(user, log=log) except Exception as ex: log.error(f"Failed creating fixed user {user.name}: {ex}") diff --git a/server/mongo/initialize/pre_populate.py b/server/mongo/initialize/pre_populate.py index e69ed44..0f93c7a 100644 --- a/server/mongo/initialize/pre_populate.py +++ b/server/mongo/initialize/pre_populate.py @@ -1,30 +1,44 @@ import hashlib import importlib import os +import re from collections import defaultdict from datetime import datetime, timezone +from functools import partial from io import BytesIO from itertools import chain from operator import attrgetter from os.path import splitext from pathlib import Path -from typing import Optional, Any, Type, Set, Dict, Sequence, Tuple, BinaryIO, Union +from typing import ( + Optional, + Any, + Type, + Set, + Dict, + Sequence, + Tuple, + BinaryIO, + Union, + Mapping, +) from urllib.parse import unquote, urlparse from zipfile import ZipFile, ZIP_BZIP2 -import attr import mongoengine from boltons.iterutils import chunked_iter from furl import furl from mongoengine import Q from bll.event import EventBLL +from config import config from database.model import EntityVisibility from database.model.model import Model from database.model.project import Project from database.model.task.task import Task, ArtifactModes, TaskStatus from database.utils import get_options from utilities import json +from .user import _ensure_backend_user class PrePopulate: @@ -32,6 +46,7 @@ class PrePopulate: events_file_suffix = "_events" export_tag_prefix = "Exported:" export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S" + metadata_filename = "metadata.json" class JsonLinesWriter: def __init__(self, file: BinaryIO): @@ -54,26 +69,21 @@ class PrePopulate: self._write("\n" + line) self.empty = False - @attr.s(auto_attribs=True) - class _MapData: - files: Sequence[str] = None - entities: Dict[str, datetime] = None - @staticmethod def _get_last_update_time(entity) -> datetime: return getattr(entity, "last_update", None) or getattr(entity, "created") @classmethod def _check_for_update( - cls, map_file: Path, entities: dict + cls, map_file: Path, entities: dict, metadata_hash: str ) -> Tuple[bool, Sequence[str]]: if not map_file.is_file(): return True, [] files = [] try: - map_data = cls._MapData(**json.loads(map_file.read_text())) - files = map_data.files + map_data = json.loads(map_file.read_text()) + files = map_data.get("files", []) for file in files: if not Path(file).is_file(): return True, files @@ -82,7 +92,7 @@ class PrePopulate: item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc) for item in chain.from_iterable(entities.values()) } - old_times = map_data.entities + old_times = map_data.get("entities", {}) if set(new_times.keys()) != set(old_times.keys()): return True, files @@ -90,6 +100,10 @@ class PrePopulate: for id_, new_timestamp in new_times.items(): if new_timestamp != old_times[id_]: return True, files + + if metadata_hash != map_data.get("metadata_hash", ""): + return True, files + except Exception as ex: print("Error reading map file. " + str(ex)) return True, files @@ -98,16 +112,24 @@ class PrePopulate: @classmethod def _write_update_file( - cls, map_file: Path, entities: dict, created_files: Sequence[str] + cls, + map_file: Path, + entities: dict, + created_files: Sequence[str], + metadata_hash: str, ): - map_data = cls._MapData( - files=created_files, - entities={ - entity.id: cls._get_last_update_time(entity) - for entity in chain.from_iterable(entities.values()) - }, + map_file.write_text( + json.dumps( + dict( + files=created_files, + entities={ + entity.id: cls._get_last_update_time(entity) + for entity in chain.from_iterable(entities.values()) + }, + metadata_hash=metadata_hash, + ) + ) ) - map_file.write_text(json.dumps(attr.asdict(map_data))) @staticmethod def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]: @@ -117,7 +139,9 @@ class PrePopulate: return True if a.startswith("http"): parsed = urlparse(a) - if parsed.scheme in {"http", "https"} and parsed.port == 8081: + if parsed.scheme in {"http", "https"} and parsed.netloc.endswith( + "8081" + ): return True return False @@ -137,6 +161,7 @@ class PrePopulate: artifacts_path: str = None, task_statuses: Sequence[str] = None, tag_exported_entities: bool = False, + metadata: Mapping[str, Any] = None, ) -> Sequence[str]: if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)): raise ValueError("Invalid task statuses") @@ -146,11 +171,22 @@ class PrePopulate: experiments=experiments, projects=projects, task_statuses=task_statuses ) + hash_ = hashlib.md5() + if metadata: + meta_str = json.dumps(metadata) + hash_.update(meta_str.encode()) + metadata_hash = hash_.hexdigest() + else: + meta_str, metadata_hash = "", "" + map_file = file.with_suffix(".map") - updated, old_files = cls._check_for_update(map_file, entities) + updated, old_files = cls._check_for_update( + map_file, entities=entities, metadata_hash=metadata_hash + ) if not updated: print(f"There are no updates from the last export") return old_files + for old in old_files: old_path = Path(old) if old_path.is_file(): @@ -158,10 +194,16 @@ class PrePopulate: zip_args = dict(mode="w", compression=ZIP_BZIP2) with ZipFile(file, **zip_args) as zfile: - artifacts, hash_ = cls._export( - zfile, entities, tag_entities=tag_exported_entities + if metadata: + zfile.writestr(cls.metadata_filename, meta_str) + artifacts = cls._export( + zfile, + entities=entities, + hash_=hash_, + tag_entities=tag_exported_entities, ) - file_with_hash = file.with_name(f"{file.stem}_{hash_}{file.suffix}") + + file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}") file.replace(file_with_hash) created_files = [str(file_with_hash)] @@ -172,16 +214,43 @@ class PrePopulate: cls._export_artifacts(zfile, artifacts, artifacts_path) created_files.append(str(artifacts_file)) - cls._write_update_file(map_file, entities, created_files) + cls._write_update_file( + map_file, + entities=entities, + created_files=created_files, + metadata_hash=metadata_hash, + ) return created_files @classmethod def import_from_zip( - cls, filename: str, company_id: str, user_id: str, artifacts_path: str + cls, + filename: str, + company_id: str, + artifacts_path: str, + user_id: str = "", + user_name: str = "", ): + metadata = None + with ZipFile(filename) as zfile: - cls._import(zfile, company_id, user_id) + try: + with zfile.open(cls.metadata_filename) as f: + metadata = json.loads(f.read()) + if not user_id: + meta_user_id = metadata.get("user_id", "") + meta_user_name = metadata.get("user_name", "") + user_id, user_name = meta_user_id, meta_user_name + except Exception: + pass + + if not user_id: + user_id, user_name = "__allegroai__", "Allegro.ai" + + user_id = _ensure_backend_user(user_id, company_id, user_name) + + cls._import(zfile, company_id, user_id, metadata) if artifacts_path and os.path.isdir(artifacts_path): artifacts_file = Path(filename).with_suffix(".artifacts") @@ -190,6 +259,24 @@ class PrePopulate: with ZipFile(artifacts_file) as zfile: zfile.extractall(artifacts_path) + @classmethod + def update_featured_projects_order(cls): + featured_order = config.get("services.projects.featured_order", []) + + def get_index(p: Project): + for index, entry in enumerate(featured_order): + if ( + entry.get("id", None) == p.id + or entry.get("name", None) == p.name + or ("name_regex" in entry and re.match(entry["name_regex"], p.name)) + ): + return index + return 999 + + for project in Project.get_many_public(projection=["id", "name"]): + featured_index = get_index(project) + Project.objects(id=project.id).update(featured=featured_index) + @staticmethod def _resolve_type( cls: Type[mongoengine.Document], ids: Optional[Sequence[str]] @@ -389,15 +476,14 @@ class PrePopulate: @classmethod def _export( - cls, writer: ZipFile, entities: dict, tag_entities: bool = False - ) -> Tuple[Sequence[str], str]: + cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False + ) -> Sequence[str]: """ Export the requested experiments, projects and models and return the list of artifact files Always do the export on sorted items since the order of items influence hash """ artifacts = [] now = datetime.utcnow() - hash_ = hashlib.md5() for cls_ in sorted(entities, key=attrgetter("__name__")): items = sorted(entities[cls_], key=attrgetter("id")) if not items: @@ -423,7 +509,7 @@ class PrePopulate: if tag_entities: cls._add_tag(items, now.strftime(cls.export_tag)) - return artifacts, hash_.hexdigest() + return artifacts @staticmethod def json_lines(file: BinaryIO): @@ -441,7 +527,13 @@ class PrePopulate: yield clean @classmethod - def _import(cls, reader: ZipFile, company_id: str = "", user_id: str = None): + def _import( + cls, + reader: ZipFile, + company_id: str = "", + user_id: str = None, + metadata: Mapping[str, Any] = None, + ): """ Import entities and events from the zip file Start from entities since event import will require the tasks already in DB @@ -451,12 +543,13 @@ class PrePopulate: fi for fi in reader.filelist if not fi.orig_filename.endswith(event_file_ending) + and fi.orig_filename != cls.metadata_filename ) event_files = ( fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending) ) for files, reader_func in ( - (entity_files, cls._import_entity), + (entity_files, partial(cls._import_entity, metadata=metadata or {})), (event_files, cls._import_events), ): for file_info in files: @@ -466,11 +559,20 @@ class PrePopulate: reader_func(f, full_name, company_id, user_id) @classmethod - def _import_entity(cls, f: BinaryIO, full_name: str, company_id: str, user_id: str): + def _import_entity( + cls, + f: BinaryIO, + full_name: str, + company_id: str, + user_id: str, + metadata: Mapping[str, Any], + ): module_name, _, class_name = full_name.rpartition(".") module = importlib.import_module(module_name) cls_: Type[mongoengine.Document] = getattr(module, class_name) print(f"Writing {cls_.__name__.lower()}s into database") + + override_project_count = 0 for item in cls.json_lines(f): doc = cls_.from_json(item, created=True) if hasattr(doc, "user"): @@ -478,10 +580,24 @@ class PrePopulate: if hasattr(doc, "company"): doc.company = company_id if isinstance(doc, Project): + override_project_name = metadata.get("project_name", None) + if override_project_name: + if override_project_count: + override_project_name = ( + f"{override_project_name} {override_project_count + 1}" + ) + override_project_count += 1 + doc.name = override_project_name + + doc.logo_url = metadata.get("logo_url", None) + doc.logo_blob = metadata.get("logo_blob", None) + cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update( set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}" ) + doc.save() + if isinstance(doc, Task): cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True) diff --git a/server/mongo/initialize/user.py b/server/mongo/initialize/user.py index a42647d..6861108 100644 --- a/server/mongo/initialize/user.py +++ b/server/mongo/initialize/user.py @@ -58,15 +58,15 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str): return user_id -def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger): - if User.objects(id=user.user_id).first(): +def ensure_fixed_user(user: FixedUser, log: Logger): + if User.objects(company=user.company, id=user.user_id).first(): return data = attr.asdict(user) data["id"] = user.user_id data["email"] = f"{user.user_id}@example.com" - data["role"] = Role.user + data["role"] = Role.guest if user.is_guest else Role.user - _ensure_auth_user(user_data=data, company_id=company_id, log=log) + _ensure_auth_user(user_data=data, company_id=user.company, log=log) - return _ensure_backend_user(user.user_id, company_id, user.name) + return _ensure_backend_user(user.user_id, user.company, user.name) diff --git a/server/mongo/initialize/util.py b/server/mongo/initialize/util.py index 087176d..e90c8fc 100644 --- a/server/mongo/initialize/util.py +++ b/server/mongo/initialize/util.py @@ -3,7 +3,6 @@ from uuid import uuid4 from bll.queue import QueueBLL from config import config -from config.info import get_default_company from database.model.company import Company from database.model.queue import Queue from database.model.settings import Settings, SettingKeys @@ -11,13 +10,11 @@ from database.model.settings import Settings, SettingKeys log = config.logger(__file__) -def _ensure_company(log: Logger): - company_id = get_default_company() +def _ensure_company(company_id, company_name, log: Logger): company = Company.objects(id=company_id).only("id").first() if company: return company_id - company_name = "trains" log.info(f"Creating company: {company_name}") company = Company(id=company_id, name=company_name) company.save() diff --git a/server/requirements.txt b/server/requirements.txt index ef479c9..95f804c 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,7 +1,8 @@ attrs>=19.1.0 boltons>=19.1.0 +boto3==1.14.13 dpath>=1.4.2,<2.0 -elasticsearch>=5.0.0,<6.0.0 +elasticsearch>=7.0.0,<8.0.0 fastjsonschema>=2.8 Flask-Compress>=1.4.0 Flask-Cors>=3.0.5 @@ -24,7 +25,7 @@ python-rapidjson>=0.6.3 redis>=2.10.5 related>=0.7.2 requests>=2.13.0 -semantic_version>=2.8.0,<3 +semantic_version>=2.8.3,<3 six tqdm validators>=0.12.4 \ No newline at end of file diff --git a/server/schema/services/auth.conf b/server/schema/services/auth.conf index 46fc509..3755dfb 100644 --- a/server/schema/services/auth.conf +++ b/server/schema/services/auth.conf @@ -328,6 +328,9 @@ fixed_users_mode { description: "Fixed users mode enabled" type: boolean } + migration_warning { + type: boolean + } } } } diff --git a/server/schema/services/events.conf b/server/schema/services/events.conf index 34cb11a..a073cd4 100644 --- a/server/schema/services/events.conf +++ b/server/schema/services/events.conf @@ -848,7 +848,7 @@ description: "Task ID" } samples { - description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000." + description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000." type: integer } key { @@ -886,7 +886,7 @@ ] properties { tasks { - description: "List of task Task IDs" + description: "List of task Task IDs. Maximum amount of tasks is 10" type: array items { type: string @@ -894,7 +894,7 @@ } } samples { - description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000." + description: "The amount of histogram points to return. Optional, the default value is 6000" type: integer } key { diff --git a/server/schema/services/models.conf b/server/schema/services/models.conf index 1e23c2b..646c699 100644 --- a/server/schema/services/models.conf +++ b/server/schema/services/models.conf @@ -1,54 +1,366 @@ -{ - _description: """This service provides a management interface for models (results of training tasks) stored in the system.""" - _definitions { - multi_field_pattern_data { - type: object - properties { - pattern { - description: "Pattern string (regex)" - type: string - } - fields { - description: "List of field names" - type: array - items { type: string } - } +_description: """This service provides a management interface for models (results of training tasks) stored in the system.""" +_definitions { + multi_field_pattern_data { + type: object + properties { + pattern { + description: "Pattern string (regex)" + type: string + } + fields { + description: "List of field names" + type: array + items { type: string } } } - model { + } + model { + type: object + properties { + id { + description: "Model id" + type: string + } + name { + description: "Model name" + type: string + } + user { + description: "Associated user id" + type: string + } + company { + description: "Company id" + type: string + } + created { + description: "Model creation time" + type: string + format: "date-time" + } + task { + description: "Task ID of task in which the model was created" + type: string + } + parent { + description: "Parent model ID" + type: string + } + project { + description: "Associated project ID" + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + framework { + description: "Framework on which the model is based. Should be identical to the framework of the task which created the model" + type: string + } + design { + description: "Json object representing the model design. Should be identical to the network design of the task which created the model" + type: object + additionalProperties: true + } + labels { + description: "Json object representing the ids of the labels in the model. The keys are the layers' names and the values are the ids." + type: object + additionalProperties { type: integer } + } + uri { + description: "URI for the model, pointing to the destination storage." + type: string + } + ready { + description: "Indication if the model is final and can be used by other tasks" + type: boolean + } + ui_cache { + description: "UI cache for this model" + type: object + additionalProperties: true + } + } + } +} + +get_by_id { + "2.1" { + description: "Gets model information" + request { type: object + required: [ model ] properties { - id { + model { description: "Model id" type: string } + } + } + + response { + type: object + properties { + model { + description: "Model info" + "$ref": "#/definitions/model" + } + } + } + } +} + +get_by_task_id { + "2.1" { + description: "Gets model information" + request { + type: object + properties { + task { + description: "Task id" + type: string + } + } + } + response { + type: object + properties { + model { + description: "Model info" + "$ref": "#/definitions/model" + } + } + } + } +} +get_all_ex { + internal: true + "2.1": ${get_all."2.1"} +} +get_all { + "2.1" { + description: "Get all models" + request { + type: object + properties { name { - description: "Model name" + description: "Get only models whose name matches this pattern (python regular expression syntax)" type: string } user { - description: "Associated user id" + description: "List of user IDs used to filter results by the model's creating user" + type: array + items { type: string } + } + ready { + description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects." + type: boolean + } + tags { + description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion" + type: array + items { type: string } + } + system_tags { + description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion" + type: array + items { type: string } + } + only_fields { + description: "List of model field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)" + type: array + items { type: string } + } + page { + description: "Page number, returns a specific page out of the resulting list of models" + type: integer + minimum: 0 + } + page_size { + description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)" + type: integer + minimum: 1 + } + project { + description: "List of associated project IDs" + type: array + items { type: string } + } + order_by { + description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page" + type: array + items { type: string } + } + task { + description: "List of associated task IDs" + type: array + items { type: string } + } + id { + description: "List of model IDs" + type: array + items { type: string } + } + search_text { + description: "Free text search query" type: string } - company { - description: "Company id" + framework { + description: "List of frameworks" + type: array + items { type: string } + } + uri { + description: "List of model URIs" + type: array + items { type: string } + } + _all_ { + description: "Multi-field pattern condition (all fields match pattern)" + "$ref": "#/definitions/multi_field_pattern_data" + } + _any_ { + description: "Multi-field pattern condition (any field matches pattern)" + "$ref": "#/definitions/multi_field_pattern_data" + } + } + dependencies { + page: [ page_size ] + } + } + response { + type: object + properties { + models: { + description: "Models list" + type: array + items { "$ref": "#/definitions/model" } + } + } + } + } +} +get_frameworks { + "2.8" { + description: "Get the list of frameworks used in the company models" + request { + type: object + properties { + projects { + description: "The list of projects which models will be analyzed. If not passed or empty then all the company and public models will be analyzed" + type: array + items: {type: string} + } + } + } + response { + type: object + properties { + frameworks { + description: "Unique list of the frameworks used in the company models" + type: array + items: {type: string} + } + } + } + } +} +update_for_task { + "2.1" { + description: "Create or update a new model for a task" + request { + type: object + required: [ + task + ] + properties { + task { + description: "Task id" + type: string + } + uri { + description: "URI for the model. Exactly one of uri or override_model_id is a required." + type: string + } + name { + description: "Model name Unique within the company." + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + override_model_id { + description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required." + type: string + } + iteration { + description: "Iteration (used to update task statistics)" + type: integer + } + } + } + response { + type: object + properties { + id { + description: "ID of the model" type: string } created { - description: "Model creation time" - type: string - format: "date-time" + description: "Was the model created" + type: boolean } - task { - description: "Task ID of task in which the model was created" + updated { + description: "Number of models updated (0 or 1)" + type: integer + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } + } + } + } +} +create { + "2.1" { + description: "Create a new model not associated with a task" + request { + type: object + required: [ + uri + name + ] + properties { + uri { + description: "URI for the model" type: string } - parent { - description: "Parent model ID" - type: string - } - project { - description: "Associated project ID" + name { + description: "Model name Unique within the company." type: string } comment { @@ -66,595 +378,281 @@ items {type: string} } framework { - description: "Framework on which the model is based. Should be identical to the framework of the task which created the model" + description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." type: string } design { - description: "Json object representing the model design. Should be identical to the network design of the task which created the model" + description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" type: object additionalProperties: true } labels { - description: "Json object representing the ids of the labels in the model. The keys are the layers' names and the values are the ids." + description: "Json object" type: object additionalProperties { type: integer } } - uri { - description: "URI for the model, pointing to the destination storage." + ready { + description: "Indication if the model is final and can be used by other tasks. Default is false." + type: boolean + default: false + } + public { + description: "Create a public model Default is false." + type: boolean + default: false + } + project { + description: "Project to which to model belongs" type: string } + parent { + description: "Parent model" + type: string + } + task { + description: "Associated task ID" + type: string + } + } + } + response { + type: object + properties { + id { + description: "ID of the model" + type: string + } + created { + description: "Was the model created" + type: boolean + } + } + } + } +} +edit { + "2.1" { + description: "Edit an existing model" + request { + type: object + required: [ + model + ] + properties { + model { + description: "Model ID" + type: string + } + uri { + description: "URI for the model" + type: string + } + name { + description: "Model name Unique within the company." + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + framework { + description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." + type: string + } + design { + description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" + type: object + additionalProperties: true + } + labels { + description: "Json object" + type: object + additionalProperties { type: integer } + } ready { description: "Indication if the model is final and can be used by other tasks" type: boolean } + project { + description: "Project to which to model belongs" + type: string + } + parent { + description: "Parent model" + type: string + } + task { + description: "Associated task ID" + type: string + } + iteration { + description: "Iteration (used to update task statistics)" + type: integer + } + } + } + response { + type: object + properties { + updated { + description: "Number of models updated (0 or 1)" + type: integer + enum: [0, 1] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } + } + } + } +} +update { + "2.1" { + description: "Update a model" + request { + type: object + required: [ model ] + properties { + model { + description: "Model id" + type: string + } + name { + description: "Model name Unique within the company." + type: string + } + comment { + description: "Model comment" + type: string + } + tags { + type: array + description: "User-defined tags" + items { type: string } + } + system_tags { + type: array + description: "System tags. This field is reserved for system use, please don't use it." + items {type: string} + } + ready { + description: "Indication if the model is final and can be used by other tasks Default is false." + type: boolean + default: false + } + created { + description: "Model creation time (UTC) " + type: string + format: "date-time" + } ui_cache { description: "UI cache for this model" type: object additionalProperties: true } - } - } - } + project { + description: "Project to which to model belongs" + type: string + } + task { + description: "Associated task ID" + type: "string" + } + iteration { + description: "Iteration (used to update task statistics if an associated task is reported)" + type: integer + } - get_by_id { - "2.1" { - description: "Gets model information" - request { - type: object - required: [ model ] - properties { - model { - description: "Model id" - type: string - } - } } - - response { - type: object - properties { - model { - description: "Model info" - "$ref": "#/definitions/model" - } + } + response { + type: object + properties { + updated { + description: "Number of models updated (0 or 1)" + type: integer + enum: [0, 1] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true } } } } - - get_by_task_id { - "2.1" { - description: "Gets model information" - request { - type: object - properties { - task { - description: "Task id" - type: string - } +} +set_ready { + "2.1" { + description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task." + request { + type: object + required: [ model ] + properties { + model { + description: "Model id" + type: string } - } - response { - type: object - properties { - model { - description: "Model info" - "$ref": "#/definitions/model" - } + force_publish_task { + description: "Publish the associated task (if exists) even if it is not in the 'stopped' state. Optional, the default value is False." + type: boolean + } + publish_task { + description: "Indicates that the associated task (if exists) should be published. Optional, the default value is True." + type: boolean } } } - } - get_all_ex { - internal: true - "2.1": ${get_all."2.1"} - } - get_all { - "2.1" { - description: "Get all models" - request { - type: object - properties { - name { - description: "Get only models whose name matches this pattern (python regular expression syntax)" - type: string - } - user { - description: "List of user IDs used to filter results by the model's creating user" - type: array - items { type: string } - } - ready { - description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects." - type: boolean - } - tags { - description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion" - type: array - items { type: string } - } - system_tags { - description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion" - type: array - items { type: string } - } - only_fields { - description: "List of model field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)" - type: array - items { type: string } - } - page { - description: "Page number, returns a specific page out of the resulting list of models" - type: integer - minimum: 0 - } - page_size { - description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)" - type: integer - minimum: 1 - } - project { - description: "List of associated project IDs" - type: array - items { type: string } - } - order_by { - description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page" - type: array - items { type: string } - } - task { - description: "List of associated task IDs" - type: array - items { type: string } - } - id { - description: "List of model IDs" - type: array - items { type: string } - } - search_text { - description: "Free text search query" - type: string - } - framework { - description: "List of frameworks" - type: array - items { type: string } - } - uri { - description: "List of model URIs" - type: array - items { type: string } - } - _all_ { - description: "Multi-field pattern condition (all fields match pattern)" - "$ref": "#/definitions/multi_field_pattern_data" - } - _any_ { - description: "Multi-field pattern condition (any field matches pattern)" - "$ref": "#/definitions/multi_field_pattern_data" - } + response { + type: object + properties { + updated { + description: "Number of models updated (0 or 1)" + type: integer + enum: [0, 1] } - dependencies { - page: [ page_size ] - } - } - response { - type: object - properties { - models: { - description: "Models list" - type: array - items { "$ref": "#/definitions/model" } - } - } - } - } - } - get_frameworks { - "2.8" { - description: "Get the list of frameworks used in the company models" - request { - type: object - properties { - projects { - description: "The list of projects which models will be analyzed. If not passed or empty then all the company and public models will be analyzed" - type: array - items: {type: string} - } - } - } - response { - type: object - properties { - frameworks { - description: "Unique list of the frameworks used in the company models" - type: array - items: {type: string} - } - } - } - } - } - update_for_task { - "2.1" { - description: "Create or update a new model for a task" - request { - type: object - required: [ - task - ] - properties { - task { - description: "Task id" - type: string - } - uri { - description: "URI for the model. Exactly one of uri or override_model_id is a required." - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - override_model_id { - description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required." - type: string - } - iteration { - description: "Iteration (used to update task statistics)" - type: integer - } - } - } - response { - type: object - properties { - id { - description: "ID of the model" - type: string - } - created { - description: "Was the model created" - type: boolean - } - updated { - description: "Number of models updated (0 or 1)" - type: integer - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } - } - } - create { - "2.1" { - description: "Create a new model not associated with a task" - request { - type: object - required: [ - uri - name - ] - properties { - uri { - description: "URI for the model" - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - framework { - description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." - type: string - } - design { - description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" - type: object - additionalProperties: true - } - labels { - description: "Json object" - type: object - additionalProperties { type: integer } - } - ready { - description: "Indication if the model is final and can be used by other tasks. Default is false." - type: boolean - default: false - } - public { - description: "Create a public model Default is false." - type: boolean - default: false - } - project { - description: "Project to which to model belongs" - type: string - } - parent { - description: "Parent model" - type: string - } - task { - description: "Associated task ID" - type: string - } - } - } - response { - type: object - properties { - id { - description: "ID of the model" - type: string - } - created { - description: "Was the model created" - type: boolean - } - } - } - } - } - edit { - "2.1" { - description: "Edit an existing model" - request { - type: object - required: [ - model - ] - properties { - model { - description: "Model ID" - type: string - } - uri { - description: "URI for the model" - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - framework { - description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model." - type: string - } - design { - description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model" - type: object - additionalProperties: true - } - labels { - description: "Json object" - type: object - additionalProperties { type: integer } - } - ready { - description: "Indication if the model is final and can be used by other tasks" - type: boolean - } - project { - description: "Project to which to model belongs" - type: string - } - parent { - description: "Parent model" - type: string - } - task { - description: "Associated task ID" - type: string - } - iteration { - description: "Iteration (used to update task statistics)" - type: integer - } - } - } - response { - type: object - properties { - updated { - description: "Number of models updated (0 or 1)" - type: integer - enum: [0, 1] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } - } - } - update { - "2.1" { - description: "Update a model" - request { - type: object - required: [ model ] - properties { - model { - description: "Model id" - type: string - } - name { - description: "Model name Unique within the company." - type: string - } - comment { - description: "Model comment" - type: string - } - tags { - type: array - description: "User-defined tags" - items { type: string } - } - system_tags { - type: array - description: "System tags. This field is reserved for system use, please don't use it." - items {type: string} - } - ready { - description: "Indication if the model is final and can be used by other tasks Default is false." - type: boolean - default: false - } - created { - description: "Model creation time (UTC) " - type: string - format: "date-time" - } - ui_cache { - description: "UI cache for this model" - type: object - additionalProperties: true - } - project { - description: "Project to which to model belongs" - type: string - } - task { - description: "Associated task ID" - type: "string" - } - iteration { - description: "Iteration (used to update task statistics if an associated task is reported)" - type: integer - } - - } - } - response { - type: object - properties { - updated { - description: "Number of models updated (0 or 1)" - type: integer - enum: [0, 1] - } - fields { - description: "Updated fields names and values" - type: object - additionalProperties: true - } - } - } - } - } - set_ready { - "2.1" { - description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task." - request { - type: object - required: [ model ] - properties { - model { - description: "Model id" - type: string - } - force_publish_task { - description: "Publish the associated task (if exists) even if it is not in the 'stopped' state. Optional, the default value is False." - type: boolean - } - publish_task { - description: "Indicates that the associated task (if exists) should be published. Optional, the default value is True." - type: boolean - } - } - } - response { - type: object - properties { - updated { - description: "Number of models updated (0 or 1)" - type: integer - enum: [0, 1] - } - published_task { - description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing." - type: object - properties { - id { - description: "Task id" - type: string - } - data { - description: "Data returned from the task publishing operation." - type: object - properties { - committed_versions_results { - description: "Committed versions results" - type: array - items { - type: object - additionalProperties: true - } - } - updated { - description: "Number of tasks updated (0 or 1)" - type: integer - enum: [ 0, 1 ] - } - fields { - description: "Updated fields names and values" + published_task { + description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing." + type: object + properties { + id { + description: "Task id" + type: string + } + data { + description: "Data returned from the task publishing operation." + type: object + properties { + committed_versions_results { + description: "Committed versions results" + type: array + items { type: object additionalProperties: true } } + updated { + description: "Number of tasks updated (0 or 1)" + type: integer + enum: [ 0, 1 ] + } + fields { + description: "Updated fields names and values" + type: object + additionalProperties: true + } } } } @@ -662,37 +660,87 @@ } } } - delete { - "2.1" { - description: "Delete a model." - request { - required: [ - model - ] - type: object - properties { - model { - description: "Model ID" - type: string - } - force { - description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published. - """ - type: boolean - } +} +delete { + "2.1" { + description: "Delete a model." + request { + required: [ + model + ] + type: object + properties { + model { + description: "Model ID" + type: string + } + force { + description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published. + """ + type: boolean + } + } + } + response { + type: object + properties { + deleted { + description: "Indicates whether the model was deleted" + type: boolean + } + + } + } + } +} + +make_public { + "2.9" { + description: """Convert company models to public""" + request { + type: object + properties { + ids { + description: "Ids of the models to convert" + type: array + items { type: string} } } - response { - type: object - properties { - deleted { - description: "Indicates whether the model was deleted" - type: boolean - } - + } + response { + type: object + properties { + updated { + description: "Number of models updated" + type: integer } } } } } + +make_private { + "2.9" { + description: """Convert public models to private""" + request { + type: object + properties { + ids { + description: "Ids of the models to convert. Only the models originated by the company can be converted" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of models updated" + type: integer + } + } + } + } +} \ No newline at end of file diff --git a/server/schema/services/projects.conf b/server/schema/services/projects.conf index 5cd5038..e3cb4fb 100644 --- a/server/schema/services/projects.conf +++ b/server/schema/services/projects.conf @@ -573,6 +573,7 @@ get_hyper_parameters { } } } + get_task_tags { "2.8" { description: "Get user and system tags used for the tasks under the specified projects" @@ -580,10 +581,61 @@ get_task_tags { response = ${_definitions.tags_response} } } + get_model_tags { "2.8" { description: "Get user and system tags used for the models under the specified projects" request = ${_definitions.tags_request} response = ${_definitions.tags_response} } +} + +make_public { + "2.9" { + description: """Convert company projects to public""" + request { + type: object + properties { + ids { + description: "Ids of the projects to convert" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of projects updated" + type: integer + } + } + } + } +} + +make_private { + "2.9" { + description: """Convert public projects to private""" + request { + type: object + properties { + ids { + description: "Ids of the projects to convert. Only the projects originated by the company can be converted" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of projects updated" + type: integer + } + } + } + } } \ No newline at end of file diff --git a/server/schema/services/tasks.conf b/server/schema/services/tasks.conf index 2f0d10b..7b4c747 100644 --- a/server/schema/services/tasks.conf +++ b/server/schema/services/tasks.conf @@ -1441,4 +1441,54 @@ add_or_update_artifacts { } } } +} + +make_public { + "2.9" { + description: """Convert company tasks to public""" + request { + type: object + properties { + ids { + description: "Ids of the tasks to convert" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of tasks updated" + type: integer + } + } + } + } +} + +make_private { + "2.9" { + description: """Convert public tasks to private""" + request { + type: object + properties { + ids { + description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted" + type: array + items { type: string} + } + } + } + response { + type: object + properties { + updated { + description: "Number of tasks updated" + type: integer + } + } + } + } } \ No newline at end of file diff --git a/server/schema/services/workers.conf b/server/schema/services/workers.conf index 1a18498..81ef288 100644 --- a/server/schema/services/workers.conf +++ b/server/schema/services/workers.conf @@ -135,6 +135,10 @@ description: "Task currently being run by the worker" "$ref": "#/definitions/current_task_entry" } + project { + description: "Project in which currently executing task resides" + "$ref": "#/definitions/id_name_entry" + } queue { description: "Queue from which running task was taken" "$ref": "#/definitions/queue_entry" @@ -151,11 +155,11 @@ type: object properties { id { - description: "Worker ID" + description: "ID" type: string } name { - description: "Worker name" + description: "Name" type: string } } diff --git a/server/server.py b/server/server.py index ec847f9..4458364 100644 --- a/server/server.py +++ b/server/server.py @@ -10,7 +10,7 @@ from werkzeug.exceptions import BadRequest import database from apierrors.base import BaseError from bll.statistics.stats_reporter import StatisticsReporter -from config import config +from config import config, info from elastic.initialize import init_es_data from mongo.initialize import init_mongo_data, pre_populate_data from service_repo import ServiceRepo, APICall @@ -39,9 +39,11 @@ database.initialize() hosts_string = ";".join(sorted(database.get_hosts())) key = "db_init_" + md5(hosts_string.encode()).hexdigest() with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 30)): - print(key) - init_es_data() + empty_es = init_es_data() empty_db = init_mongo_data() +if empty_es and not empty_db: + log.info(f"ES database seems not migrated") + info.missed_es_upgrade = True if empty_db and config.get("apiserver.pre_populate.enabled", False): pre_populate_data() diff --git a/server/service_repo/auth/auth.py b/server/service_repo/auth/auth.py index 9401d46..d2c2483 100644 --- a/server/service_repo/auth/auth.py +++ b/server/service_repo/auth/auth.py @@ -69,6 +69,10 @@ def authorize_credentials(auth_data, service, action, call_data_items): if fixed_user: if secret_key != fixed_user.password: raise errors.unauthorized.InvalidCredentials('bad username or password') + + if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action): + raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest') + query = Q(id=fixed_user.user_id) with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'): diff --git a/server/service_repo/auth/fixed_user.py b/server/service_repo/auth/fixed_user.py index a4188ba..b814864 100644 --- a/server/service_repo/auth/fixed_user.py +++ b/server/service_repo/auth/fixed_user.py @@ -1,14 +1,12 @@ import hashlib from functools import lru_cache -from typing import Sequence, TypeVar +from typing import Sequence, Optional import attr from config import config from config.info import get_default_company -T = TypeVar("T", bound="FixedUser") - class FixedUsersError(Exception): pass @@ -21,6 +19,8 @@ class FixedUser: name: str company: str = get_default_company() + is_guest: bool = False + def __attrs_post_init__(self): self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest() @@ -28,6 +28,10 @@ class FixedUser: def enabled(cls): return config.get("apiserver.auth.fixed_users.enabled", False) + @classmethod + def guest_enabled(cls): + return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False) + @classmethod def validate(cls): if not cls.enabled(): @@ -39,18 +43,50 @@ class FixedUser: ) @classmethod - @lru_cache() - def from_config(cls) -> Sequence[T]: - return [ + # @lru_cache() + def from_config(cls) -> Sequence["FixedUser"]: + users = [ cls(**user) for user in config.get("apiserver.auth.fixed_users.users", []) ] + if cls.guest_enabled(): + users.insert( + 0, + cls.get_guest_user() + ) + + return users + @classmethod @lru_cache() - def get_by_username(cls, username) -> T: + def get_by_username(cls, username) -> "FixedUser": return next( (user for user in cls.from_config() if user.username == username), None ) + @classmethod + @lru_cache() + def is_guest_endpoint(cls, service, action): + """ + Validate a potential guest user, + This method will verify the user is indeed the guest user, + and that the guest user may access the service/action using its username/password + """ + return any( + ep == ".".join((service, action)) + for ep in config.get("services.auth.fixed_users.guest.allow_endpoints", []) + ) + + @classmethod + def get_guest_user(cls) -> Optional["FixedUser"]: + if cls.guest_enabled(): + return cls( + is_guest=True, + username=config.get("services.auth.fixed_users.guest.username"), + password=config.get("services.auth.fixed_users.guest.password"), + name=config.get("services.auth.fixed_users.guest.name"), + company=config.get("services.auth.fixed_users.guest.default_company"), + ) + def __hash__(self): return hash(self.user_id) diff --git a/server/services/auth.py b/server/services/auth.py index 26b5771..8e0eb26 100644 --- a/server/services/auth.py +++ b/server/services/auth.py @@ -16,7 +16,7 @@ from apimodels.auth import ( ) from apimodels.base import UpdateResponse from bll.auth import AuthBLL -from config import config +from config import config, info from database.errors import translate_errors_context from database.model.auth import User from service_repo import APICall, endpoint @@ -176,4 +176,17 @@ def update(call, company_id, _): @endpoint("auth.fixed_users_mode") def fixed_users_mode(call: APICall, *_, **__): - call.result.data = dict(enabled=FixedUser.enabled()) + data = { + "enabled": FixedUser.enabled(), + "migration_warning": info.missed_es_upgrade, + "guest": { + "enabled": FixedUser.guest_enabled(), + } + } + guest_user = FixedUser.get_guest_user() + if guest_user: + data["guest"]["name"] = guest_user.name + data["guest"]["username"] = guest_user.username + data["guest"]["password"] = guest_user.password + + call.result.data = data diff --git a/server/services/models.py b/server/services/models.py index 1a94762..5429866 100644 --- a/server/services/models.py +++ b/server/services/models.py @@ -5,7 +5,8 @@ from mongoengine import Q, EmbeddedDocument import database from apierrors import errors -from apimodels.base import UpdateResponse +from apierrors.errors.bad_request import InvalidModelId +from apimodels.base import UpdateResponse, MakePublicRequest from apimodels.models import ( CreateModelRequest, CreateModelResponse, @@ -467,3 +468,21 @@ def update(call: APICall, company_id, _): if del_count: _reset_cached_tags(company_id, projects=[model.project]) call.result.data = dict(deleted=del_count > 0) + + +@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Model.set_public( + company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True + ) + + +@endpoint( + "models.make_private", min_version="2.9", request_data_model=MakePublicRequest +) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Model.set_public( + company_id, request.ids, invalid_cls=InvalidModelId, enabled=False + ) diff --git a/server/services/projects.py b/server/services/projects.py index e72a282..e51cda5 100644 --- a/server/services/projects.py +++ b/server/services/projects.py @@ -8,7 +8,8 @@ from mongoengine import Q import database from apierrors import errors -from apimodels.base import UpdateResponse +from apierrors.errors.bad_request import InvalidProjectId +from apimodels.base import UpdateResponse, MakePublicRequest from apimodels.projects import ( GetHyperParamReq, GetHyperParamResp, @@ -422,3 +423,23 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest): projects=request.projects, ) call.result.data = get_tags_response(ret) + + +@endpoint( + "projects.make_public", min_version="2.9", request_data_model=MakePublicRequest +) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Project.set_public( + company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True + ) + + +@endpoint( + "projects.make_private", min_version="2.9", request_data_model=MakePublicRequest +) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Project.set_public( + company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False + ) diff --git a/server/services/tasks.py b/server/services/tasks.py index 79d15bb..3c40301 100644 --- a/server/services/tasks.py +++ b/server/services/tasks.py @@ -11,7 +11,8 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS from pymongo import UpdateOne from apierrors import errors, APIError -from apimodels.base import UpdateResponse, IdResponse +from apierrors.errors.bad_request import InvalidTaskId +from apimodels.base import UpdateResponse, IdResponse, MakePublicRequest from apimodels.tasks import ( StartedResponse, ResetResponse, @@ -78,10 +79,24 @@ def set_task_status_from_call( task = TaskBLL.get_task_with_access( request.task, company_id=company_id, - only=tuple({"status", "project"} | fields_resolver.get_names()), + only=tuple( + {"status", "project", "started", "duration"} | fields_resolver.get_names() + ), requires_write_access=True, ) + if "duration" not in fields_resolver.get_names(): + if new_status == Task.started: + fields_resolver.add_fields(min__duration=max(0, task.duration or 0)) + elif new_status in ( + TaskStatus.completed, + TaskStatus.failed, + TaskStatus.stopped, + ): + fields_resolver.add_fields( + duration=int((task.started - datetime.utcnow()).total_seconds()) + ) + status_reason = request.status_reason status_message = request.status_message force = request.force @@ -354,9 +369,7 @@ def _update_cached_tags(company: str, project: str, fields: dict): def _reset_cached_tags(company: str, projects: Sequence[str]): - org_bll.reset_tags( - company, Tags.Task, projects=projects - ) + org_bll.reset_tags(company, Tags.Task, projects=projects) @endpoint( @@ -573,9 +586,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest): if updated: new_project = fixed_fields.get("project", task.project) if new_project != task.project: - _reset_cached_tags( - company_id, projects=[new_project, task.project] - ) + _reset_cached_tags(company_id, projects=[new_project, task.project]) else: _update_cached_tags( company_id, project=task.project, fields=fixed_fields @@ -1005,3 +1016,19 @@ def add_or_update_artifacts( task_id=request.task, company_id=company_id, artifacts=request.artifacts ) call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated) + + +@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Task.set_public( + company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True + ) + + +@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest) +def make_public(call: APICall, company_id, request: MakePublicRequest): + with translate_errors_context(): + call.result.data = Task.set_public( + company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False + ) diff --git a/server/tests/automated/test_models.py b/server/tests/automated/test_models.py index e6d9747..126a5f4 100644 --- a/server/tests/automated/test_models.py +++ b/server/tests/automated/test_models.py @@ -1,3 +1,4 @@ +from apierrors.errors.bad_request import InvalidModelId from tests.automated import TestService MODEL_CANNOT_BE_UPDATED_CODES = (400, 203) @@ -7,7 +8,7 @@ IN_PROGRESS = "in_progress" class TestModelsService(TestService): - def setUp(self, version="2.8"): + def setUp(self, version="2.9"): super().setUp(version=version) def test_publish_output_model_running_task(self): @@ -197,6 +198,28 @@ class TestModelsService(TestService): res = self.api.models.get_frameworks(projects=[project]) self.assertEqual([], res.frameworks) + def test_make_public(self): + m1 = self._create_model(name="public model test") + + # model with company_origin not set to the current company cannot be converted to private + with self.api.raises(InvalidModelId): + self.api.models.make_private(ids=[m1]) + + # public model can be retrieved but not updated + res = self.api.models.make_public(ids=[m1]) + self.assertEqual(res.updated, 1) + res = self.api.models.get_all(id=[m1]) + self.assertEqual([m.id for m in res.models], [m1]) + with self.api.raises(InvalidModelId): + self.api.models.update(model=m1, name="public model test change 1") + + # task made private again and can be both retrieved and updated + res = self.api.models.make_private(ids=[m1]) + self.assertEqual(res.updated, 1) + res = self.api.models.get_all(id=[m1]) + self.assertEqual([m.id for m in res.models], [m1]) + self.api.models.update(model=m1, name="public model test change 2") + def _assert_task_status(self, task_id, status): task = self.api.tasks.get_by_id(task=task_id).task assert task.status == status diff --git a/server/tests/automated/test_projects_edit.py b/server/tests/automated/test_projects_edit.py new file mode 100644 index 0000000..49863cf --- /dev/null +++ b/server/tests/automated/test_projects_edit.py @@ -0,0 +1,34 @@ +from apierrors.errors.bad_request import InvalidProjectId +from apierrors.errors.forbidden import NoWritePermission +from config import config +from tests.automated import TestService + + +log = config.logger(__file__) + + +class TestProjectsEdit(TestService): + def setUp(self, **kwargs): + super().setUp(version="2.9") + + def test_make_public(self): + p1 = self.create_temp("projects", name="Test public", description="test") + + # project with company_origin not set to the current company cannot be converted to private + with self.api.raises(InvalidProjectId): + self.api.projects.make_private(ids=[p1]) + + # public project can be retrieved but not updated + res = self.api.projects.make_public(ids=[p1]) + self.assertEqual(res.updated, 1) + res = self.api.projects.get_all(id=[p1]) + self.assertEqual([p.id for p in res.projects], [p1]) + with self.api.raises(NoWritePermission): + self.api.projects.update(project=p1, name="Test public change 1") + + # task made private again and can be both retrieved and updated + res = self.api.projects.make_private(ids=[p1]) + self.assertEqual(res.updated, 1) + res = self.api.projects.get_all(id=[p1]) + self.assertEqual([p.id for p in res.projects], [p1]) + self.api.projects.update(project=p1, name="Test public change 2") diff --git a/server/tests/automated/test_tasks_edit.py b/server/tests/automated/test_tasks_edit.py index 0819f43..eec8c57 100644 --- a/server/tests/automated/test_tasks_edit.py +++ b/server/tests/automated/test_tasks_edit.py @@ -1,4 +1,5 @@ -from apierrors.errors.bad_request import InvalidModelId, ValidationError +from apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId +from apierrors.errors.forbidden import NoWritePermission from config import config from tests.automated import TestService @@ -8,7 +9,7 @@ log = config.logger(__file__) class TestTasksEdit(TestService): def setUp(self, **kwargs): - super().setUp(version=2.5) + super().setUp(version="2.9") def new_task(self, **kwargs): self.update_missing( @@ -145,3 +146,28 @@ class TestTasksEdit(TestService): self.api.tasks.delete, task=new_task, move_to_trash=False, force=True ) return new_task + + def test_make_public(self): + task = self.new_task() + + # task is created as private and can be updated + self.api.tasks.started(task=task) + + # task with company_origin not set to the current company cannot be converted to private + with self.api.raises(InvalidTaskId): + self.api.tasks.make_private(ids=[task]) + + # public task can be retrieved but not updated + res = self.api.tasks.make_public(ids=[task]) + self.assertEqual(res.updated, 1) + res = self.api.tasks.get_all_ex(id=[task]) + self.assertEqual([t.id for t in res.tasks], [task]) + with self.api.raises(NoWritePermission): + self.api.tasks.stopped(task=task) + + # task made private again and can be both retrieved and updated + res = self.api.tasks.make_private(ids=[task]) + self.assertEqual(res.updated, 1) + res = self.api.tasks.get_all_ex(id=[task]) + self.assertEqual([t.id for t in res.tasks], [task]) + self.api.tasks.stopped(task=task)