Fix support for example projects and experiments in demo server

This commit is contained in:
allegroai 2020-07-06 22:06:42 +03:00
parent 901ec37290
commit aa22170ab4
7 changed files with 181 additions and 94 deletions

View File

@ -84,7 +84,7 @@ class EventMetrics:
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name"),
override_projection=("id", "name", "company"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
@ -93,8 +93,14 @@ class EventMetrics:
task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.company for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
ret = self._run_get_scalar_metrics_as_parallel(
company_id,
next(iter(companies)),
task_ids=task_ids,
samples=samples,
key=ScalarKey.resolve(key),

View File

@ -0,0 +1,18 @@
from typing import Optional, Sequence
from mongoengine import Q
from database.model.model import Model
from database.utils import get_company_or_none_constraint
class ModelBLL:
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")

View File

@ -85,22 +85,25 @@ class TaskBLL(object):
company_id,
task_id,
required_status=None,
required_dataset=None,
only_fields=None,
allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
only_fields = [only_fields]
else:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
qs = Task.objects(id=task_id, company=company_id)
if only_fields:
qs = (
qs.only(only_fields)
if isinstance(only_fields, string_types)
else qs.only(*only_fields)
)
qs = qs.only(
"status", "input"
) # make sure all fields we rely on here are also returned
task = qs.first()
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
@ -108,17 +111,12 @@ class TaskBLL(object):
if required_status and not task.status == required_status:
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
if required_dataset and required_dataset not in (
entry.dataset for entry in task.input.view.entries
):
raise errors.bad_request.InvalidId(
"not in input view", dataset=required_dataset
)
return task
@staticmethod
def assert_exists(company_id, task_ids, only=None, allow_public=False):
def assert_exists(
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
ids = set(task_ids)
@ -128,15 +126,18 @@ class TaskBLL(object):
allow_public=allow_public,
return_dicts=False,
)
res = None
if only:
res = q.only(*only)
count = len(res)
else:
count = q.count()
res = q.first()
elif return_tasks:
res = list(q)
count = len(res) if res is not None else q.count()
if count != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids)
return res
if return_tasks:
return res
@staticmethod
def create(call: APICall, fields: dict):
@ -181,7 +182,7 @@ class TaskBLL(object):
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
task = cls.get_by_id(company_id=company_id, task_id=task_id)
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
if execution_overrides:

View File

@ -33,6 +33,10 @@
artifacts_path: "/mnt/fileserver"
}
# time in seconds to take an exclusive lock to init es and mongodb
# not including the pre_populate
db_init_timout: 30
mongo {
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model

View File

@ -79,6 +79,10 @@ def get_entries():
return _entries
def get_hosts():
return [entry.host for entry in get_entries()]
def get_aliases():
return [entry.alias for entry in get_entries()]

View File

@ -405,6 +405,11 @@ get_all_ex {
enum: [ active, archived ]
default: active
}
non_public {
description: "Return only non-public projects"
type: boolean
default: false
}
}
}
}

View File

@ -48,12 +48,14 @@ def add_batch(call: APICall, company_id, req_model):
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log_v1_5(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id,
task.company,
task_id,
order,
event_type="log",
@ -68,7 +70,9 @@ def get_task_log_v1_5(call, company_id, req_model):
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
order = call.data.get("order") or "desc"
from_ = call.data.get("from") or "head"
@ -78,7 +82,7 @@ def get_task_log_v1_7(call, company_id, req_model):
scroll_order = "asc" if (from_ == "head") else "desc"
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=company_id,
company_id=task.company,
task_id=task_id,
order=scroll_order,
event_type="log",
@ -97,10 +101,12 @@ def get_task_log_v1_7(call, company_id, req_model):
@endpoint("events.get_task_log", min_version="2.9", request_data_model=LogEventsRequest)
def get_task_log(call, company_id, req_model: LogEventsRequest):
task_id = req_model.task
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
res = event_bll.log_events_iterator.get_task_events(
company_id=company_id,
company_id=task.company,
task_id=task_id,
batch_size=req_model.batch_size,
navigate_earlier=req_model.navigate_earlier,
@ -108,16 +114,16 @@ def get_task_log(call, company_id, req_model: LogEventsRequest):
)
call.result.data = dict(
events=res.events,
returned=len(res.events),
total=res.total_events,
events=res.events, returned=len(res.events), total=res.total_events
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, req_model):
def download_task_log(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
line_type = call.data.get("line_type", "json").lower()
line_format = str(call.data.get("line_format", "{asctime} {worker} {level} {msg}"))
@ -160,7 +166,7 @@ def download_task_log(call, company_id, req_model):
batch_size = 1000
while True:
log_events, scroll_id, _ = event_bll.scroll_task_events(
company_id,
task.company,
task_id,
order="asc",
event_type="log",
@ -193,23 +199,27 @@ def download_task_log(call, company_id, req_model):
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, req_model):
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
company_id, task_id, "training_stats_vector"
task.company, task_id, "training_stats_vector"
)
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, req_model):
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
company_id, task_id, "training_stats_scalar"
task.company, task_id, "training_stats_scalar"
)
)
@ -219,13 +229,15 @@ def get_scalar_metrics_and_variants(call, company_id, req_model):
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, req_model):
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(
company_id, task_id, metric, variant
task.company, task_id, metric, variant
)
call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations
@ -240,9 +252,11 @@ def get_task_events(call, company_id, _):
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
@ -259,14 +273,16 @@ def get_task_events(call, company_id, _):
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, req_model):
def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}],
@ -283,13 +299,15 @@ def get_scalar_metric_data(call, company_id, req_model):
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, req_model):
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
company_id, task_id
task.company, task_id
)
es_index = EventMetrics.get_index_name(company_id, "*")
es_index = EventMetrics.get_index_name(task.company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,
@ -306,11 +324,13 @@ def get_task_latest_scalar_values(call, company_id, req_model):
request_data_model=ScalarMetricsIterHistogramRequest,
)
def scalar_metrics_iter_histogram(
call, company_id, req_model: ScalarMetricsIterHistogramRequest
call, company_id, request: ScalarMetricsIterHistogramRequest
):
task_bll.assert_exists(call.identity.company, req_model.task, allow_public=True)
task = task_bll.assert_exists(
company_id, request.task, allow_public=True, only=("company",)
)[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
company_id, task_id=req_model.task, samples=req_model.samples, key=req_model.key
task.company, task_id=request.task, samples=request.samples, key=request.key
)
call.result.data = metrics
@ -338,21 +358,27 @@ def multi_task_scalar_metrics_iter_histogram(
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, req_model):
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name"),
company_id=company_id,
only=("id", "name", "company"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.company for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id,
next(iter(companies)),
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
@ -382,13 +408,19 @@ def get_multi_task_plots(call, company_id, req_model):
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name"),
only=("id", "name", "company"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.company for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.get_task_events(
company_id,
next(iter(companies)),
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
@ -411,12 +443,14 @@ def get_multi_task_plots(call, company_id, req_model):
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, req_model):
def get_task_plots_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
# event_type="plot",
@ -426,7 +460,7 @@ def get_task_plots_v1_7(call, company_id, req_model):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
@ -445,14 +479,16 @@ def get_task_plots_v1_7(call, company_id, req_model):
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, req_model):
def get_task_plots(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_plots(
company_id,
task.company,
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
@ -470,12 +506,14 @@ def get_task_plots(call, company_id, req_model):
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, req_model):
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
# event_type="training_debug_image",
@ -485,7 +523,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
@ -505,14 +543,16 @@ def get_debug_images_v1_7(call, company_id, req_model):
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images_v1_8(call, company_id, req_model):
def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
@ -537,16 +577,25 @@ def get_debug_images_v1_8(call, company_id, req_model):
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
tasks = set(m.task for m in req_model.metrics)
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics}
tasks = task_bll.assert_exists(
company_id, task_ids=task_ids, allow_public=True, only=("company",)
)
companies = {t.company for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.debug_images_iterator.get_task_events(
company_id=company_id,
metrics=[(m.task, m.metric) for m in req_model.metrics],
iter_count=req_model.iters,
navigate_earlier=req_model.navigate_earlier,
refresh=req_model.refresh,
state_id=req_model.scroll_id,
company_id=next(iter(companies)),
metrics=[(m.task, m.metric) for m in request.metrics],
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
refresh=request.refresh,
state_id=request.scroll_id,
)
call.result.data_model = DebugImageResponse(
@ -566,12 +615,12 @@ def get_debug_images(call, company_id, req_model: DebugImagesRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
task_bll.assert_exists(
call.identity.company, task_ids=req_model.tasks, allow_public=True
)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id, task_ids=request.tasks, allow_public=True, only=("company",)
)[0]
res = event_bll.metrics.get_tasks_metrics(
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
task.company, task_ids=request.tasks, event_type=request.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
@ -583,7 +632,7 @@ def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id)
task_bll.assert_exists(company_id, task_id, return_tasks=False)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked