Fix type annotations

Fix obtaining events for tasks moved from private to public
Fix assert_exists() to return company_origin if requested
This commit is contained in:
allegroai 2021-01-05 16:27:38 +02:00
parent c67a56eb8d
commit 22e9c2b7eb
6 changed files with 59 additions and 46 deletions

View File

@ -115,7 +115,7 @@ class EventMetrics:
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name", "company"),
override_projection=("id", "name", "company", "company_origin"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
@ -123,7 +123,7 @@ class EventMetrics:
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
companies = {t.company for t in task_objs}
companies = {t.get_index_company() for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"

View File

@ -126,7 +126,9 @@ class TaskBLL(object):
return_dicts=False,
)
if only:
q = q.only(*only)
# Make sure to reset fields filters (some fields are excluded by default) since this
# is an internal call and specific fields were requested.
q = q.all_fields().only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids)

View File

@ -773,7 +773,7 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, unset__company=1)
update = dict(set__company_origin=company_id, set__company="")
else:
items = list(
cls.objects(

View File

@ -217,3 +217,13 @@ class Task(AttributedDocument):
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)
def get_index_company(self) -> str:
"""
Returns the company ID used for locating indices containing task data.
In case the task has a valid company, this is the company ID.
Otherwise, if the task has a company_origin, this is a task that has been made public and the
origin company should be used.
Otherwise, an empty company is used.
"""
return self.company or self.company_origin or ""

View File

@ -1,4 +1,4 @@
from typing import Text, Sequence, Callable, Union
from typing import Text, Sequence, Callable, Union, Type
from funcsigs import signature
from jsonmodels import models
@ -18,8 +18,8 @@ def endpoint(
name: Text,
min_version: Text = "1.0",
required_fields: Sequence[Text] = None,
request_data_model: models.Base = None,
response_data_model: models.Base = None,
request_data_model: Type[models.Base] = None,
response_data_model: Type[models.Base] = None,
validate_schema=False,
):
""" Endpoint decorator, used to declare a method as an endpoint handler """

View File

@ -50,13 +50,13 @@ def add_batch(call: APICall, company_id, _):
def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events(
task.company,
task.get_index_company(),
task_id,
order,
event_type="log",
@ -72,7 +72,7 @@ def get_task_log_v1_5(call, company_id, _):
def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
@ -83,7 +83,7 @@ def get_task_log_v1_7(call, company_id, _):
scroll_order = "asc" if (from_ == "head") else "desc"
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=task.company,
company_id=task.get_index_company(),
task_id=task_id,
order=scroll_order,
event_type="log",
@ -103,11 +103,11 @@ def get_task_log_v1_7(call, company_id, _):
def get_task_log(call, company_id, request: LogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
res = event_bll.log_events_iterator.get_task_events(
company_id=task.company,
company_id=task.get_index_company(),
task_id=task_id,
batch_size=request.batch_size,
navigate_earlier=request.navigate_earlier,
@ -131,7 +131,7 @@ def get_task_log(call, company_id, request: LogEventsRequest):
def download_task_log(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
line_type = call.data.get("line_type", "json").lower()
@ -175,7 +175,7 @@ def download_task_log(call, company_id, _):
batch_size = 1000
while True:
log_events, scroll_id, _ = event_bll.scroll_task_events(
task.company,
task.get_index_company(),
task_id,
order="asc",
event_type="log",
@ -211,11 +211,11 @@ def download_task_log(call, company_id, _):
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.company, task_id, "training_stats_vector"
task.get_index_company(), task_id, "training_stats_vector"
)
)
@ -224,11 +224,11 @@ def get_vector_metrics_and_variants(call, company_id, _):
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.company, task_id, "training_stats_scalar"
task.get_index_company(), task_id, "training_stats_scalar"
)
)
@ -241,12 +241,12 @@ def get_scalar_metrics_and_variants(call, company_id, _):
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task.company, task_id, metric, variant
task.get_index_company(), task_id, metric, variant
)
call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations
@ -262,10 +262,10 @@ def get_task_events(call, company_id, _):
order = call.data.get("order") or "asc"
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_events(
task.company,
task.get_index_company(),
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
@ -288,10 +288,10 @@ def get_scalar_metric_data(call, company_id, _):
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_events(
task.company,
task.get_index_company(),
task_id,
event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}],
@ -311,12 +311,13 @@ def get_scalar_metric_data(call, company_id, _):
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
index_company = task.get_index_company()
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
task.company, task_id
index_company, task_id
)
es_index = EventMetrics.get_index_name(task.company, "*")
es_index = EventMetrics.get_index_name(index_company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,
@ -336,10 +337,10 @@ def scalar_metrics_iter_histogram(
call, company_id, request: ScalarMetricsIterHistogramRequest
):
task = task_bll.assert_exists(
company_id, request.task, allow_public=True, only=("company",)
company_id, request.task, allow_public=True, only=("company", "company_origin")
)[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.company, task_id=request.task, samples=request.samples, key=request.key
task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key
)
call.result.data = metrics
@ -374,12 +375,12 @@ def get_multi_task_plots_v1_7(call, company_id, _):
tasks = task_bll.assert_exists(
company_id=company_id,
only=("id", "name", "company"),
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.company for t in tasks}
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
@ -417,12 +418,12 @@ def get_multi_task_plots(call, company_id, req_model):
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name", "company"),
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.company for t in tasks}
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
@ -458,7 +459,7 @@ def get_task_plots_v1_7(call, company_id, _):
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
@ -469,7 +470,7 @@ def get_task_plots_v1_7(call, company_id, _):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
task.company,
task.get_index_company(),
task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
@ -494,10 +495,10 @@ def get_task_plots(call, company_id, _):
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_plots(
task.company,
task.get_index_company(),
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
@ -521,7 +522,7 @@ def get_debug_images_v1_7(call, company_id, _):
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
@ -532,7 +533,7 @@ def get_debug_images_v1_7(call, company_id, _):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
task.company,
task.get_index_company(),
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
@ -558,10 +559,10 @@ def get_debug_images_v1_8(call, company_id, _):
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_events(
task.company,
task.get_index_company(),
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
@ -589,10 +590,10 @@ def get_debug_images_v1_8(call, company_id, _):
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics}
tasks = task_bll.assert_exists(
company_id, task_ids=task_ids, allow_public=True, only=("company",)
company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin")
)
companies = {t.company for t in tasks}
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
@ -626,10 +627,10 @@ def get_debug_images(call, company_id, request: DebugImagesRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id, task_ids=request.tasks, allow_public=True, only=("company",)
company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin")
)[0]
res = event_bll.metrics.get_tasks_metrics(
task.company, task_ids=request.tasks, event_type=request.event_type
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]