From aa22170ab41bb5332e175286b7369412c0fd919c Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Mon, 6 Jul 2020 22:06:42 +0300 Subject: [PATCH] Fix support for example projects and experiments in demo server --- server/bll/event/event_metrics.py | 10 +- server/bll/model/__init__.py | 18 +++ server/bll/task/task_bll.py | 53 ++++---- server/config/default/apiserver.conf | 4 + server/database/__init__.py | 4 + server/schema/services/projects.conf | 5 + server/services/events.py | 181 +++++++++++++++++---------- 7 files changed, 181 insertions(+), 94 deletions(-) create mode 100644 server/bll/model/__init__.py diff --git a/server/bll/event/event_metrics.py b/server/bll/event/event_metrics.py index e42565d..a331f2d 100644 --- a/server/bll/event/event_metrics.py +++ b/server/bll/event/event_metrics.py @@ -84,7 +84,7 @@ class EventMetrics: company=company_id, query=Q(id__in=task_ids), allow_public=allow_public, - override_projection=("id", "name"), + override_projection=("id", "name", "company"), return_dicts=False, ) if len(task_objs) < len(task_ids): @@ -93,8 +93,14 @@ class EventMetrics: task_name_by_id = {t.id: t.name for t in task_objs} + companies = {t.company for t in task_objs} + if len(companies) > 1: + raise errors.bad_request.InvalidTaskId( + "only tasks from the same company are supported" + ) + ret = self._run_get_scalar_metrics_as_parallel( - company_id, + next(iter(companies)), task_ids=task_ids, samples=samples, key=ScalarKey.resolve(key), diff --git a/server/bll/model/__init__.py b/server/bll/model/__init__.py new file mode 100644 index 0000000..6fa58d4 --- /dev/null +++ b/server/bll/model/__init__.py @@ -0,0 +1,18 @@ +from typing import Optional, Sequence + +from mongoengine import Q + +from database.model.model import Model +from database.utils import get_company_or_none_constraint + + +class ModelBLL: + def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence: + """ + Return the list of unique frameworks used by company and public models + If project ids passed then only models from these projects are considered + """ + query = get_company_or_none_constraint(company) + if project_ids: + query &= Q(project__in=project_ids) + return Model.objects(query).distinct(field="framework") diff --git a/server/bll/task/task_bll.py b/server/bll/task/task_bll.py index ef80fa5..88a6276 100644 --- a/server/bll/task/task_bll.py +++ b/server/bll/task/task_bll.py @@ -85,22 +85,25 @@ class TaskBLL(object): company_id, task_id, required_status=None, - required_dataset=None, only_fields=None, + allow_public=False, ): + if only_fields: + if isinstance(only_fields, string_types): + only_fields = [only_fields] + else: + only_fields = list(only_fields) + only_fields = only_fields + ["status"] with TimingContext("mongo", "task_by_id_all"): - qs = Task.objects(id=task_id, company=company_id) - if only_fields: - qs = ( - qs.only(only_fields) - if isinstance(only_fields, string_types) - else qs.only(*only_fields) - ) - qs = qs.only( - "status", "input" - ) # make sure all fields we rely on here are also returned - task = qs.first() + tasks = Task.get_many( + company=company_id, + query=Q(id=task_id), + allow_public=allow_public, + override_projection=only_fields, + return_dicts=False, + ) + task = None if not tasks else tasks[0] if not task: raise errors.bad_request.InvalidTaskId(id=task_id) @@ -108,17 +111,12 @@ class TaskBLL(object): if required_status and not task.status == required_status: raise errors.bad_request.InvalidTaskStatus(expected=required_status) - if required_dataset and required_dataset not in ( - entry.dataset for entry in task.input.view.entries - ): - raise errors.bad_request.InvalidId( - "not in input view", dataset=required_dataset - ) - return task @staticmethod - def assert_exists(company_id, task_ids, only=None, allow_public=False): + def assert_exists( + company_id, task_ids, only=None, allow_public=False, return_tasks=True + ) -> Optional[Sequence[Task]]: task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids with translate_errors_context(), TimingContext("mongo", "task_exists"): ids = set(task_ids) @@ -128,15 +126,18 @@ class TaskBLL(object): allow_public=allow_public, return_dicts=False, ) + res = None if only: res = q.only(*only) - count = len(res) - else: - count = q.count() - res = q.first() + elif return_tasks: + res = list(q) + + count = len(res) if res is not None else q.count() if count != len(ids): raise errors.bad_request.InvalidTaskId(ids=task_ids) - return res + + if return_tasks: + return res @staticmethod def create(call: APICall, fields: dict): @@ -181,7 +182,7 @@ class TaskBLL(object): execution_overrides: Optional[dict] = None, validate_references: bool = False, ) -> Task: - task = cls.get_by_id(company_id=company_id, task_id=task_id) + task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True) execution_dict = task.execution.to_proper_dict() if task.execution else {} execution_model_overriden = False if execution_overrides: diff --git a/server/config/default/apiserver.conf b/server/config/default/apiserver.conf index d1cd078..e248c35 100644 --- a/server/config/default/apiserver.conf +++ b/server/config/default/apiserver.conf @@ -33,6 +33,10 @@ artifacts_path: "/mnt/fileserver" } + # time in seconds to take an exclusive lock to init es and mongodb + # not including the pre_populate + db_init_timout: 30 + mongo { # controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data # but not declared in a data model diff --git a/server/database/__init__.py b/server/database/__init__.py index 8c666a3..f2aa574 100644 --- a/server/database/__init__.py +++ b/server/database/__init__.py @@ -79,6 +79,10 @@ def get_entries(): return _entries +def get_hosts(): + return [entry.host for entry in get_entries()] + + def get_aliases(): return [entry.alias for entry in get_entries()] diff --git a/server/schema/services/projects.conf b/server/schema/services/projects.conf index d4fc43e..5cd5038 100644 --- a/server/schema/services/projects.conf +++ b/server/schema/services/projects.conf @@ -405,6 +405,11 @@ get_all_ex { enum: [ active, archived ] default: active } + non_public { + description: "Return only non-public projects" + type: boolean + default: false + } } } } diff --git a/server/services/events.py b/server/services/events.py index 2dc5f0b..6c07e5f 100644 --- a/server/services/events.py +++ b/server/services/events.py @@ -48,12 +48,14 @@ def add_batch(call: APICall, company_id, req_model): @endpoint("events.get_task_log", required_fields=["task"]) def get_task_log_v1_5(call, company_id, req_model): task_id = call.data["task"] - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] order = call.data.get("order") or "desc" scroll_id = call.data.get("scroll_id") batch_size = int(call.data.get("batch_size") or 500) events, scroll_id, total_events = event_bll.scroll_task_events( - company_id, + task.company, task_id, order, event_type="log", @@ -68,7 +70,9 @@ def get_task_log_v1_5(call, company_id, req_model): @endpoint("events.get_task_log", min_version="1.7", required_fields=["task"]) def get_task_log_v1_7(call, company_id, req_model): task_id = call.data["task"] - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] order = call.data.get("order") or "desc" from_ = call.data.get("from") or "head" @@ -78,7 +82,7 @@ def get_task_log_v1_7(call, company_id, req_model): scroll_order = "asc" if (from_ == "head") else "desc" events, scroll_id, total_events = event_bll.scroll_task_events( - company_id=company_id, + company_id=task.company, task_id=task_id, order=scroll_order, event_type="log", @@ -97,10 +101,12 @@ def get_task_log_v1_7(call, company_id, req_model): @endpoint("events.get_task_log", min_version="2.9", request_data_model=LogEventsRequest) def get_task_log(call, company_id, req_model: LogEventsRequest): task_id = req_model.task - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] res = event_bll.log_events_iterator.get_task_events( - company_id=company_id, + company_id=task.company, task_id=task_id, batch_size=req_model.batch_size, navigate_earlier=req_model.navigate_earlier, @@ -108,16 +114,16 @@ def get_task_log(call, company_id, req_model: LogEventsRequest): ) call.result.data = dict( - events=res.events, - returned=len(res.events), - total=res.total_events, + events=res.events, returned=len(res.events), total=res.total_events ) @endpoint("events.download_task_log", required_fields=["task"]) -def download_task_log(call, company_id, req_model): +def download_task_log(call, company_id, _): task_id = call.data["task"] - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] line_type = call.data.get("line_type", "json").lower() line_format = str(call.data.get("line_format", "{asctime} {worker} {level} {msg}")) @@ -160,7 +166,7 @@ def download_task_log(call, company_id, req_model): batch_size = 1000 while True: log_events, scroll_id, _ = event_bll.scroll_task_events( - company_id, + task.company, task_id, order="asc", event_type="log", @@ -193,23 +199,27 @@ def download_task_log(call, company_id, req_model): @endpoint("events.get_vector_metrics_and_variants", required_fields=["task"]) -def get_vector_metrics_and_variants(call, company_id, req_model): +def get_vector_metrics_and_variants(call, company_id, _): task_id = call.data["task"] - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] call.result.data = dict( metrics=event_bll.get_metrics_and_variants( - company_id, task_id, "training_stats_vector" + task.company, task_id, "training_stats_vector" ) ) @endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"]) -def get_scalar_metrics_and_variants(call, company_id, req_model): +def get_scalar_metrics_and_variants(call, company_id, _): task_id = call.data["task"] - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] call.result.data = dict( metrics=event_bll.get_metrics_and_variants( - company_id, task_id, "training_stats_scalar" + task.company, task_id, "training_stats_scalar" ) ) @@ -219,13 +229,15 @@ def get_scalar_metrics_and_variants(call, company_id, req_model): "events.vector_metrics_iter_histogram", required_fields=["task", "metric", "variant"], ) -def vector_metrics_iter_histogram(call, company_id, req_model): +def vector_metrics_iter_histogram(call, company_id, _): task_id = call.data["task"] - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] metric = call.data["metric"] variant = call.data["variant"] iterations, vectors = event_bll.get_vector_metrics_per_iter( - company_id, task_id, metric, variant + task.company, task_id, metric, variant ) call.result.data = dict( metric=metric, variant=variant, vectors=vectors, iterations=iterations @@ -240,9 +252,11 @@ def get_task_events(call, company_id, _): scroll_id = call.data.get("scroll_id") order = call.data.get("order") or "asc" - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] result = event_bll.get_task_events( - company_id, + task.company, task_id, sort=[{"timestamp": {"order": order}}], event_type=event_type, @@ -259,14 +273,16 @@ def get_task_events(call, company_id, _): @endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"]) -def get_scalar_metric_data(call, company_id, req_model): +def get_scalar_metric_data(call, company_id, _): task_id = call.data["task"] metric = call.data["metric"] scroll_id = call.data.get("scroll_id") - task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] result = event_bll.get_task_events( - company_id, + task.company, task_id, event_type="training_stats_scalar", sort=[{"iter": {"order": "desc"}}], @@ -283,13 +299,15 @@ def get_scalar_metric_data(call, company_id, req_model): @endpoint("events.get_task_latest_scalar_values", required_fields=["task"]) -def get_task_latest_scalar_values(call, company_id, req_model): +def get_task_latest_scalar_values(call, company_id, _): task_id = call.data["task"] - task = task_bll.assert_exists(company_id, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] metrics, last_timestamp = event_bll.get_task_latest_scalar_values( - company_id, task_id + task.company, task_id ) - es_index = EventMetrics.get_index_name(company_id, "*") + es_index = EventMetrics.get_index_name(task.company, "*") last_iters = event_bll.get_last_iters(es_index, task_id, None, 1) call.result.data = dict( metrics=metrics, @@ -306,11 +324,13 @@ def get_task_latest_scalar_values(call, company_id, req_model): request_data_model=ScalarMetricsIterHistogramRequest, ) def scalar_metrics_iter_histogram( - call, company_id, req_model: ScalarMetricsIterHistogramRequest + call, company_id, request: ScalarMetricsIterHistogramRequest ): - task_bll.assert_exists(call.identity.company, req_model.task, allow_public=True) + task = task_bll.assert_exists( + company_id, request.task, allow_public=True, only=("company",) + )[0] metrics = event_bll.metrics.get_scalar_metrics_average_per_iter( - company_id, task_id=req_model.task, samples=req_model.samples, key=req_model.key + task.company, task_id=request.task, samples=request.samples, key=request.key ) call.result.data = metrics @@ -338,21 +358,27 @@ def multi_task_scalar_metrics_iter_histogram( @endpoint("events.get_multi_task_plots", required_fields=["tasks"]) -def get_multi_task_plots_v1_7(call, company_id, req_model): +def get_multi_task_plots_v1_7(call, company_id, _): task_ids = call.data["tasks"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") tasks = task_bll.assert_exists( - company_id=call.identity.company, - only=("id", "name"), + company_id=company_id, + only=("id", "name", "company"), task_ids=task_ids, allow_public=True, ) + companies = {t.company for t in tasks} + if len(companies) > 1: + raise errors.bad_request.InvalidTaskId( + "only tasks from the same company are supported" + ) + # Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( - company_id, + next(iter(companies)), task_ids, event_type="plot", sort=[{"iter": {"order": "desc"}}], @@ -382,13 +408,19 @@ def get_multi_task_plots(call, company_id, req_model): tasks = task_bll.assert_exists( company_id=call.identity.company, - only=("id", "name"), + only=("id", "name", "company"), task_ids=task_ids, allow_public=True, ) + companies = {t.company for t in tasks} + if len(companies) > 1: + raise errors.bad_request.InvalidTaskId( + "only tasks from the same company are supported" + ) + result = event_bll.get_task_events( - company_id, + next(iter(companies)), task_ids, event_type="plot", sort=[{"iter": {"order": "desc"}}], @@ -411,12 +443,14 @@ def get_multi_task_plots(call, company_id, req_model): @endpoint("events.get_task_plots", required_fields=["task"]) -def get_task_plots_v1_7(call, company_id, req_model): +def get_task_plots_v1_7(call, company_id, _): task_id = call.data["task"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") - task_bll.assert_exists(call.identity.company, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] # events, next_scroll_id, total_events = event_bll.get_task_events( # company, task_id, # event_type="plot", @@ -426,7 +460,7 @@ def get_task_plots_v1_7(call, company_id, req_model): # get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( - company_id, + task.company, task_id, event_type="plot", sort=[{"iter": {"order": "desc"}}], @@ -445,14 +479,16 @@ def get_task_plots_v1_7(call, company_id, req_model): @endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"]) -def get_task_plots(call, company_id, req_model): +def get_task_plots(call, company_id, _): task_id = call.data["task"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") - task_bll.assert_exists(call.identity.company, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] result = event_bll.get_task_plots( - company_id, + task.company, tasks=[task_id], sort=[{"iter": {"order": "desc"}}], last_iterations_per_plot=iters, @@ -470,12 +506,14 @@ def get_task_plots(call, company_id, req_model): @endpoint("events.debug_images", required_fields=["task"]) -def get_debug_images_v1_7(call, company_id, req_model): +def get_debug_images_v1_7(call, company_id, _): task_id = call.data["task"] iters = call.data.get("iters") or 1 scroll_id = call.data.get("scroll_id") - task_bll.assert_exists(call.identity.company, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] # events, next_scroll_id, total_events = event_bll.get_task_events( # company, task_id, # event_type="training_debug_image", @@ -485,7 +523,7 @@ def get_debug_images_v1_7(call, company_id, req_model): # get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( - company_id, + task.company, task_id, event_type="training_debug_image", sort=[{"iter": {"order": "desc"}}], @@ -505,14 +543,16 @@ def get_debug_images_v1_7(call, company_id, req_model): @endpoint("events.debug_images", min_version="1.8", required_fields=["task"]) -def get_debug_images_v1_8(call, company_id, req_model): +def get_debug_images_v1_8(call, company_id, _): task_id = call.data["task"] iters = call.data.get("iters") or 1 scroll_id = call.data.get("scroll_id") - task_bll.assert_exists(call.identity.company, task_id, allow_public=True) + task = task_bll.assert_exists( + company_id, task_id, allow_public=True, only=("company",) + )[0] result = event_bll.get_task_events( - company_id, + task.company, task_id, event_type="training_debug_image", sort=[{"iter": {"order": "desc"}}], @@ -537,16 +577,25 @@ def get_debug_images_v1_8(call, company_id, req_model): request_data_model=DebugImagesRequest, response_data_model=DebugImageResponse, ) -def get_debug_images(call, company_id, req_model: DebugImagesRequest): - tasks = set(m.task for m in req_model.metrics) - task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True) +def get_debug_images(call, company_id, request: DebugImagesRequest): + task_ids = {m.task for m in request.metrics} + tasks = task_bll.assert_exists( + company_id, task_ids=task_ids, allow_public=True, only=("company",) + ) + + companies = {t.company for t in tasks} + if len(companies) > 1: + raise errors.bad_request.InvalidTaskId( + "only tasks from the same company are supported" + ) + result = event_bll.debug_images_iterator.get_task_events( - company_id=company_id, - metrics=[(m.task, m.metric) for m in req_model.metrics], - iter_count=req_model.iters, - navigate_earlier=req_model.navigate_earlier, - refresh=req_model.refresh, - state_id=req_model.scroll_id, + company_id=next(iter(companies)), + metrics=[(m.task, m.metric) for m in request.metrics], + iter_count=request.iters, + navigate_earlier=request.navigate_earlier, + refresh=request.refresh, + state_id=request.scroll_id, ) call.result.data_model = DebugImageResponse( @@ -566,12 +615,12 @@ def get_debug_images(call, company_id, req_model: DebugImagesRequest): @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) -def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest): - task_bll.assert_exists( - call.identity.company, task_ids=req_model.tasks, allow_public=True - ) +def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest): + task = task_bll.assert_exists( + company_id, task_ids=request.tasks, allow_public=True, only=("company",) + )[0] res = event_bll.metrics.get_tasks_metrics( - company_id, task_ids=req_model.tasks, event_type=req_model.event_type + task.company, task_ids=request.tasks, event_type=request.event_type ) call.result.data = { "metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res] @@ -583,7 +632,7 @@ def delete_for_task(call, company_id, req_model): task_id = call.data["task"] allow_locked = call.data.get("allow_locked", False) - task_bll.assert_exists(company_id, task_id) + task_bll.assert_exists(company_id, task_id, return_tasks=False) call.result.data = dict( deleted=event_bll.delete_task_events( company_id, task_id, allow_locked=allow_locked