diff --git a/apiserver/apimodels/events.py b/apiserver/apimodels/events.py index e4db85d..39b38ff 100644 --- a/apiserver/apimodels/events.py +++ b/apiserver/apimodels/events.py @@ -26,6 +26,7 @@ class MetricVariants(Base): class ScalarMetricsIterHistogramRequest(HistogramRequestBase): task: str = StringField(required=True) metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) + model_events: bool = BoolField(default=False) class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): @@ -40,6 +41,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase): ) ], ) + model_events: bool = BoolField(default=False) class TaskMetric(Base): @@ -56,6 +58,7 @@ class MetricEventsRequest(Base): navigate_earlier: bool = BoolField(default=True) refresh: bool = BoolField(default=False) scroll_id: str = StringField() + model_events: bool = BoolField() class TaskMetricVariant(Base): @@ -69,12 +72,14 @@ class GetHistorySampleRequest(TaskMetricVariant): refresh: bool = BoolField(default=False) scroll_id: Optional[str] = StringField() navigate_current_metric: bool = BoolField(default=True) + model_events: bool = BoolField(default=False) class NextHistorySampleRequest(Base): task: str = StringField(required=True) scroll_id: Optional[str] = StringField() navigate_earlier: bool = BoolField(default=True) + model_events: bool = BoolField(default=False) class LogOrderEnum(StringEnum): @@ -93,6 +98,7 @@ class TaskEventsRequest(TaskEventsRequestBase): order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc) scroll_id: str = StringField() count_total: bool = BoolField(default=True) + model_events: bool = BoolField(default=False) class LogEventsRequest(TaskEventsRequestBase): @@ -108,6 +114,7 @@ class ScalarMetricsIterRawRequest(TaskEventsRequestBase): metric: MetricVariants = EmbeddedField(MetricVariants, required=True) count_total: bool = BoolField(default=False) scroll_id: str = StringField() + model_events: bool = BoolField(default=False) class IterationEvents(Base): @@ -129,6 +136,7 @@ class MultiTasksRequestBase(Base): tasks: Sequence[str] = ListField( items_types=str, validators=[Length(minimum_value=1)] ) + model_events: bool = BoolField(default=False) class SingleValueMetricsRequest(MultiTasksRequestBase): @@ -145,6 +153,7 @@ class TaskPlotsRequest(Base): scroll_id: str = StringField() no_scroll: bool = BoolField(default=False) metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants) + model_events: bool = BoolField(default=False) class ClearScrollRequest(Base): diff --git a/apiserver/bll/event/event_bll.py b/apiserver/bll/event/event_bll.py index 9a45183..9e6e068 100644 --- a/apiserver/bll/event/event_bll.py +++ b/apiserver/bll/event/event_bll.py @@ -31,6 +31,7 @@ from apiserver.bll.event.metric_debug_images_iterator import MetricDebugImagesIt from apiserver.bll.event.metric_plots_iterator import MetricPlotsIterator from apiserver.bll.util import parallel_chunked_decorator from apiserver.database import utils as dbutils +from apiserver.database.model.model import Model from apiserver.es_factory import es_factory from apiserver.apierrors import errors from apiserver.bll.event.event_metrics import EventMetrics @@ -68,6 +69,15 @@ class EventBLL(object): r"['\"]source['\"]:\s?['\"]([a-z][a-z0-9+\-.]*://.*?)['\"]", flags=re.IGNORECASE, ) + _task_event_query = { + "bool": { + "should": [ + {"term": {"model_event": False}}, + {"bool": {"must_not": [{"exists": {"field": "model_event"}}]}}, + ] + } + } + _model_event_query = {"term": {"model_event": True}} def __init__(self, events_es=None, redis=None): self.es = events_es or es_factory.connect("events") @@ -103,11 +113,35 @@ class EventBLL(object): res = Task.objects(query).only("id") return {r.id for r in res} + @staticmethod + def _get_valid_models(company_id, model_ids: Set, allow_locked_models=False) -> Set: + """Verify that task exists and can be updated""" + if not model_ids: + return set() + + with translate_errors_context(): + query = Q(id__in=model_ids, company=company_id) + if not allow_locked_models: + query &= Q(ready__ne=True) + res = Model.objects(query).only("id") + return {r.id for r in res} + def add_events( - self, company_id, events, worker, allow_locked_tasks=False + self, company_id, events, worker, allow_locked=False ) -> Tuple[int, int, dict]: + model_events = events[0].get("model_event", False) + for event in events: + if event.get("model_event", model_events) != model_events: + raise errors.bad_request.ValidationError( + "Inconsistent model_event setting in the passed events" + ) + if event.pop("allow_locked", allow_locked) != allow_locked: + raise errors.bad_request.ValidationError( + "Inconsistent allow_locked setting in the passed events" + ) + actions: List[dict] = [] - task_ids = set() + task_or_model_ids = set() task_iteration = defaultdict(lambda: 0) task_last_scalar_events = nested_dict( 3, dict @@ -117,13 +151,28 @@ class EventBLL(object): ) # task_id -> metric_hash -> event_type -> MetricEvent errors_per_type = defaultdict(int) invalid_iteration_error = f"Iteration number should not exceed {MAX_LONG}" - valid_tasks = self._get_valid_tasks( - company_id, - task_ids={ - event["task"] for event in events if event.get("task") is not None - }, - allow_locked_tasks=allow_locked_tasks, - ) + if model_events: + for event in events: + model = event.pop("model", None) + if model is not None: + event["task"] = model + valid_entities = self._get_valid_models( + company_id, + model_ids={ + event["task"] for event in events if event.get("task") is not None + }, + allow_locked_models=allow_locked, + ) + entity_name = "model" + else: + valid_entities = self._get_valid_tasks( + company_id, + task_ids={ + event["task"] for event in events if event.get("task") is not None + }, + allow_locked_tasks=allow_locked, + ) + entity_name = "task" for event in events: # remove spaces from event type @@ -137,13 +186,17 @@ class EventBLL(object): errors_per_type[f"Invalid event type {event_type}"] += 1 continue - task_id = event.get("task") - if task_id is None: + if model_events and event_type == EventType.task_log.value: + errors_per_type[f"Task log events are not supported for models"] += 1 + continue + + task_or_model_id = event.get("task") + if task_or_model_id is None: errors_per_type["Event must have a 'task' field"] += 1 continue - if task_id not in valid_tasks: - errors_per_type["Invalid task id"] += 1 + if task_or_model_id not in valid_entities: + errors_per_type[f"Invalid {entity_name} id {task_or_model_id}"] += 1 continue event["type"] = event_type @@ -165,10 +218,13 @@ class EventBLL(object): # force iter to be a long int iter = event.get("iter") if iter is not None: - iter = int(iter) - if iter > MAX_LONG or iter < MIN_LONG: - errors_per_type[invalid_iteration_error] += 1 - continue + if model_events: + iter = 0 + else: + iter = int(iter) + if iter > MAX_LONG or iter < MIN_LONG: + errors_per_type[invalid_iteration_error] += 1 + continue event["iter"] = iter # used to have "values" to indicate array. no need anymore @@ -178,6 +234,7 @@ class EventBLL(object): event["metric"] = event.get("metric") or "" event["variant"] = event.get("variant") or "" + event["model_event"] = model_events index_name = get_index_name(company_id, event_type) es_action = { @@ -192,21 +249,26 @@ class EventBLL(object): else: es_action["_id"] = dbutils.id() - task_ids.add(task_id) + task_or_model_ids.add(task_or_model_id) if ( iter is not None + and not model_events and event.get("metric") not in self._skip_iteration_for_metric ): - task_iteration[task_id] = max(iter, task_iteration[task_id]) - - self._update_last_metric_events_for_task( - last_events=task_last_events[task_id], event=event, - ) - if event_type == EventType.metrics_scalar.value: - self._update_last_scalar_events_for_task( - last_events=task_last_scalar_events[task_id], event=event + task_iteration[task_or_model_id] = max( + iter, task_iteration[task_or_model_id] ) + if not model_events: + self._update_last_metric_events_for_task( + last_events=task_last_events[task_or_model_id], event=event, + ) + if event_type == EventType.metrics_scalar.value: + self._update_last_scalar_events_for_task( + last_events=task_last_scalar_events[task_or_model_id], + event=event, + ) + actions.append(es_action) plot_actions = [ @@ -243,28 +305,31 @@ class EventBLL(object): else: errors_per_type["Error when indexing events batch"] += 1 - remaining_tasks = set() - now = datetime.utcnow() - for task_id in task_ids: - # Update related tasks. For reasons of performance, we prefer to update - # all of them and not only those who's events were successful - updated = self._update_task( - company_id=company_id, - task_id=task_id, - now=now, - iter_max=task_iteration.get(task_id), - last_scalar_events=task_last_scalar_events.get(task_id), - last_events=task_last_events.get(task_id), - ) + if not model_events: + remaining_tasks = set() + now = datetime.utcnow() + for task_or_model_id in task_or_model_ids: + # Update related tasks. For reasons of performance, we prefer to update + # all of them and not only those who's events were successful + updated = self._update_task( + company_id=company_id, + task_id=task_or_model_id, + now=now, + iter_max=task_iteration.get(task_or_model_id), + last_scalar_events=task_last_scalar_events.get( + task_or_model_id + ), + last_events=task_last_events.get(task_or_model_id), + ) - if not updated: - remaining_tasks.add(task_id) - continue + if not updated: + remaining_tasks.add(task_or_model_id) + continue - if remaining_tasks: - TaskBLL.set_last_update( - remaining_tasks, company_id, last_update=now - ) + if remaining_tasks: + TaskBLL.set_last_update( + remaining_tasks, company_id, last_update=now + ) # this is for backwards compatibility with streaming bulk throwing exception on those invalid_iterations_count = errors_per_type.get(invalid_iteration_error) @@ -527,6 +592,7 @@ class EventBLL(object): scroll_id: str = None, no_scroll: bool = False, metric_variants: MetricVariants = None, + model_events: bool = False, ): if scroll_id == self.empty_scroll: return TaskEventsResult() @@ -553,7 +619,7 @@ class EventBLL(object): } must = [plot_valid_condition] - if last_iterations_per_plot is None: + if last_iterations_per_plot is None or model_events: must.append({"terms": {"task": tasks}}) if metric_variants: must.append(get_metric_variants_condition(metric_variants)) @@ -709,6 +775,7 @@ class EventBLL(object): size=500, scroll_id=None, no_scroll=False, + model_events=False, ) -> TaskEventsResult: if scroll_id == self.empty_scroll: return TaskEventsResult() @@ -728,7 +795,7 @@ class EventBLL(object): if variant: must.append({"term": {"variant": variant}}) - if last_iter_count is None: + if last_iter_count is None or model_events: must.append({"terms": {"task": task_ids}}) else: tasks_iters = self.get_last_iters( @@ -989,6 +1056,21 @@ class EventBLL(object): for tb in es_res["aggregations"]["tasks"]["buckets"] } + @staticmethod + def _validate_model_state( + company_id: str, model_id: str, allow_locked: bool = False + ): + extra_msg = None + query = Q(id=model_id, company=company_id) + if not allow_locked: + query &= Q(ready__ne=True) + extra_msg = "or model published" + res = Model.objects(query).only("id").first() + if not res: + raise errors.bad_request.InvalidModelId( + extra_msg, company=company_id, id=model_id + ) + @staticmethod def _validate_task_state(company_id: str, task_id: str, allow_locked: bool = False): extra_msg = None @@ -1002,10 +1084,15 @@ class EventBLL(object): extra_msg, company=company_id, id=task_id ) - def delete_task_events(self, company_id, task_id, allow_locked=False): - self._validate_task_state( - company_id=company_id, task_id=task_id, allow_locked=allow_locked - ) + def delete_task_events(self, company_id, task_id, allow_locked=False, model=False): + if model: + self._validate_model_state( + company_id=company_id, model_id=task_id, allow_locked=allow_locked, + ) + else: + self._validate_task_state( + company_id=company_id, task_id=task_id, allow_locked=allow_locked + ) es_req = {"query": {"term": {"task": task_id}}} with translate_errors_context(): diff --git a/apiserver/bll/event/event_metrics.py b/apiserver/bll/event/event_metrics.py index 7da4bc5..7b5dba3 100644 --- a/apiserver/bll/event/event_metrics.py +++ b/apiserver/bll/event/event_metrics.py @@ -8,9 +8,7 @@ from typing import Sequence, Tuple, Mapping from boltons.iterutils import bucketize from elasticsearch import Elasticsearch -from mongoengine import Q -from apiserver.apierrors import errors from apiserver.bll.event.event_common import ( EventType, EventSettings, @@ -111,40 +109,19 @@ class EventMetrics: def compare_scalar_metrics_average_per_iter( self, company_id, - task_ids: Sequence[str], + tasks: Sequence[Task], samples, key: ScalarKeyEnum, - allow_public=True, ): """ Compare scalar metrics for different tasks per metric and variant The amount of points in each histogram should not exceed the requested samples """ - task_name_by_id = {} - with translate_errors_context(): - task_objs = Task.get_many( - company=company_id, - query=Q(id__in=task_ids), - allow_public=allow_public, - override_projection=("id", "name", "company", "company_origin"), - return_dicts=False, - ) - if len(task_objs) < len(task_ids): - invalid = tuple(set(task_ids) - set(r.id for r in task_objs)) - raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid) - task_name_by_id = {t.id: t.name for t in task_objs} - - companies = {t.get_index_company() for t in task_objs} - if len(companies) > 1: - raise errors.bad_request.InvalidTaskId( - "only tasks from the same company are supported" - ) - event_type = EventType.metrics_scalar - company_id = next(iter(companies)) if check_empty_data(self.es, company_id=company_id, event_type=event_type): return {} + task_name_by_id = {t.id: t.name for t in tasks} get_scalar_average_per_iter = partial( self._get_scalar_average_per_iter_core, company_id=company_id, @@ -153,6 +130,7 @@ class EventMetrics: key=ScalarKey.resolve(key), run_parallel=False, ) + task_ids = [t.id for t in tasks] with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool: task_metrics = zip( task_ids, pool.map(get_scalar_average_per_iter, task_ids) @@ -169,7 +147,7 @@ class EventMetrics: return res def get_task_single_value_metrics( - self, company_id: str, task_ids: Sequence[str] + self, company_id: str, tasks: Sequence[Task] ) -> Mapping[str, dict]: """ For the requested tasks return all the events delivered for the single iteration (-2**31) @@ -179,6 +157,7 @@ class EventMetrics: ): return {} + task_ids = [t.id for t in tasks] task_events = self._get_task_single_value_metrics(company_id, task_ids) def _get_value(event: dict): diff --git a/apiserver/bll/model/__init__.py b/apiserver/bll/model/__init__.py index 3d47b77..f5e234a 100644 --- a/apiserver/bll/model/__init__.py +++ b/apiserver/bll/model/__init__.py @@ -1,5 +1,7 @@ from datetime import datetime -from typing import Callable, Tuple, Sequence, Dict +from typing import Callable, Tuple, Sequence, Dict, Optional + +from mongoengine import Q from apiserver.apierrors import errors from apiserver.apimodels.models import ModelTaskPublishResponse @@ -24,6 +26,33 @@ class ModelBLL: raise errors.bad_request.InvalidModelId(**query) return model + @staticmethod + def assert_exists( + company_id, + model_ids, + only=None, + allow_public=False, + return_models=True, + ) -> Optional[Sequence[Model]]: + model_ids = [model_ids] if isinstance(model_ids, str) else model_ids + ids = set(model_ids) + query = Q(id__in=ids) + + q = Model.get_many( + company=company_id, + query=query, + allow_public=allow_public, + return_dicts=False, + ) + if only: + q = q.only(*only) + + if q.count() != len(ids): + raise errors.bad_request.InvalidModelId(ids=model_ids) + + if return_models: + return list(q) + @classmethod def publish_model( cls, diff --git a/apiserver/database/model/model.py b/apiserver/database/model/model.py index 48c4aee..8ce4d20 100644 --- a/apiserver/database/model/model.py +++ b/apiserver/database/model/model.py @@ -92,3 +92,6 @@ class Model(AttributedDocument): metadata = SafeMapField( field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True ) + + def get_index_company(self) -> str: + return self.company or self.company_origin or "" diff --git a/apiserver/elastic/mappings/events/events.json b/apiserver/elastic/mappings/events/events.json index b0f8edf..ace61c8 100644 --- a/apiserver/elastic/mappings/events/events.json +++ b/apiserver/elastic/mappings/events/events.json @@ -35,6 +35,12 @@ }, "value": { "type": "float" + }, + "company_id": { + "type": "keyword" + }, + "model_event": { + "type": "boolean" } } } diff --git a/apiserver/mongo/initialize/pre_populate.py b/apiserver/mongo/initialize/pre_populate.py index 11b7864..9ea08c1 100644 --- a/apiserver/mongo/initialize/pre_populate.py +++ b/apiserver/mongo/initialize/pre_populate.py @@ -968,5 +968,5 @@ class PrePopulate: ev["task"] = task_id ev["company_id"] = company_id cls.event_bll.add_events( - company_id, events=events, worker="", allow_locked_tasks=True + company_id, events=events, worker="", allow_locked=True ) diff --git a/apiserver/schema/services/events.conf b/apiserver/schema/services/events.conf index 9f33885..bd2bf60 100644 --- a/apiserver/schema/services/events.conf +++ b/apiserver/schema/services/events.conf @@ -405,13 +405,27 @@ add { additionalProperties: true } } + "999.0": ${add."2.1"} { + request.properties { + model_event { + type: boolean + description: If set then the event is for a model. Otherwise for a task. Cannot be used with task log events. If used in batch then all the events should be marked the same + default: false + } + allow_locked { + type: boolean + description: Allow adding events to published tasks or models + default: false + } + } + } } add_batch { "2.1" { description: "Adds a batch of events in a single call (json-lines format, stream-friendly)" batch_request: { action: add - version: 1.5 + version: 2.1 } response { type: object @@ -422,10 +436,16 @@ add_batch { } } } + "999.0": ${add_batch."2.1"} { + batch_request: { + action: add + version: 999.0 + } + } } delete_for_task { "2.1" { - description: "Delete all task event. *This cannot be undone!*" + description: "Delete all task events. *This cannot be undone!*" request { type: object required: [ @@ -454,6 +474,37 @@ delete_for_task { } } } +delete_for_model { + "999.0" { + description: "Delete all model events. *This cannot be undone!*" + request { + type: object + required: [ + model + ] + properties { + model { + type: string + description: "Model ID" + } + allow_locked { + type: boolean + description: "Allow deleting events even if the model is locked" + default: false + } + } + } + response { + type: object + properties { + deleted { + type: boolean + description: "Number of deleted events" + } + } + } + } +} debug_images { "2.1" { description: "Get all debug images of a task" @@ -548,6 +599,13 @@ debug_images { } } } + "999.0": ${debug_images."2.14"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } plots { "2.20" { @@ -583,6 +641,13 @@ plots { } response {"$ref": "#/definitions/plots_response"} } + "999.0": ${plots."2.20"} { + model_events { + type: boolean + description: If set then the retrieving model plots. Otherwise task plots + default: false + } + } } get_debug_image_sample { "2.12": { @@ -626,6 +691,13 @@ get_debug_image_sample { default: true } } + "999.0": ${get_debug_image_sample."2.20"} { + model_events { + type: boolean + description: If set then the retrieving model debug images. Otherwise task debug images + default: false + } + } } next_debug_image_sample { "2.12": { @@ -651,6 +723,18 @@ next_debug_image_sample { } response {"$ref": "#/definitions/debug_image_sample_response"} } + "999.0": ${next_debug_image_sample."2.12"} { + request.properties.next_iteration { + type: boolean + default: false + description: If set then navigate to the next/previous iteration + } + model_events { + type: boolean + description: If set then the retrieving model debug images. Otherwise task debug images + default: false + } + } } get_plot_sample { "2.20": { @@ -692,6 +776,13 @@ get_plot_sample { } response {"$ref": "#/definitions/plot_sample_response"} } + "999.0": ${get_debug_image_sample."2.20"} { + model_events { + type: boolean + description: If set then the retrieving model plots. Otherwise task plots + default: false + } + } } next_plot_sample { "2.20": { @@ -717,6 +808,18 @@ next_plot_sample { } response {"$ref": "#/definitions/plot_sample_response"} } + "999.0": ${next_plot_sample."2.20"} { + request.properties.next_iteration { + type: boolean + default: false + description: If set then navigate to the next/previous iteration + } + model_events { + type: boolean + description: If set then the retrieving model plots. Otherwise task plots + default: false + } + } } get_task_metrics{ "2.7": { @@ -749,6 +852,13 @@ get_task_metrics{ } } } + "999.0": ${get_task_metrics."2.7"} { + model_events { + type: boolean + description: If set then get metrics from model events. Otherwise from task events + default: false + } + } } get_task_log { "1.5" { @@ -966,6 +1076,13 @@ get_task_events { } } } + "999.0": ${get_task_events."2.1"} { + model_events { + type: boolean + description: If set then get retrieving model events. Otherwise task events + default: false + } + } } download_task_log { @@ -1067,6 +1184,13 @@ get_task_plots { default: false } } + "999.0": ${get_task_plots."2.16"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } get_multi_task_plots { "2.1" { @@ -1124,6 +1248,13 @@ get_multi_task_plots { default: false } } + "999.0": ${get_multi_task_plots."2.16"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } get_vector_metrics_and_variants { "2.1" { @@ -1152,6 +1283,13 @@ get_vector_metrics_and_variants { } } } + "999.0": ${get_vector_metrics_and_variants."2.1"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } vector_metrics_iter_histogram { "2.1" { @@ -1190,6 +1328,13 @@ vector_metrics_iter_histogram { } } } + "999.0": ${vector_metrics_iter_histogram."2.1"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } scalar_metrics_iter_histogram { "2.1" { @@ -1243,6 +1388,13 @@ scalar_metrics_iter_histogram { } } } + "999.0": ${scalar_metrics_iter_histogram."2.14"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } multi_task_scalar_metrics_iter_histogram { "2.1" { @@ -1282,6 +1434,13 @@ multi_task_scalar_metrics_iter_histogram { additionalProperties: true } } + "999.0": ${multi_task_scalar_metrics_iter_histogram."2.1"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } get_task_single_value_metrics { "2.20" { @@ -1331,6 +1490,13 @@ get_task_single_value_metrics { } } } + "999.0": ${get_task_single_value_metrics."2.20"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } get_task_latest_scalar_values { "2.1" { @@ -1410,6 +1576,13 @@ get_scalar_metrics_and_variants { } } } + "999.0": ${get_scalar_metrics_and_variants."2.1"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } get_scalar_metric_data { "2.1" { @@ -1459,6 +1632,13 @@ get_scalar_metric_data { default: false } } + "999.0": ${get_scalar_metric_data."2.16"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } scalar_metrics_iter_raw { "2.16" { @@ -1523,6 +1703,13 @@ scalar_metrics_iter_raw { } } } + "999.0": ${scalar_metrics_iter_raw."2.16"} { + model_events { + type: boolean + description: If set then the retrieving model events. Otherwise task events + default: false + } + } } clear_scroll { "2.18" { diff --git a/apiserver/services/events.py b/apiserver/services/events.py index bb87fc8..9457a0d 100644 --- a/apiserver/services/events.py +++ b/apiserver/services/events.py @@ -2,7 +2,7 @@ import itertools import math from collections import defaultdict from operator import itemgetter -from typing import Sequence, Optional +from typing import Sequence, Optional, Union, Tuple import attr import jsonmodels.fields @@ -33,13 +33,36 @@ from apiserver.bll.event import EventBLL from apiserver.bll.event.event_common import EventType, MetricVariants from apiserver.bll.event.events_iterator import Scroll from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey +from apiserver.bll.model import ModelBLL from apiserver.bll.task import TaskBLL from apiserver.config_repo import config +from apiserver.database.model.model import Model +from apiserver.database.model.task.task import Task from apiserver.service_repo import APICall, endpoint from apiserver.utilities import json, extract_properties_to_lists task_bll = TaskBLL() event_bll = EventBLL() +model_bll = ModelBLL() + + +def _assert_task_or_model_exists( + company_id: str, task_ids: Union[str, Sequence[str]], model_events: bool +) -> Union[Sequence[Model], Sequence[Task]]: + if model_events: + return model_bll.assert_exists( + company_id, + task_ids, + allow_public=True, + only=("id", "name", "company", "company_origin"), + ) + + return task_bll.assert_exists( + company_id, + task_ids, + allow_public=True, + only=("id", "name", "company", "company_origin"), + ) @endpoint("events.add") @@ -47,7 +70,7 @@ def add(call: APICall, company_id, _): data = call.data.copy() allow_locked = data.pop("allow_locked", False) added, err_count, err_info = event_bll.add_events( - company_id, [data], call.worker, allow_locked_tasks=allow_locked + company_id, [data], call.worker, allow_locked=allow_locked ) call.result.data = dict(added=added, errors=err_count, errors_info=err_info) @@ -58,7 +81,12 @@ def add_batch(call: APICall, company_id, _): if events is None or len(events) == 0: raise errors.bad_request.BatchContainsNoItems() - added, err_count, err_info = event_bll.add_events(company_id, events, call.worker) + added, err_count, err_info = event_bll.add_events( + company_id, + events, + call.worker, + allow_locked=events[0].get("allow_locked", False), + ) call.result.data = dict(added=added, errors=err_count, errors_info=err_info) @@ -225,12 +253,13 @@ def download_task_log(call, company_id, _): @endpoint("events.get_vector_metrics_and_variants", required_fields=["task"]) def get_vector_metrics_and_variants(call, company_id, _): task_id = call.data["task"] - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company", "company_origin") + model_events = call.data["model_events"] + task_or_model = _assert_task_or_model_exists( + company_id, task_id, model_events=model_events, )[0] call.result.data = dict( metrics=event_bll.get_metrics_and_variants( - task.get_index_company(), task_id, EventType.metrics_vector + task_or_model.get_index_company(), task_id, EventType.metrics_vector ) ) @@ -238,12 +267,13 @@ def get_vector_metrics_and_variants(call, company_id, _): @endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"]) def get_scalar_metrics_and_variants(call, company_id, _): task_id = call.data["task"] - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company", "company_origin") + model_events = call.data["model_events"] + task_or_model = _assert_task_or_model_exists( + company_id, task_id, model_events=model_events, )[0] call.result.data = dict( metrics=event_bll.get_metrics_and_variants( - task.get_index_company(), task_id, EventType.metrics_scalar + task_or_model.get_index_company(), task_id, EventType.metrics_scalar ) ) @@ -255,13 +285,14 @@ def get_scalar_metrics_and_variants(call, company_id, _): ) def vector_metrics_iter_histogram(call, company_id, _): task_id = call.data["task"] - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company", "company_origin") + model_events = call.data["model_events"] + task_or_model = _assert_task_or_model_exists( + company_id, task_id, model_events=model_events, )[0] metric = call.data["metric"] variant = call.data["variant"] iterations, vectors = event_bll.get_vector_metrics_per_iter( - task.get_index_company(), task_id, metric, variant + task_or_model.get_index_company(), task_id, metric, variant ) call.result.data = dict( metric=metric, variant=variant, vectors=vectors, iterations=iterations @@ -286,11 +317,10 @@ def make_response( @endpoint("events.get_task_events", request_data_model=TaskEventsRequest) -def get_task_events(call, company_id, request: TaskEventsRequest): +def get_task_events(_, company_id, request: TaskEventsRequest): task_id = request.task - - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company",), + task_or_model = _assert_task_or_model_exists( + company_id, task_id, model_events=request.model_events, )[0] key = ScalarKeyEnum.iter @@ -322,7 +352,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest): if request.count_total and total is None: total = event_bll.events_iterator.count_task_events( event_type=request.event_type, - company_id=task.company, + company_id=task_or_model.get_index_company(), task_id=task_id, metric_variants=metric_variants, ) @@ -336,7 +366,7 @@ def get_task_events(call, company_id, request: TaskEventsRequest): res = event_bll.events_iterator.get_task_events( event_type=request.event_type, - company_id=task.company, + company_id=task_or_model.get_index_company(), task_id=task_id, batch_size=batch_size, key=ScalarKeyEnum.iter, @@ -365,18 +395,20 @@ def get_scalar_metric_data(call, company_id, _): metric = call.data["metric"] scroll_id = call.data.get("scroll_id") no_scroll = call.data.get("no_scroll", False) + model_events = call.data.get("model_events", False) - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company", "company_origin") + task_or_model = _assert_task_or_model_exists( + company_id, task_id, model_events=model_events, )[0] result = event_bll.get_task_events( - task.get_index_company(), + task_or_model.get_index_company(), task_id, event_type=EventType.metrics_scalar, sort=[{"iter": {"order": "desc"}}], metric=metric, scroll_id=scroll_id, no_scroll=no_scroll, + model_events=model_events, ) call.result.data = dict( @@ -398,7 +430,7 @@ def get_task_latest_scalar_values(call, company_id, _): index_company, task_id ) last_iters = event_bll.get_last_iters( - company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1 + company_id=index_company, event_type=EventType.all, task_id=task_id, iters=1 ).get(task_id) call.result.data = dict( metrics=metrics, @@ -417,11 +449,11 @@ def get_task_latest_scalar_values(call, company_id, _): def scalar_metrics_iter_histogram( call, company_id, request: ScalarMetricsIterHistogramRequest ): - task = task_bll.assert_exists( - company_id, request.task, allow_public=True, only=("company", "company_origin") + task_or_model = _assert_task_or_model_exists( + company_id, request.task, model_events=request.model_events )[0] metrics = event_bll.metrics.get_scalar_metrics_average_per_iter( - task.get_index_company(), + company_id=task_or_model.get_index_company(), task_id=request.task, samples=request.samples, key=request.key, @@ -429,24 +461,55 @@ def scalar_metrics_iter_histogram( call.result.data = metrics +def _get_task_or_model_index_company( + company_id: str, task_ids: Sequence[str], model_events=False, +) -> Tuple[str, Sequence[Task]]: + """ + Verify that all tasks exists and belong to store data in the same company index + Return company and tasks + """ + tasks_or_models = _assert_task_or_model_exists( + company_id, task_ids, model_events=model_events, + ) + + unique_ids = set(task_ids) + if len(tasks_or_models) < len(unique_ids): + invalid = tuple(unique_ids - {t.id for t in tasks_or_models}) + error_cls = ( + errors.bad_request.InvalidModelId + if model_events + else errors.bad_request.InvalidTaskId + ) + raise error_cls(company=company_id, ids=invalid) + + companies = {t.get_index_company() for t in tasks_or_models} + if len(companies) > 1: + raise errors.bad_request.InvalidTaskId( + "only tasks from the same company are supported" + ) + + return companies.pop(), tasks_or_models + + @endpoint( "events.multi_task_scalar_metrics_iter_histogram", request_data_model=MultiTaskScalarMetricsIterHistogramRequest, ) def multi_task_scalar_metrics_iter_histogram( - call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest + call, company_id, request: MultiTaskScalarMetricsIterHistogramRequest ): - task_ids = req_model.tasks + task_ids = request.tasks if isinstance(task_ids, str): task_ids = [s.strip() for s in task_ids.split(",")] - # Note, bll already validates task ids as it needs their names + company, tasks_or_models = _get_task_or_model_index_company( + company_id, task_ids, request.model_events + ) call.result.data = dict( metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter( - company_id, - task_ids=task_ids, - samples=req_model.samples, - allow_public=True, - key=req_model.key, + company_id=company, + tasks=tasks_or_models, + samples=request.samples, + key=request.key, ) ) @@ -455,21 +518,11 @@ def multi_task_scalar_metrics_iter_histogram( def get_task_single_value_metrics( call, company_id: str, request: SingleValueMetricsRequest ): - task_ids = call.data["tasks"] - tasks = task_bll.assert_exists( - company_id=call.identity.company, - only=("id", "name", "company", "company_origin"), - task_ids=task_ids, - allow_public=True, + company, tasks_or_models = _get_task_or_model_index_company( + company_id, request.tasks, request.model_events ) - companies = {t.get_index_company() for t in tasks} - if len(companies) > 1: - raise errors.bad_request.InvalidTaskId( - "only tasks from the same company are supported" - ) - - res = event_bll.metrics.get_task_single_value_metrics(company_id, task_ids) + res = event_bll.metrics.get_task_single_value_metrics(company, tasks_or_models) call.result.data = dict( tasks=[{"task": task, "values": values} for task, values in res.items()] ) @@ -481,22 +534,11 @@ def get_multi_task_plots_v1_7(call, company_id, _): iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") - tasks = task_bll.assert_exists( - company_id=company_id, - only=("id", "name", "company", "company_origin"), - task_ids=task_ids, - allow_public=True, - ) - - companies = {t.get_index_company() for t in tasks} - if len(companies) > 1: - raise errors.bad_request.InvalidTaskId( - "only tasks from the same company are supported" - ) + company, tasks_or_models = _get_task_or_model_index_company(company_id, task_ids) # Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination result = event_bll.get_task_events( - next(iter(companies)), + company, task_ids, event_type=EventType.metrics_plot, sort=[{"iter": {"order": "desc"}}], @@ -504,7 +546,7 @@ def get_multi_task_plots_v1_7(call, company_id, _): scroll_id=scroll_id, ) - tasks = {t.id: t.name for t in tasks} + tasks = {t.id: t.name for t in tasks_or_models} return_events = _get_top_iter_unique_events_per_task( result.events, max_iters=iters, tasks=tasks @@ -519,36 +561,29 @@ def get_multi_task_plots_v1_7(call, company_id, _): @endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"]) -def get_multi_task_plots(call, company_id, req_model): +def get_multi_task_plots(call, company_id, _): task_ids = call.data["tasks"] iters = call.data.get("iters", 1) scroll_id = call.data.get("scroll_id") no_scroll = call.data.get("no_scroll", False) + model_events = call.data.get("model_events", False) - tasks = task_bll.assert_exists( - company_id=call.identity.company, - only=("id", "name", "company", "company_origin"), - task_ids=task_ids, - allow_public=True, + company, tasks_or_models = _get_task_or_model_index_company( + company_id, task_ids, model_events ) - companies = {t.get_index_company() for t in tasks} - if len(companies) > 1: - raise errors.bad_request.InvalidTaskId( - "only tasks from the same company are supported" - ) - result = event_bll.get_task_events( - next(iter(companies)), + company, task_ids, event_type=EventType.metrics_plot, sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id, no_scroll=no_scroll, + model_events=model_events, ) - tasks = {t.id: t.name for t in tasks} + tasks = {t.id: t.name for t in tasks_or_models} return_events = _get_top_iter_unique_events_per_task( result.events, max_iters=iters, tasks=tasks @@ -615,17 +650,18 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest): iters = request.iters scroll_id = request.scroll_id - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company", "company_origin") + task_or_model = _assert_task_or_model_exists( + company_id, task_id, model_events=request.model_events )[0] result = event_bll.get_task_plots( - task.get_index_company(), + task_or_model.get_index_company(), tasks=[task_id], sort=[{"iter": {"order": "desc"}}], last_iterations_per_plot=iters, scroll_id=scroll_id, no_scroll=request.no_scroll, metric_variants=_get_metric_variants_from_request(request.metrics), + model_events=request.model_events, ) return_events = result.events @@ -651,21 +687,11 @@ def task_plots(call, company_id, request: MetricEventsRequest): if None in metrics: metrics.clear() - tasks = task_bll.assert_exists( - company_id, - task_ids=list(task_metrics), - allow_public=True, - only=("company", "company_origin"), + company, _ = _get_task_or_model_index_company( + company_id, task_ids=list(task_metrics), model_events=request.model_events ) - - companies = {t.get_index_company() for t in tasks} - if len(companies) > 1: - raise errors.bad_request.InvalidTaskId( - "only tasks from the same company are supported" - ) - result = event_bll.plots_iterator.get_task_events( - company_id=next(iter(companies)), + company_id=company, task_metrics=task_metrics, iter_count=request.iters, navigate_earlier=request.navigate_earlier, @@ -730,17 +756,19 @@ def get_debug_images_v1_8(call, company_id, _): task_id = call.data["task"] iters = call.data.get("iters") or 1 scroll_id = call.data.get("scroll_id") + model_events = call.data.get("model_events", False) - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company", "company_origin") + tasks_or_model = _assert_task_or_model_exists( + company_id, task_id, model_events=model_events, )[0] result = event_bll.get_task_events( - task.get_index_company(), + tasks_or_model.get_index_company(), task_id, event_type=EventType.metrics_image, sort=[{"iter": {"order": "desc"}}], last_iter_count=iters, scroll_id=scroll_id, + model_events=model_events, ) return_events = result.events @@ -768,21 +796,12 @@ def get_debug_images(call, company_id, request: MetricEventsRequest): if None in metrics: metrics.clear() - tasks = task_bll.assert_exists( - company_id, - task_ids=list(task_metrics), - allow_public=True, - only=("company", "company_origin"), + company, _ = _get_task_or_model_index_company( + company_id, task_ids=list(task_metrics), model_events=request.model_events ) - companies = {t.get_index_company() for t in tasks} - if len(companies) > 1: - raise errors.bad_request.InvalidTaskId( - "only tasks from the same company are supported" - ) - result = event_bll.debug_images_iterator.get_task_events( - company_id=next(iter(companies)), + company_id=company, task_metrics=task_metrics, iter_count=request.iters, navigate_earlier=request.navigate_earlier, @@ -811,11 +830,11 @@ def get_debug_images(call, company_id, request: MetricEventsRequest): request_data_model=GetHistorySampleRequest, ) def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest): - task = task_bll.assert_exists( - company_id, task_ids=[request.task], allow_public=True, only=("company",) + task_or_model = _assert_task_or_model_exists( + company_id, request.task, model_events=request.model_events, )[0] res = event_bll.debug_image_sample_history.get_sample_for_variant( - company_id=task.company, + company_id=task_or_model.get_index_company(), task=request.task, metric=request.metric, variant=request.variant, @@ -833,11 +852,11 @@ def get_debug_image_sample(call, company_id, request: GetHistorySampleRequest): request_data_model=NextHistorySampleRequest, ) def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest): - task = task_bll.assert_exists( - company_id, task_ids=[request.task], allow_public=True, only=("company",) + task_or_model = _assert_task_or_model_exists( + company_id, request.task, model_events=request.model_events, )[0] res = event_bll.debug_image_sample_history.get_next_sample( - company_id=task.company, + company_id=task_or_model.get_index_company(), task=request.task, state_id=request.scroll_id, navigate_earlier=request.navigate_earlier, @@ -849,11 +868,11 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest) "events.get_plot_sample", request_data_model=GetHistorySampleRequest, ) def get_plot_sample(call, company_id, request: GetHistorySampleRequest): - task = task_bll.assert_exists( - company_id, task_ids=[request.task], allow_public=True, only=("company",) + task_or_model = _assert_task_or_model_exists( + company_id, request.task, model_events=request.model_events, )[0] res = event_bll.plot_sample_history.get_sample_for_variant( - company_id=task.company, + company_id=task_or_model.get_index_company(), task=request.task, metric=request.metric, variant=request.variant, @@ -869,11 +888,11 @@ def get_plot_sample(call, company_id, request: GetHistorySampleRequest): "events.next_plot_sample", request_data_model=NextHistorySampleRequest, ) def next_plot_sample(call, company_id, request: NextHistorySampleRequest): - task = task_bll.assert_exists( - company_id, task_ids=[request.task], allow_public=True, only=("company",) + task_or_model = _assert_task_or_model_exists( + company_id, request.task, model_events=request.model_events, )[0] res = event_bll.plot_sample_history.get_next_sample( - company_id=task.company, + company_id=task_or_model.get_index_company(), task=request.task, state_id=request.scroll_id, navigate_earlier=request.navigate_earlier, @@ -883,14 +902,11 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest): @endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest) def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): - task = task_bll.assert_exists( - company_id, - task_ids=request.tasks, - allow_public=True, - only=("company", "company_origin"), - )[0] + company, _ = _get_task_or_model_index_company( + company_id, request.tasks, model_events=request.model_events + ) res = event_bll.metrics.get_task_metrics( - task.get_index_company(), task_ids=request.tasks, event_type=request.event_type + company, task_ids=request.tasks, event_type=request.event_type ) call.result.data = { "metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res] @@ -898,7 +914,7 @@ def get_task_metrics(call: APICall, company_id, request: TaskMetricsRequest): @endpoint("events.delete_for_task", required_fields=["task"]) -def delete_for_task(call, company_id, req_model): +def delete_for_task(call, company_id, _): task_id = call.data["task"] allow_locked = call.data.get("allow_locked", False) @@ -910,6 +926,19 @@ def delete_for_task(call, company_id, req_model): ) +@endpoint("events.delete_for_model", required_fields=["model"]) +def delete_for_model(call: APICall, company_id: str, _): + model_id = call.data["model"] + allow_locked = call.data.get("allow_locked", False) + + model_bll.assert_exists(company_id, model_id, return_models=False) + call.result.data = dict( + deleted=event_bll.delete_task_events( + company_id, model_id, allow_locked=allow_locked, model=True + ) + ) + + @endpoint("events.clear_task_log") def clear_task_log(call: APICall, company_id: str, request: ClearTaskLogRequest): task_id = request.task @@ -1004,17 +1033,13 @@ def scalar_metrics_iter_raw( request.batch_size = request.batch_size or scroll.request.batch_size task_id = request.task - - task = task_bll.assert_exists( - company_id, task_id, allow_public=True, only=("company",), - )[0] - + task_or_model = _assert_task_or_model_exists(company_id, task_id, model_events=request.model_events)[0] metric_variants = _get_metric_variants_from_request([request.metric]) if request.count_total and total is None: total = event_bll.events_iterator.count_task_events( event_type=EventType.metrics_scalar, - company_id=task.company, + company_id=task_or_model.get_index_company(), task_id=task_id, metric_variants=metric_variants, ) @@ -1030,7 +1055,7 @@ def scalar_metrics_iter_raw( for iteration in range(0, math.ceil(batch_size / 10_000)): res = event_bll.events_iterator.get_task_events( event_type=EventType.metrics_scalar, - company_id=task.company, + company_id=task_or_model.get_index_company(), task_id=task_id, batch_size=min(batch_size, 10_000), navigate_earlier=False, diff --git a/apiserver/tests/automated/test_task_events.py b/apiserver/tests/automated/test_task_events.py index fea05a3..5a9b7e1 100644 --- a/apiserver/tests/automated/test_task_events.py +++ b/apiserver/tests/automated/test_task_events.py @@ -6,17 +6,24 @@ from typing import Sequence, Optional, Tuple from boltons.iterutils import first +from apiserver.apierrors import errors from apiserver.es_factory import es_factory from apiserver.apierrors.errors.bad_request import EventsNotAdded from apiserver.tests.automated import TestService class TestTaskEvents(TestService): + delete_params = dict(can_fail=True, force=True) + def _temp_task(self, name="test task events"): task_input = dict( name=name, type="training", input=dict(mapping={}, view=dict(entries=[])), ) - return self.create_temp("tasks", **task_input) + return self.create_temp("tasks", delete_paramse=self.delete_params, **task_input) + + def _temp_model(self, name="test model events", **kwargs): + self.update_missing(kwargs, name=name, uri="file:///a/b", labels={}) + return self.create_temp("models", delete_params=self.delete_params, **kwargs) @staticmethod def _create_task_event(type_, task, iteration, **kwargs): @@ -172,6 +179,42 @@ class TestTaskEvents(TestService): self.assertEqual(iter_count - 1, metric_data.max_value) self.assertEqual(0, metric_data.min_value) + def test_model_events(self): + model = self._temp_model(ready=False) + + # task log events are not allowed + log_event = self._create_task_event( + "log", + task=model, + iteration=0, + msg=f"This is a log message", + model_event=True, + ) + with self.api.raises(errors.bad_request.EventsNotAdded): + self.send(log_event) + + # send metric events and check that model data always have iteration 0 and only last data is saved + events = [ + { + **self._create_task_event("training_stats_scalar", model, iteration), + "metric": f"Metric{metric_idx}", + "variant": f"Variant{variant_idx}", + "value": iteration, + "model_event": True, + } + for iteration in range(2) + for metric_idx in range(5) + for variant_idx in range(5) + ] + self.send_batch(events) + data = self.api.events.scalar_metrics_iter_histogram(task=model, model_events=True) + self.assertEqual(list(data), [f"Metric{idx}" for idx in range(5)]) + metric_data = data.Metric0 + self.assertEqual(list(metric_data), [f"Variant{idx}" for idx in range(5)]) + variant_data = metric_data.Variant0 + self.assertEqual(variant_data.x, [0]) + self.assertEqual(variant_data.y, [1.0]) + def test_error_events(self): task = self._temp_task() events = [ @@ -555,7 +598,8 @@ class TestTaskEvents(TestService): return data def send(self, event): - self.api.send("events.add", event) + _, data = self.api.send("events.add", event) + return data if __name__ == "__main__": diff --git a/apiserver/tests/automated/test_tasks_delete.py b/apiserver/tests/automated/test_tasks_delete.py index 019c123..df16751 100644 --- a/apiserver/tests/automated/test_tasks_delete.py +++ b/apiserver/tests/automated/test_tasks_delete.py @@ -50,10 +50,12 @@ class TestTasksResetDelete(TestService): self.assertEqual(res.urls.artifact_urls, []) task = self.new_task() + (_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task) published_model_urls, draft_model_urls = self.create_task_models(task) artifact_urls = self.send_artifacts(task) event_urls = self.send_debug_image_events(task) event_urls.update(self.send_plot_events(task)) + event_urls.update(self.send_model_events(model)) res = self.assert_delete_task(task, force=True, return_file_urls=True) self.assertEqual(set(res.urls.model_urls), draft_model_urls) self.assertEqual(set(res.urls.event_urls), event_urls) @@ -120,10 +122,12 @@ class TestTasksResetDelete(TestService): self, **kwargs ) -> Tuple[str, Tuple[Set[str], Set[str]], Set[str], Set[str]]: task = self.new_task(**kwargs) + (_, published_model_urls), (model, draft_model_urls) = self.create_task_models(task, **kwargs) published_model_urls, draft_model_urls = self.create_task_models(task, **kwargs) artifact_urls = self.send_artifacts(task) event_urls = self.send_debug_image_events(task) event_urls.update(self.send_plot_events(task)) + event_urls.update(self.send_model_events(model)) return task, (published_model_urls, draft_model_urls), artifact_urls, event_urls def assert_delete_task(self, task_id, force=False, return_file_urls=False): @@ -137,15 +141,17 @@ class TestTasksResetDelete(TestService): self.assertEqual(tasks, []) return res - def create_task_models(self, task, **kwargs) -> Tuple[Set[str], Set[str]]: + def create_task_models(self, task, **kwargs) -> Tuple: """ Update models from task and return only non public models """ - model_ready = self.new_model(uri="ready", **kwargs) - model_not_ready = self.new_model(uri="not_ready", ready=False, **kwargs) + ready_uri = "ready" + not_ready_uri = "not_ready" + model_ready = self.new_model(uri=ready_uri, **kwargs) + model_not_ready = self.new_model(uri=not_ready_uri, ready=False, **kwargs) self.api.models.edit(model=model_not_ready, task=task) self.api.models.edit(model=model_ready, task=task) - return {"ready"}, {"not_ready"} + return (model_ready, {ready_uri}), (model_not_ready, {not_ready_uri}) def send_artifacts(self, task) -> Set[str]: """ @@ -159,6 +165,20 @@ class TestTasksResetDelete(TestService): self.api.tasks.add_or_update_artifacts(task=task, artifacts=artifacts) return {"test2"} + def send_model_events(self, model) -> Set[str]: + url1 = "http://link1" + url2 = "http://link2" + events = [ + self.create_event( + model, "training_debug_image", 0, url=url1, model_event=True + ), + self.create_event( + model, "plot", 0, plot_str=f'{{"source": "{url2}"}}', model_event=True + ) + ] + self.send_batch(events) + return {url1, url2} + def send_debug_image_events(self, task) -> Set[str]: events = [ self.create_event(