import itertools from collections import defaultdict from operator import itemgetter import six from apierrors import errors from bll.event import EventBLL from bll.task import TaskBLL from service_repo import APICall, endpoint from utilities import json task_bll = TaskBLL() event_bll = EventBLL() @endpoint("events.add") def add(call, company_id, req_model): assert isinstance(call, APICall) added, batch_errors = event_bll.add_events(company_id, [call.data.copy()], call.worker) call.result.data = dict( added=added, errors=len(batch_errors) ) call.kpis["events"] = 1 @endpoint("events.add_batch") def add_batch(call, company_id, req_model): assert isinstance(call, APICall) events = call.batched_data if events is None or len(events) == 0: raise errors.bad_request.BatchContainsNoItems() added, batch_errors = event_bll.add_events(company_id, events, call.worker) call.result.data = dict( added=added, errors=len(batch_errors) ) call.kpis["events"] = len(events) @endpoint("events.get_task_log", required_fields=["task"]) def get_task_log(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(company_id, task_id, allow_public=True) order = call.data.get("order") or "desc" scroll_id = call.data.get("scroll_id") batch_size = int(call.data.get("batch_size") or 500) events, scroll_id, total_events = event_bll.scroll_task_events( company_id, task_id, order, event_type="log", batch_size=batch_size, scroll_id=scroll_id) call.result.data = dict( events=events, returned=len(events), total=total_events, scroll_id=scroll_id, ) @endpoint("events.get_task_log", min_version="1.7", required_fields=["task"]) def get_task_log_v1_7(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(company_id, task_id, allow_public=True) order = call.data.get("order") or "desc" from_ = call.data.get("from") or "head" scroll_id = call.data.get("scroll_id") batch_size = int(call.data.get("batch_size") or 500) scroll_order = 'asc' if (from_ == 'head') else 'desc' events, scroll_id, total_events = event_bll.scroll_task_events( company_id=company_id, task_id=task_id, order=scroll_order, event_type="log", batch_size=batch_size, scroll_id=scroll_id ) if scroll_order != order: events = events[::-1] call.result.data = dict( events=events, returned=len(events), total=total_events, scroll_id=scroll_id, ) @endpoint('events.download_task_log', required_fields=['task']) def download_task_log(call, company_id, req_model): task_id = call.data['task'] task_bll.assert_exists(company_id, task_id, allow_public=True) line_type = call.data.get('line_type', 'json').lower() line_format = str(call.data.get('line_format', '{asctime} {worker} {level} {msg}')) is_json = (line_type == 'json') if not is_json: if not line_format: raise errors.bad_request.MissingRequiredFields('line_format is required for plain text lines') # validate line format placeholders valid_task_log_fields = {'asctime', 'timestamp', 'level', 'worker', 'msg'} invalid_placeholders = set() while True: try: line_format.format(**dict.fromkeys(valid_task_log_fields | invalid_placeholders)) break except KeyError as e: invalid_placeholders.add(e.args[0]) except Exception as e: raise errors.bad_request.FieldsValueError('invalid line format', error=e.args[0]) if invalid_placeholders: raise errors.bad_request.FieldsValueError( 'undefined placeholders in line format', placeholders=invalid_placeholders ) # make sure line_format has a trailing newline line_format = line_format.rstrip('\n') + '\n' def generate(): scroll_id = None batch_size = 1000 while True: log_events, scroll_id, _ = event_bll.scroll_task_events( company_id, task_id, order="asc", event_type="log", batch_size=batch_size, scroll_id=scroll_id ) if not log_events: break for ev in log_events: ev['asctime'] = ev.pop('@timestamp') if is_json: ev.pop('type') ev.pop('task') yield json.dumps(ev) + '\n' else: try: yield line_format.format(**ev) except KeyError as ex: raise errors.bad_request.FieldsValueError( 'undefined placeholders in line format', placeholders=[str(ex)] ) if len(log_events) < batch_size: break call.result.filename = 'task_%s.log' % task_id call.result.content_type = 'text/plain' call.result.raw_data = generate() @endpoint("events.get_vector_metrics_and_variants", required_fields=["task"]) def get_vector_metrics_and_variants(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(company_id, task_id, allow_public=True) call.result.data = dict( metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_vector") ) @endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"]) def get_scalar_metrics_and_variants(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(company_id, task_id, allow_public=True) call.result.data = dict( metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_scalar") ) # todo: !!! currently returning 10,000 records. should decide on a better way to control it @endpoint("events.vector_metrics_iter_histogram", required_fields=["task", "metric", "variant"]) def vector_metrics_iter_histogram(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(company_id, task_id, allow_public=True) metric = call.data["metric"] variant = call.data["variant"] iterations, vectors = event_bll.get_vector_metrics_per_iter(company_id, task_id, metric, variant) call.result.data = dict( metric=metric, variant=variant, vectors=vectors, iterations=iterations ) @endpoint("events.get_task_events", required_fields=["task"]) def get_task_events(call, company_id, req_model): task_id = call.data["task"] event_type = call.data.get("event_type") scroll_id = call.data.get("scroll_id") order = call.data.get("order") or "asc" task_bll.assert_exists(company_id, task_id, allow_public=True) result = event_bll.get_task_events( company_id, task_id, sort=[{"timestamp": {"order": order}}], event_type=event_type, scroll_id=scroll_id ) call.result.data = dict( events=result.events, returned=len(result.events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"]) def get_scalar_metric_data(call, company_id, req_model): task_id = call.data["task"] metric = call.data["metric"] scroll_id = call.data.get("scroll_id") task_bll.assert_exists(company_id, task_id, allow_public=True) result = event_bll.get_task_events( company_id, task_id, event_type="training_stats_scalar", sort=[{"iter": {"order": "desc"}}], metric=metric, scroll_id=scroll_id ) call.result.data = dict( events=result.events, returned=len(result.events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.get_task_latest_scalar_values", required_fields=["task"]) def get_task_latest_scalar_values(call, company_id, req_model): task_id = call.data["task"] task = task_bll.assert_exists(company_id, task_id, allow_public=True) metrics, last_timestamp = event_bll.get_task_latest_scalar_values(company_id, task_id) es_index = EventBLL.get_index_name(company_id, "*") last_iters = event_bll.get_last_iters(es_index, task_id, None, 1) call.result.data = dict( metrics=metrics, last_iter=last_iters[0] if last_iters else 0, name=task.name, status=task.status, last_timestamp=last_timestamp ) # todo: should not repeat iter (x-axis) for each metric/variant, JS client should get raw data and fill gaps if needed @endpoint("events.scalar_metrics_iter_histogram", required_fields=["task"]) def scalar_metrics_iter_histogram(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(call.identity.company, task_id, allow_public=True) metrics = event_bll.get_scalar_metrics_average_per_iter(company_id, task_id) call.result.data = metrics @endpoint("events.multi_task_scalar_metrics_iter_histogram", required_fields=["tasks"]) def multi_task_scalar_metrics_iter_histogram(call, company_id, req_model): task_ids = call.data["tasks"] if isinstance(task_ids, six.string_types): task_ids = [s.strip() for s in task_ids.split(",")] # Note, bll already validates task ids as it needs their names call.result.data = dict( metrics=event_bll.compare_scalar_metrics_average_per_iter(company_id, task_ids, allow_public=True) ) @endpoint("events.get_multi_task_plots", required_fields=["tasks"]) def get_multi_task_plots_v1_7(call, company_id, req_model): task_ids = call.data["tasks"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") tasks = task_bll.assert_exists( company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True ) # Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( company_id, task_ids, event_type="plot", sort=[{"iter": {"order": "desc"}}], size=10000, scroll_id=scroll_id ) tasks = {t.id: t.name for t in tasks} return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks) call.result.data = dict( plots=return_events, returned=len(return_events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"]) def get_multi_task_plots(call, company_id, req_model): task_ids = call.data["tasks"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") tasks = task_bll.assert_exists( company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True ) result = event_bll.get_task_events( company_id, task_ids, event_type="plot", sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id ) tasks = {t.id: t.name for t in tasks} return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks) call.result.data = dict( plots=return_events, returned=len(return_events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.get_task_plots", required_fields=["task"]) def get_task_plots_v1_7(call, company_id, req_model): task_id = call.data["task"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") task_bll.assert_exists(call.identity.company, task_id, allow_public=True) # events, next_scroll_id, total_events = event_bll.get_task_events( # company_id, task_id, # event_type="plot", # sort=[{"iter": {"order": "desc"}}], # last_iter_count=iters, # scroll_id=scroll_id) # get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( company_id, task_id, event_type="plot", sort=[{"iter": {"order": "desc"}}], size=10000, scroll_id=scroll_id ) return_events = _get_top_iter_unique_events(result.events, max_iters=iters) call.result.data = dict( plots=return_events, returned=len(return_events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"]) def get_task_plots(call, company_id, req_model): task_id = call.data["task"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") task_bll.assert_exists(call.identity.company, task_id, allow_public=True) result = event_bll.get_task_events( company_id, task_id, event_type="plot", sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id ) return_events = result.events call.result.data = dict( plots=return_events, returned=len(return_events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.debug_images", required_fields=["task"]) def get_debug_images_v1_7(call, company_id, req_model): task_id = call.data["task"] iters = call.data.get("iters") or 1 scroll_id = call.data.get("scroll_id") task_bll.assert_exists(call.identity.company, task_id, allow_public=True) # events, next_scroll_id, total_events = event_bll.get_task_events( # company_id, task_id, # event_type="training_debug_image", # sort=[{"iter": {"order": "desc"}}], # last_iter_count=iters, # scroll_id=scroll_id) # get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( company_id, task_id, event_type="training_debug_image", sort=[{"iter": {"order": "desc"}}], size=10000, scroll_id=scroll_id ) return_events = _get_top_iter_unique_events(result.events, max_iters=iters) call.result.data = dict( task=task_id, images=return_events, returned=len(return_events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.debug_images", min_version="1.8", required_fields=["task"]) def get_debug_images(call, company_id, req_model): task_id = call.data["task"] iters = call.data.get("iters") or 1 scroll_id = call.data.get("scroll_id") task_bll.assert_exists(call.identity.company, task_id, allow_public=True) result = event_bll.get_task_events( company_id, task_id, event_type="training_debug_image", sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id ) return_events = result.events call.result.data = dict( task=task_id, images=return_events, returned=len(return_events), total=result.total_events, scroll_id=result.next_scroll_id, ) @endpoint("events.delete_for_task", required_fields=["task"]) def delete_for_task(call, company_id, req_model): task_id = call.data["task"] task_bll.assert_exists(company_id, task_id) call.result.data = dict( deleted=event_bll.delete_task_events(company_id, task_id) ) def _get_top_iter_unique_events_per_task(events, max_iters, tasks): key = itemgetter('metric', 'variant', 'task', 'iter') unique_events = itertools.chain.from_iterable( itertools.islice(group, max_iters) for _, group in itertools.groupby(sorted(events, key=key, reverse=True), key=key)) def collect(evs, fields): if not fields: evs = list(evs) return { 'name': tasks.get(evs[0].get('task')), 'plots': evs } return { str(k): collect(group, fields[1:]) for k, group in itertools.groupby(evs, key=itemgetter(fields[0])) } collect_fields = ('metric', 'variant', 'task', 'iter') return collect( sorted(unique_events, key=itemgetter(*collect_fields), reverse=True), collect_fields ) def _get_top_iter_unique_events(events, max_iters): top_unique_events = defaultdict(lambda: []) for e in events: key = e.get("metric", "") + e.get("variant", "") evs = top_unique_events[key] if len(evs) < max_iters: evs.append(e) unique_events = list(itertools.chain.from_iterable(list(top_unique_events.values()))) unique_events.sort(key=lambda e: e["iter"], reverse=True) return unique_events