clearml-server/server/services/events.py
2019-06-16 22:41:49 +03:00

508 lines
17 KiB
Python

import itertools
from collections import defaultdict
from operator import itemgetter
import six
from apierrors import errors
from bll.event import EventBLL
from bll.task import TaskBLL
from service_repo import APICall, endpoint
from utilities import json
task_bll = TaskBLL()
event_bll = EventBLL()
@endpoint("events.add")
def add(call, company_id, req_model):
assert isinstance(call, APICall)
added, batch_errors = event_bll.add_events(company_id, [call.data.copy()], call.worker)
call.result.data = dict(
added=added,
errors=len(batch_errors)
)
call.kpis["events"] = 1
@endpoint("events.add_batch")
def add_batch(call, company_id, req_model):
assert isinstance(call, APICall)
events = call.batched_data
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(
added=added,
errors=len(batch_errors)
)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id, task_id, order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id)
call.result.data = dict(
events=events,
returned=len(events),
total=total_events,
scroll_id=scroll_id,
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
order = call.data.get("order") or "desc"
from_ = call.data.get("from") or "head"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_order = 'asc' if (from_ == 'head') else 'desc'
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=company_id,
task_id=task_id,
order=scroll_order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id
)
if scroll_order != order:
events = events[::-1]
call.result.data = dict(
events=events,
returned=len(events),
total=total_events,
scroll_id=scroll_id,
)
@endpoint('events.download_task_log', required_fields=['task'])
def download_task_log(call, company_id, req_model):
task_id = call.data['task']
task_bll.assert_exists(company_id, task_id, allow_public=True)
line_type = call.data.get('line_type', 'json').lower()
line_format = str(call.data.get('line_format', '{asctime} {worker} {level} {msg}'))
is_json = (line_type == 'json')
if not is_json:
if not line_format:
raise errors.bad_request.MissingRequiredFields('line_format is required for plain text lines')
# validate line format placeholders
valid_task_log_fields = {'asctime', 'timestamp', 'level', 'worker', 'msg'}
invalid_placeholders = set()
while True:
try:
line_format.format(**dict.fromkeys(valid_task_log_fields | invalid_placeholders))
break
except KeyError as e:
invalid_placeholders.add(e.args[0])
except Exception as e:
raise errors.bad_request.FieldsValueError('invalid line format', error=e.args[0])
if invalid_placeholders:
raise errors.bad_request.FieldsValueError(
'undefined placeholders in line format',
placeholders=invalid_placeholders
)
# make sure line_format has a trailing newline
line_format = line_format.rstrip('\n') + '\n'
def generate():
scroll_id = None
batch_size = 1000
while True:
log_events, scroll_id, _ = event_bll.scroll_task_events(
company_id,
task_id,
order="asc",
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id
)
if not log_events:
break
for ev in log_events:
ev['asctime'] = ev.pop('@timestamp')
if is_json:
ev.pop('type')
ev.pop('task')
yield json.dumps(ev) + '\n'
else:
try:
yield line_format.format(**ev)
except KeyError as ex:
raise errors.bad_request.FieldsValueError(
'undefined placeholders in line format',
placeholders=[str(ex)]
)
if len(log_events) < batch_size:
break
call.result.filename = 'task_%s.log' % task_id
call.result.content_type = 'text/plain'
call.result.raw_data = generate()
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_vector")
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_scalar")
)
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint("events.vector_metrics_iter_histogram", required_fields=["task", "metric", "variant"])
def vector_metrics_iter_histogram(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(company_id, task_id, metric, variant)
call.result.data = dict(
metric=metric,
variant=variant,
vectors=vectors,
iterations=iterations
)
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, req_model):
task_id = call.data["task"]
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
task_bll.assert_exists(company_id, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
scroll_id=scroll_id
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, req_model):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(company_id, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, req_model):
task_id = call.data["task"]
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(company_id, task_id)
es_index = EventBLL.get_index_name(company_id, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
name=task.name,
status=task.status,
last_timestamp=last_timestamp
)
# todo: should not repeat iter (x-axis) for each metric/variant, JS client should get raw data and fill gaps if needed
@endpoint("events.scalar_metrics_iter_histogram", required_fields=["task"])
def scalar_metrics_iter_histogram(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
metrics = event_bll.get_scalar_metrics_average_per_iter(company_id, task_id)
call.result.data = metrics
@endpoint("events.multi_task_scalar_metrics_iter_histogram", required_fields=["tasks"])
def multi_task_scalar_metrics_iter_histogram(call, company_id, req_model):
task_ids = call.data["tasks"]
if isinstance(task_ids, six.string_types):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
metrics=event_bll.compare_scalar_metrics_average_per_iter(company_id, task_ids, allow_public=True)
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
)
result = event_bll.get_task_events(
company_id, task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company_id, task_id,
# event_type="plot",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
# scroll_id=scroll_id)
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
)
return_events = result.events
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company_id, task_id,
# event_type="training_debug_image",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
# scroll_id=scroll_id)
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
call.result.data = dict(
task=task_id,
images=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
)
return_events = result.events
call.result.data = dict(
task=task_id,
images=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id)
call.result.data = dict(
deleted=event_bll.delete_task_events(company_id, task_id)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
key = itemgetter('metric', 'variant', 'task', 'iter')
unique_events = itertools.chain.from_iterable(
itertools.islice(group, max_iters)
for _, group in itertools.groupby(sorted(events, key=key, reverse=True), key=key))
def collect(evs, fields):
if not fields:
evs = list(evs)
return {
'name': tasks.get(evs[0].get('task')),
'plots': evs
}
return {
str(k): collect(group, fields[1:])
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
}
collect_fields = ('metric', 'variant', 'task', 'iter')
return collect(
sorted(unique_events, key=itemgetter(*collect_fields), reverse=True),
collect_fields
)
def _get_top_iter_unique_events(events, max_iters):
top_unique_events = defaultdict(lambda: [])
for e in events:
key = e.get("metric", "") + e.get("variant", "")
evs = top_unique_events[key]
if len(evs) < max_iters:
evs.append(e)
unique_events = list(itertools.chain.from_iterable(list(top_unique_events.values())))
unique_events.sort(key=lambda e: e["iter"], reverse=True)
return unique_events