clearml-server/apiserver/services/events.py
2021-01-05 16:28:49 +02:00

691 lines
22 KiB
Python

import itertools
from collections import defaultdict
from operator import itemgetter
from apiserver.apierrors import errors
from apiserver.apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
DebugImagesRequest,
DebugImageResponse,
MetricEvents,
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
LogOrderEnum,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
task_bll = TaskBLL()
event_bll = EventBLL()
@endpoint("events.add")
def add(call: APICall, company_id, _):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, err_count, err_info = event_bll.add_events(
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1
@endpoint("events.add_batch")
def add_batch(call: APICall, company_id, _):
events = call.batched_data
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events(
task.get_index_company(),
task_id,
order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id,
)
call.result.data = dict(
events=events, returned=len(events), total=total_events, scroll_id=scroll_id
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
order = call.data.get("order") or "desc"
from_ = call.data.get("from") or "head"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_order = "asc" if (from_ == "head") else "desc"
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=task.get_index_company(),
task_id=task_id,
order=scroll_order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id,
)
if scroll_order != order:
events = events[::-1]
call.result.data = dict(
events=events, returned=len(events), total=total_events, scroll_id=scroll_id
)
@endpoint("events.get_task_log", min_version="2.9", request_data_model=LogEventsRequest)
def get_task_log(call, company_id, request: LogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
res = event_bll.log_events_iterator.get_task_events(
company_id=task.get_index_company(),
task_id=task_id,
batch_size=request.batch_size,
navigate_earlier=request.navigate_earlier,
from_timestamp=request.from_timestamp,
)
if (
request.order and (
(request.navigate_earlier and request.order == LogOrderEnum.asc)
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
)
):
res.events.reverse()
call.result.data = dict(
events=res.events, returned=len(res.events), total=res.total_events
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
line_type = call.data.get("line_type", "json").lower()
line_format = str(call.data.get("line_format", "{asctime} {worker} {level} {msg}"))
is_json = line_type == "json"
if not is_json:
if not line_format:
raise errors.bad_request.MissingRequiredFields(
"line_format is required for plain text lines"
)
# validate line format placeholders
valid_task_log_fields = {"asctime", "timestamp", "level", "worker", "msg"}
invalid_placeholders = set()
while True:
try:
line_format.format(
**dict.fromkeys(valid_task_log_fields | invalid_placeholders)
)
break
except KeyError as e:
invalid_placeholders.add(e.args[0])
except Exception as e:
raise errors.bad_request.FieldsValueError(
"invalid line format", error=e.args[0]
)
if invalid_placeholders:
raise errors.bad_request.FieldsValueError(
"undefined placeholders in line format",
placeholders=invalid_placeholders,
)
# make sure line_format has a trailing newline
line_format = line_format.rstrip("\n") + "\n"
def generate():
scroll_id = None
batch_size = 1000
while True:
log_events, scroll_id, _ = event_bll.scroll_task_events(
task.get_index_company(),
task_id,
order="asc",
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id,
)
if not log_events:
break
for ev in log_events:
ev["asctime"] = ev.pop("timestamp")
if is_json:
ev.pop("type")
ev.pop("task")
yield json.dumps(ev) + "\n"
else:
try:
yield line_format.format(**ev)
except KeyError as ex:
raise errors.bad_request.FieldsValueError(
"undefined placeholders in line format",
placeholders=[str(ex)],
)
if len(log_events) < batch_size:
break
call.result.filename = "task_%s.log" % task_id
call.result.content_type = "text/plain"
call.result.raw_data = generate()
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, "training_stats_vector"
)
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, "training_stats_scalar"
)
)
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint(
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(
task.get_index_company(), task_id, metric, variant
)
call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations
)
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, _):
task_id = call.data["task"]
batch_size = call.data.get("batch_size", 500)
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
scroll_id=scroll_id,
size=batch_size,
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id,
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
index_company = task.get_index_company()
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
index_company, task_id
)
es_index = EventMetrics.get_index_name(index_company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
name=task.name,
status=task.status,
last_timestamp=last_timestamp,
)
# todo: should not repeat iter (x-axis) for each metric/variant, JS client should get raw data and fill gaps if needed
@endpoint(
"events.scalar_metrics_iter_histogram",
request_data_model=ScalarMetricsIterHistogramRequest,
)
def scalar_metrics_iter_histogram(
call, company_id, request: ScalarMetricsIterHistogramRequest
):
task = task_bll.assert_exists(
company_id, request.task, allow_public=True, only=("company", "company_origin")
)[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key
)
call.result.data = metrics
@endpoint(
"events.multi_task_scalar_metrics_iter_histogram",
request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
)
def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
):
task_ids = req_model.tasks
if isinstance(task_ids, str):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id,
task_ids=task_ids,
samples=req_model.samples,
allow_public=True,
key=req_model.key,
)
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=company_id,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
next(iter(companies)),
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id,
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name", "company", "company_origin"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.get_task_events(
next(iter(companies)),
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
# event_type="plot",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
# scroll_id=scroll_id)
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id,
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_plots(
task.get_index_company(),
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
scroll_id=scroll_id,
)
return_events = result.events
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
# event_type="training_debug_image",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
# scroll_id=scroll_id)
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id,
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
call.result.data = dict(
task=task_id,
images=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
)
return_events = result.events
call.result.data = dict(
task=task_id,
images=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint(
"events.debug_images",
min_version="2.7",
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics}
tasks = task_bll.assert_exists(
company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin")
)
companies = {t.get_index_company() for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.debug_images_iterator.get_task_events(
company_id=next(iter(companies)),
metrics=[(m.task, m.metric) for m in request.metrics],
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
refresh=request.refresh,
state_id=request.scroll_id,
)
call.result.data_model = DebugImageResponse(
scroll_id=result.next_scroll_id,
metrics=[
MetricEvents(
task=task,
metric=metric,
iterations=[
IterationEvents(iter=iteration["iter"], events=iteration["events"])
for iteration in iterations
],
)
for (task, metric, iterations) in result.metric_events
],
)
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin")
)[0]
res = event_bll.metrics.get_tasks_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
}
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id, return_tasks=False)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked
)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
key = itemgetter("metric", "variant", "task", "iter")
unique_events = itertools.chain.from_iterable(
itertools.islice(group, max_iters)
for _, group in itertools.groupby(
sorted(events, key=key, reverse=True), key=key
)
)
def collect(evs, fields):
if not fields:
evs = list(evs)
return {"name": tasks.get(evs[0].get("task")), "plots": evs}
return {
str(k): collect(group, fields[1:])
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
}
collect_fields = ("metric", "variant", "task", "iter")
return collect(
sorted(unique_events, key=itemgetter(*collect_fields), reverse=True),
collect_fields,
)
def _get_top_iter_unique_events(events, max_iters):
top_unique_events = defaultdict(lambda: [])
for e in events:
key = e.get("metric", "") + e.get("variant", "")
evs = top_unique_events[key]
if len(evs) < max_iters:
evs.append(e)
unique_events = list(
itertools.chain.from_iterable(list(top_unique_events.values()))
)
unique_events.sort(key=lambda e: e["iter"], reverse=True)
return unique_events