Use EVENT_TYPE enum instead of string

This commit is contained in:
allegroai 2021-01-05 18:21:11 +02:00
parent 6974aa3a99
commit 4707647c92
8 changed files with 239 additions and 146 deletions

View File

@ -19,6 +19,7 @@ from apiserver.bll.event.event_common import (
EventSettings,
check_empty_data,
search_company_events,
EventType,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
@ -59,7 +60,7 @@ class DebugImagesResult(object):
class DebugImagesIterator:
EVENT_TYPE = "training_debug_image"
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
@ -142,10 +143,10 @@ class DebugImagesIterator:
return [
(
(task.id, stats.metric),
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
)
for stats in metric_stats.values()
if self.EVENT_TYPE in stats.event_stats_by_type
if self.EVENT_TYPE.value in stats.event_stats_by_type
]
update_times = dict(
@ -257,7 +258,6 @@ class DebugImagesIterator:
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task,
)
if "aggregations" not in es_res:
return []
@ -401,7 +401,6 @@ class DebugImagesIterator:
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=metric.task,
)
if "aggregations" not in es_res:
return metric.task, metric.name, []

View File

@ -10,7 +10,12 @@ from redis import StrictRedis
from apiserver.apierrors import errors
from apiserver.apimodels import JsonSerializableMixin
from apiserver.bll.event.event_common import EventSettings, get_index_name
from apiserver.bll.event.event_common import (
EventSettings,
EventType,
check_empty_data,
search_company_events,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@ -44,7 +49,7 @@ class DebugSampleHistoryResult(object):
class DebugSampleHistory:
EVENT_TYPE = "training_debug_image"
EVENT_TYPE = EventType.metrics_image
def __init__(self, redis: StrictRedis, es: Elasticsearch):
self.es = es
@ -66,14 +71,13 @@ class DebugSampleHistory:
if not state or state.task != task:
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
es_index = get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
image = self._get_next_for_current_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state
company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
es_index=es_index, navigate_earlier=navigate_earlier, state=state
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
if not image:
return res
@ -94,7 +98,7 @@ class DebugSampleHistory:
res.max_iteration = var_state.max_iteration
def _get_next_for_current_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
@ -127,7 +131,9 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext(
"es", "get_next_for_current_iteration"
):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
@ -136,7 +142,7 @@ class DebugSampleHistory:
return hits[0]["_source"]
def _get_next_for_another_iteration(
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
) -> Optional[dict]:
"""
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
@ -189,7 +195,9 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext(
"es", "get_next_for_another_iteration"
):
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
@ -212,14 +220,13 @@ class DebugSampleHistory:
If the iteration is not passed then get the latest event
"""
res = DebugSampleHistoryResult()
es_index = get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
return res
def init_state(state_: DebugSampleHistoryState):
state_.task = task
state_.metric = metric
self._reset_variant_states(es_index, state=state_)
self._reset_variant_states(company_id=company_id, state=state_)
def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric:
@ -228,7 +235,7 @@ class DebugSampleHistory:
scroll_id=state_.id,
)
if refresh:
self._reset_variant_states(es_index, state=state_)
self._reset_variant_states(company_id=company_id, state=state_)
state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state(
@ -271,7 +278,12 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext(
"es", "get_debug_image_for_variant"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = nested_get(es_res, ("hits", "hits"))
if not hits:
@ -282,9 +294,9 @@ class DebugSampleHistory:
)
return res
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState):
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations(
es_index=es_index, task=state.task, metric=state.metric
company_id=company_id, task=state.task, metric=state.metric
)
state.variant_states = [
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
@ -293,7 +305,7 @@ class DebugSampleHistory:
def _get_variant_iterations(
self,
es_index: str,
company_id: str,
task: str,
metric: str,
variants: Optional[Sequence[str]] = None,
@ -344,7 +356,9 @@ class DebugSampleHistory:
with translate_errors_context(), TimingContext(
"es", "get_debug_image_iterations"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
es_res = search_company_events(
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
)
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
variant = variant_bucket["key"]

View File

@ -13,7 +13,14 @@ from mongoengine import Q
from nested_dict import nested_dict
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
from apiserver.bll.event.event_common import EventType, EventSettings, get_index_name
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
get_index_name,
check_empty_data,
search_company_events,
delete_company_events,
)
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
@ -156,7 +163,7 @@ class EventBLL(object):
}
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
if event_type != "log":
if event_type != EventType.task_log.value:
es_action["_id"] = self._get_event_id(event)
else:
es_action["_id"] = dbutils.id()
@ -389,10 +396,10 @@ class EventBLL(object):
def scroll_task_events(
self,
company_id,
task_id,
order,
event_type=None,
company_id: str,
task_id: str,
order: str,
event_type: EventType,
batch_size=10000,
scroll_id=None,
):
@ -404,12 +411,7 @@ class EventBLL(object):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
size = min(batch_size, 10000)
if event_type is None:
event_type = "*"
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return [], None, 0
es_req = {
@ -419,15 +421,25 @@ class EventBLL(object):
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
scroll="1h",
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str
self,
company_id: str,
task_id: str,
num_last_iterations: int,
event_type: EventType,
):
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
es_req: dict = {
@ -461,13 +473,14 @@ class EventBLL(object):
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
if event_type:
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
if "aggregations" not in es_res:
return []
@ -494,9 +507,8 @@ class EventBLL(object):
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
event_type = "plot"
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
event_type = EventType.metrics_plot
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
plot_valid_condition = {
@ -519,7 +531,10 @@ class EventBLL(object):
should = []
for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant(
es_index, task_id, last_iterations_per_plot, event_type
company_id=company_id,
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
)
if not last_iters:
continue
@ -551,8 +566,13 @@ class EventBLL(object):
}
with translate_errors_context(), TimingContext("es", "get_task_plots"):
es_res = self.es.search(
index=es_index, body=es_req, ignore=404, scroll="1h",
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@ -577,9 +597,9 @@ class EventBLL(object):
def get_task_events(
self,
company_id,
task_id,
event_type=None,
company_id: str,
task_id: str,
event_type: EventType,
metric=None,
variant=None,
last_iter_count=None,
@ -595,11 +615,8 @@ class EventBLL(object):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if event_type is None:
event_type = "*"
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
must = []
@ -614,7 +631,10 @@ class EventBLL(object):
should = []
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
es_index, task_id, event_type, last_iter_count
company_id=company_id,
event_type=event_type,
task_id=task_id,
iters=last_iter_count,
)
if not last_iters:
continue
@ -642,8 +662,13 @@ class EventBLL(object):
}
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.search(
index=es_index, body=es_req, ignore=404, scroll="1h",
es_res = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@ -651,11 +676,10 @@ class EventBLL(object):
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
def get_metrics_and_variants(self, company_id, task_id, event_type):
es_index = get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
def get_metrics_and_variants(
self, company_id: str, task_id: str, event_type: EventType
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
es_req = {
@ -684,7 +708,9 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
@ -695,10 +721,9 @@ class EventBLL(object):
return metrics
def get_task_latest_scalar_values(self, company_id, task_id):
es_index = get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
es_req = {
@ -753,7 +778,9 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
metrics = []
max_timestamp = 0
@ -780,9 +807,8 @@ class EventBLL(object):
return metrics, max_timestamp
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
es_index = get_index_name(company_id, "training_stats_vector")
if not self.es.indices.exists(es_index):
event_type = EventType.metrics_vector
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return [], []
es_req = {
@ -800,7 +826,9 @@ class EventBLL(object):
"sort": ["iter"],
}
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
vectors = []
iterations = []
@ -810,8 +838,10 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters(self, es_index, task_id, event_type, iters):
if not self.es.indices.exists(es_index):
def get_last_iters(
self, company_id: str, event_type: EventType, task_id: str, iters: int
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
es_req: dict = {
@ -827,11 +857,12 @@ class EventBLL(object):
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
if event_type:
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
if "aggregations" not in es_res:
return []
@ -850,9 +881,14 @@ class EventBLL(object):
extra_msg, company=company_id, id=task_id
)
es_index = get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
es_res = delete_company_events(
es=self.es,
company_id=company_id,
event_type=EventType.all,
body=es_req,
refresh=True,
)
return es_res.get("deleted", 0)

View File

@ -13,6 +13,7 @@ class EventType(Enum):
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
all = "*"
class EventSettings:
@ -40,8 +41,8 @@ def get_index_name(company_id: str, event_type: str):
return f"events-{event_type}-{company_id}"
def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> bool:
es_index = get_index_name(company_id, event_type)
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
es_index = get_index_name(company_id, event_type.value)
if not es.indices.exists(es_index):
return True
return False
@ -50,16 +51,16 @@ def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> boo
def search_company_events(
es: Elasticsearch,
company_id: Union[str, Sequence[str]],
event_type: str,
event_type: EventType,
body: dict,
**kwargs,
) -> dict:
es_index = get_index_name(company_id, event_type)
es_index = get_index_name(company_id, event_type.value)
return es.search(index=es_index, body=body, **kwargs)
def delete_company_events(
es: Elasticsearch, company_id: str, event_type: str, body: dict, **kwargs
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type)
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs)

View File

@ -10,7 +10,12 @@ from elasticsearch import Elasticsearch
from mongoengine import Q
from apiserver.apierrors import errors
from apiserver.bll.event.event_common import EventType, get_index_name, EventSettings
from apiserver.bll.event.event_common import (
EventType,
EventSettings,
search_company_events,
check_empty_data,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@ -36,31 +41,40 @@ class EventMetrics:
The amount of points in each histogram should not exceed
the requested samples
"""
es_index = get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
return self._get_scalar_average_per_iter_core(
task_id, es_index, samples, ScalarKey.resolve(key)
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
)
def _get_scalar_average_per_iter_core(
self,
task_id: str,
es_index: str,
company_id: str,
event_type: EventType,
samples: int,
key: ScalarKey,
run_parallel: bool = True,
) -> dict:
intervals = self._get_task_metric_intervals(
es_index=es_index, task_id=task_id, samples=samples, field=key.field
company_id=company_id,
event_type=event_type,
task_id=task_id,
samples=samples,
field=key.field,
)
if not intervals:
return {}
interval_groups = self._group_task_metric_intervals(intervals)
get_scalar_average = partial(
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
self._get_scalar_average,
task_id=task_id,
company_id=company_id,
event_type=event_type,
key=key,
)
if run_parallel:
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
@ -110,13 +124,15 @@ class EventMetrics:
"only tasks from the same company are supported"
)
es_index = get_index_name(next(iter(companies)), "training_stats_scalar")
if not self.es.indices.exists(es_index):
event_type = EventType.metrics_scalar
company_id = next(iter(companies))
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
es_index=es_index,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
run_parallel=False,
@ -175,7 +191,12 @@ class EventMetrics:
return metric_interval_groups
def _get_task_metric_intervals(
self, es_index, task_id: str, samples: int, field: str = "iter"
self,
company_id: str,
event_type: EventType,
task_id: str,
samples: int,
field: str = "iter",
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
@ -212,7 +233,9 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
@ -255,7 +278,8 @@ class EventMetrics:
self,
metrics_interval: MetricIntervalGroup,
task_id: str,
es_index: str,
company_id: str,
event_type: EventType,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
@ -283,7 +307,11 @@ class EventMetrics:
}
}
aggs_result = self._query_aggregation_for_task_metrics(
es_index, aggs=aggs, task_id=task_id, metrics=metrics
company_id=company_id,
event_type=event_type,
aggs=aggs,
task_id=task_id,
metrics=metrics,
)
if not aggs_result:
@ -314,7 +342,8 @@ class EventMetrics:
def _query_aggregation_for_task_metrics(
self,
es_index: str,
company_id: str,
event_type: EventType,
aggs: dict,
task_id: str,
metrics: Sequence[Tuple[str, str]],
@ -345,7 +374,9 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
return es_res.get("aggregations")
@ -356,30 +387,26 @@ class EventMetrics:
For the requested tasks return all the metrics that
reported events of the requested types
"""
es_index = get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index):
if check_empty_data(self.es, company_id, event_type):
return {}
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
res = pool.map(
partial(
self._get_task_metrics, es_index=es_index, event_type=event_type,
self._get_task_metrics,
company_id=company_id,
event_type=event_type,
),
task_ids,
)
return list(zip(task_ids, res))
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
def _get_task_metrics(
self, task_id: str, company_id: str, event_type: EventType
) -> Sequence:
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"type": event_type.value}},
]
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"aggs": {
"metrics": {
"terms": {
@ -392,7 +419,9 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req)
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
)
return [
metric["key"]

View File

@ -3,7 +3,11 @@ from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import check_empty_data, search_company_events
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
)
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@ -16,7 +20,7 @@ class TaskEventsResult:
class LogEventsIterator:
EVENT_TYPE = "log"
EVENT_TYPE = EventType.task_log
def __init__(self, es: Elasticsearch):
self.es = es
@ -75,7 +79,6 @@ class LogEventsIterator:
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
@ -102,7 +105,6 @@ class LogEventsIterator:
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:

View File

@ -32,6 +32,7 @@ from furl import furl
from mongoengine import Q
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task.artifacts import get_artifact_id
from apiserver.bll.task.param_utils import (
split_param_name,
@ -532,19 +533,22 @@ class PrePopulate:
scroll_id = None
while True:
res = cls.event_bll.get_task_events(
task.company, task.id, scroll_id=scroll_id
company_id=task.company,
task_id=task.id,
event_type=EventType.all,
scroll_id=scroll_id,
)
if not res.events:
break
scroll_id = res.next_scroll_id
for event in res.events:
event_type = event.get("type")
if event_type == "training_debug_image":
if event_type == EventType.metrics_image.value:
url = cls._get_fixed_url(event.get("url"))
if url:
event["url"] = url
artifacts.append(url)
elif event_type == "plot":
elif event_type == EventType.metrics_plot.value:
plot_str: str = event.get("plot_str", "")
for match in cls.img_source_regex.findall(plot_str):
url = cls._get_fixed_url(match)

View File

@ -19,7 +19,7 @@ from apiserver.apimodels.events import (
NextDebugImageSampleRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import get_index_name
from apiserver.bll.event.event_common import EventType
from apiserver.bll.task import TaskBLL
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
@ -63,7 +63,7 @@ def get_task_log_v1_5(call, company_id, _):
task.get_index_company(),
task_id,
order,
event_type="log",
event_type=EventType.task_log,
batch_size=batch_size,
scroll_id=scroll_id,
)
@ -90,7 +90,7 @@ def get_task_log_v1_7(call, company_id, _):
company_id=task.get_index_company(),
task_id=task_id,
order=scroll_order,
event_type="log",
event_type=EventType.task_log,
batch_size=batch_size,
scroll_id=scroll_id,
)
@ -118,11 +118,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
from_timestamp=request.from_timestamp,
)
if (
request.order and (
(request.navigate_earlier and request.order == LogOrderEnum.asc)
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
)
if request.order and (
(request.navigate_earlier and request.order == LogOrderEnum.asc)
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
):
res.events.reverse()
@ -182,7 +180,7 @@ def download_task_log(call, company_id, _):
task.get_index_company(),
task_id,
order="asc",
event_type="log",
event_type=EventType.task_log,
batch_size=batch_size,
scroll_id=scroll_id,
)
@ -219,7 +217,7 @@ def get_vector_metrics_and_variants(call, company_id, _):
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, "training_stats_vector"
task.get_index_company(), task_id, EventType.metrics_vector
)
)
@ -232,7 +230,7 @@ def get_scalar_metrics_and_variants(call, company_id, _):
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
task.get_index_company(), task_id, "training_stats_scalar"
task.get_index_company(), task_id, EventType.metrics_scalar
)
)
@ -272,7 +270,7 @@ def get_task_events(call, company_id, _):
task.get_index_company(),
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
event_type=EventType(event_type) if event_type else EventType.all,
scroll_id=scroll_id,
size=batch_size,
)
@ -297,7 +295,7 @@ def get_scalar_metric_data(call, company_id, _):
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="training_stats_scalar",
event_type=EventType.metrics_scalar,
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id,
@ -321,8 +319,9 @@ def get_task_latest_scalar_values(call, company_id, _):
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
index_company, task_id
)
es_index = get_index_name(index_company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
@ -344,7 +343,10 @@ def scalar_metrics_iter_histogram(
company_id, request.task, allow_public=True, only=("company", "company_origin")
)[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key
task.get_index_company(),
task_id=request.task,
samples=request.samples,
key=request.key,
)
call.result.data = metrics
@ -394,7 +396,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
result = event_bll.get_task_events(
next(iter(companies)),
task_ids,
event_type="plot",
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id,
@ -436,7 +438,7 @@ def get_multi_task_plots(call, company_id, req_model):
result = event_bll.get_task_events(
next(iter(companies)),
task_ids,
event_type="plot",
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
@ -476,7 +478,7 @@ def get_task_plots_v1_7(call, company_id, _):
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="plot",
event_type=EventType.metrics_plot,
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id,
@ -539,7 +541,7 @@ def get_debug_images_v1_7(call, company_id, _):
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="training_debug_image",
event_type=EventType.metrics_image,
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id,
@ -568,7 +570,7 @@ def get_debug_images_v1_8(call, company_id, _):
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
event_type="training_debug_image",
event_type=EventType.metrics_image,
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
@ -594,7 +596,10 @@ def get_debug_images_v1_8(call, company_id, _):
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics}
tasks = task_bll.assert_exists(
company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin")
company_id,
task_ids=task_ids,
allow_public=True,
only=("company", "company_origin"),
)
companies = {t.get_index_company() for t in tasks}
@ -662,7 +667,7 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
company_id=task.company,
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier
navigate_earlier=request.navigate_earlier,
)
call.result.data = attr.asdict(res, recurse=False)
@ -670,7 +675,10 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin")
company_id,
task_ids=request.tasks,
allow_public=True,
only=("company", "company_origin"),
)[0]
res = event_bll.metrics.get_tasks_metrics(
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type