Fix type annotations

Fix obtaining events for tasks moved from private to public
Fix assert_exists() to return company_origin if requested
This commit is contained in:
allegroai 2021-01-05 16:27:38 +02:00
parent c67a56eb8d
commit 22e9c2b7eb
6 changed files with 59 additions and 46 deletions

View File

@ -115,7 +115,7 @@ class EventMetrics:
company=company_id, company=company_id,
query=Q(id__in=task_ids), query=Q(id__in=task_ids),
allow_public=allow_public, allow_public=allow_public,
override_projection=("id", "name", "company"), override_projection=("id", "name", "company", "company_origin"),
return_dicts=False, return_dicts=False,
) )
if len(task_objs) < len(task_ids): if len(task_objs) < len(task_ids):
@ -123,7 +123,7 @@ class EventMetrics:
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs} task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.company for t in task_objs} companies = {t.get_index_company() for t in task_objs}
if len(companies) > 1: if len(companies) > 1:
raise errors.bad_request.InvalidTaskId( raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported" "only tasks from the same company are supported"

View File

@ -126,7 +126,9 @@ class TaskBLL(object):
return_dicts=False, return_dicts=False,
) )
if only: if only:
q = q.only(*only) # Make sure to reset fields filters (some fields are excluded by default) since this
# is an internal call and specific fields were requested.
q = q.all_fields().only(*only)
if q.count() != len(ids): if q.count() != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids) raise errors.bad_request.InvalidTaskId(ids=task_ids)

View File

@ -773,7 +773,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
): ):
if enabled: if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id")) items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, unset__company=1) update = dict(set__company_origin=company_id, set__company="")
else: else:
items = list( items = list(
cls.objects( cls.objects(

View File

@ -217,3 +217,13 @@ class Task(AttributedDocument):
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem))) hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem)) configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict) runtime = SafeDictField(default=dict)
def get_index_company(self) -> str:
"""
Returns the company ID used for locating indices containing task data.
In case the task has a valid company, this is the company ID.
Otherwise, if the task has a company_origin, this is a task that has been made public and the
origin company should be used.
Otherwise, an empty company is used.
"""
return self.company or self.company_origin or ""

View File

@ -1,4 +1,4 @@
from typing import Text, Sequence, Callable, Union from typing import Text, Sequence, Callable, Union, Type
from funcsigs import signature from funcsigs import signature
from jsonmodels import models from jsonmodels import models
@ -18,8 +18,8 @@ def endpoint(
name: Text, name: Text,
min_version: Text = "1.0", min_version: Text = "1.0",
required_fields: Sequence[Text] = None, required_fields: Sequence[Text] = None,
request_data_model: models.Base = None, request_data_model: Type[models.Base] = None,
response_data_model: models.Base = None, response_data_model: Type[models.Base] = None,
validate_schema=False, validate_schema=False,
): ):
""" Endpoint decorator, used to declare a method as an endpoint handler """ """ Endpoint decorator, used to declare a method as an endpoint handler """

View File

@ -50,13 +50,13 @@ def add_batch(call: APICall, company_id, _):
def get_task_log_v1_5(call, company_id, _): def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
order = call.data.get("order") or "desc" order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500) batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events( events, scroll_id, total_events = event_bll.scroll_task_events(
task.company, task.get_index_company(),
task_id, task_id,
order, order,
event_type="log", event_type="log",
@ -72,7 +72,7 @@ def get_task_log_v1_5(call, company_id, _):
def get_task_log_v1_7(call, company_id, _): def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
order = call.data.get("order") or "desc" order = call.data.get("order") or "desc"
@ -83,7 +83,7 @@ def get_task_log_v1_7(call, company_id, _):
scroll_order = "asc" if (from_ == "head") else "desc" scroll_order = "asc" if (from_ == "head") else "desc"
events, scroll_id, total_events = event_bll.scroll_task_events( events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=task.company, company_id=task.get_index_company(),
task_id=task_id, task_id=task_id,
order=scroll_order, order=scroll_order,
event_type="log", event_type="log",
@ -103,11 +103,11 @@ def get_task_log_v1_7(call, company_id, _):
def get_task_log(call, company_id, request: LogEventsRequest): def get_task_log(call, company_id, request: LogEventsRequest):
task_id = request.task task_id = request.task
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
res = event_bll.log_events_iterator.get_task_events( res = event_bll.log_events_iterator.get_task_events(
company_id=task.company, company_id=task.get_index_company(),
task_id=task_id, task_id=task_id,
batch_size=request.batch_size, batch_size=request.batch_size,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
@ -131,7 +131,7 @@ def get_task_log(call, company_id, request: LogEventsRequest):
def download_task_log(call, company_id, _): def download_task_log(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
line_type = call.data.get("line_type", "json").lower() line_type = call.data.get("line_type", "json").lower()
@ -175,7 +175,7 @@ def download_task_log(call, company_id, _):
batch_size = 1000 batch_size = 1000
while True: while True:
log_events, scroll_id, _ = event_bll.scroll_task_events( log_events, scroll_id, _ = event_bll.scroll_task_events(
task.company, task.get_index_company(),
task_id, task_id,
order="asc", order="asc",
event_type="log", event_type="log",
@ -211,11 +211,11 @@ def download_task_log(call, company_id, _):
def get_vector_metrics_and_variants(call, company_id, _): def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
task.company, task_id, "training_stats_vector" task.get_index_company(), task_id, "training_stats_vector"
) )
) )
@ -224,11 +224,11 @@ def get_vector_metrics_and_variants(call, company_id, _):
def get_scalar_metrics_and_variants(call, company_id, _): def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
call.result.data = dict( call.result.data = dict(
metrics=event_bll.get_metrics_and_variants( metrics=event_bll.get_metrics_and_variants(
task.company, task_id, "training_stats_scalar" task.get_index_company(), task_id, "training_stats_scalar"
) )
) )
@ -241,12 +241,12 @@ def get_scalar_metrics_and_variants(call, company_id, _):
def vector_metrics_iter_histogram(call, company_id, _): def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
metric = call.data["metric"] metric = call.data["metric"]
variant = call.data["variant"] variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter( iterations, vectors = event_bll.get_vector_metrics_per_iter(
task.company, task_id, metric, variant task.get_index_company(), task_id, metric, variant
) )
call.result.data = dict( call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations metric=metric, variant=variant, vectors=vectors, iterations=iterations
@ -262,10 +262,10 @@ def get_task_events(call, company_id, _):
order = call.data.get("order") or "asc" order = call.data.get("order") or "asc"
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.company, task.get_index_company(),
task_id, task_id,
sort=[{"timestamp": {"order": order}}], sort=[{"timestamp": {"order": order}}],
event_type=event_type, event_type=event_type,
@ -288,10 +288,10 @@ def get_scalar_metric_data(call, company_id, _):
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.company, task.get_index_company(),
task_id, task_id,
event_type="training_stats_scalar", event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
@ -311,12 +311,13 @@ def get_scalar_metric_data(call, company_id, _):
def get_task_latest_scalar_values(call, company_id, _): def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"] task_id = call.data["task"]
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
index_company = task.get_index_company()
metrics, last_timestamp = event_bll.get_task_latest_scalar_values( metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
task.company, task_id index_company, task_id
) )
es_index = EventMetrics.get_index_name(task.company, "*") es_index = EventMetrics.get_index_name(index_company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1) last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict( call.result.data = dict(
metrics=metrics, metrics=metrics,
@ -336,10 +337,10 @@ def scalar_metrics_iter_histogram(
call, company_id, request: ScalarMetricsIterHistogramRequest call, company_id, request: ScalarMetricsIterHistogramRequest
): ):
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, request.task, allow_public=True, only=("company",) company_id, request.task, allow_public=True, only=("company", "company_origin")
)[0] )[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter( metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.company, task_id=request.task, samples=request.samples, key=request.key task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key
) )
call.result.data = metrics call.result.data = metrics
@ -374,12 +375,12 @@ def get_multi_task_plots_v1_7(call, company_id, _):
tasks = task_bll.assert_exists( tasks = task_bll.assert_exists(
company_id=company_id, company_id=company_id,
only=("id", "name", "company"), only=("id", "name", "company", "company_origin"),
task_ids=task_ids, task_ids=task_ids,
allow_public=True, allow_public=True,
) )
companies = {t.company for t in tasks} companies = {t.get_index_company() for t in tasks}
if len(companies) > 1: if len(companies) > 1:
raise errors.bad_request.InvalidTaskId( raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported" "only tasks from the same company are supported"
@ -417,12 +418,12 @@ def get_multi_task_plots(call, company_id, req_model):
tasks = task_bll.assert_exists( tasks = task_bll.assert_exists(
company_id=call.identity.company, company_id=call.identity.company,
only=("id", "name", "company"), only=("id", "name", "company", "company_origin"),
task_ids=task_ids, task_ids=task_ids,
allow_public=True, allow_public=True,
) )
companies = {t.company for t in tasks} companies = {t.get_index_company() for t in tasks}
if len(companies) > 1: if len(companies) > 1:
raise errors.bad_request.InvalidTaskId( raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported" "only tasks from the same company are supported"
@ -458,7 +459,7 @@ def get_task_plots_v1_7(call, company_id, _):
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
# events, next_scroll_id, total_events = event_bll.get_task_events( # events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id, # company, task_id,
@ -469,7 +470,7 @@ def get_task_plots_v1_7(call, company_id, _):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination # get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.company, task.get_index_company(),
task_id, task_id,
event_type="plot", event_type="plot",
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
@ -494,10 +495,10 @@ def get_task_plots(call, company_id, _):
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
result = event_bll.get_task_plots( result = event_bll.get_task_plots(
task.company, task.get_index_company(),
tasks=[task_id], tasks=[task_id],
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters, last_iterations_per_plot=iters,
@ -521,7 +522,7 @@ def get_debug_images_v1_7(call, company_id, _):
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
# events, next_scroll_id, total_events = event_bll.get_task_events( # events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id, # company, task_id,
@ -532,7 +533,7 @@ def get_debug_images_v1_7(call, company_id, _):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination # get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.company, task.get_index_company(),
task_id, task_id,
event_type="training_debug_image", event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
@ -558,10 +559,10 @@ def get_debug_images_v1_8(call, company_id, _):
scroll_id = call.data.get("scroll_id") scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",) company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0] )[0]
result = event_bll.get_task_events( result = event_bll.get_task_events(
task.company, task.get_index_company(),
task_id, task_id,
event_type="training_debug_image", event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}], sort=[{"iter": {"order": "desc"}}],
@ -589,10 +590,10 @@ def get_debug_images_v1_8(call, company_id, _):
def get_debug_images(call, company_id, request: DebugImagesRequest): def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics} task_ids = {m.task for m in request.metrics}
tasks = task_bll.assert_exists( tasks = task_bll.assert_exists(
company_id, task_ids=task_ids, allow_public=True, only=("company",) company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin")
) )
companies = {t.company for t in tasks} companies = {t.get_index_company() for t in tasks}
if len(companies) > 1: if len(companies) > 1:
raise errors.bad_request.InvalidTaskId( raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported" "only tasks from the same company are supported"
@ -626,10 +627,10 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest): def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists( task = task_bll.assert_exists(
company_id, task_ids=request.tasks, allow_public=True, only=("company",) company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin")
)[0] )[0]
res = event_bll.metrics.get_tasks_metrics( res = event_bll.metrics.get_tasks_metrics(
task.company, task_ids=request.tasks, event_type=request.event_type task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
) )
call.result.data = { call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res] "metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]