mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Use EVENT_TYPE enum instead of string
This commit is contained in:
parent
6974aa3a99
commit
4707647c92
@ -19,6 +19,7 @@ from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@ -59,7 +60,7 @@ class DebugImagesResult(object):
|
||||
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
@ -142,10 +143,10 @@ class DebugImagesIterator:
|
||||
return [
|
||||
(
|
||||
(task.id, stats.metric),
|
||||
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
|
||||
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
|
||||
)
|
||||
for stats in metric_stats.values()
|
||||
if self.EVENT_TYPE in stats.event_stats_by_type
|
||||
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||
]
|
||||
|
||||
update_times = dict(
|
||||
@ -257,7 +258,6 @@ class DebugImagesIterator:
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
routing=task,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
@ -401,7 +401,6 @@ class DebugImagesIterator:
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
routing=metric.task,
|
||||
)
|
||||
if "aggregations" not in es_res:
|
||||
return metric.task, metric.name, []
|
||||
|
@ -10,7 +10,12 @@ from redis import StrictRedis
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels import JsonSerializableMixin
|
||||
from apiserver.bll.event.event_common import EventSettings, get_index_name
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
EventType,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
@ -44,7 +49,7 @@ class DebugSampleHistoryResult(object):
|
||||
|
||||
|
||||
class DebugSampleHistory:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
EVENT_TYPE = EventType.metrics_image
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
@ -66,14 +71,13 @@ class DebugSampleHistory:
|
||||
if not state or state.task != task:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||
|
||||
es_index = get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
image = self._get_next_for_current_iteration(
|
||||
es_index=es_index, navigate_earlier=navigate_earlier, state=state
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
) or self._get_next_for_another_iteration(
|
||||
es_index=es_index, navigate_earlier=navigate_earlier, state=state
|
||||
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||
)
|
||||
if not image:
|
||||
return res
|
||||
@ -94,7 +98,7 @@ class DebugSampleHistory:
|
||||
res.max_iteration = var_state.max_iteration
|
||||
|
||||
def _get_next_for_current_iteration(
|
||||
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
|
||||
@ -127,7 +131,9 @@ class DebugSampleHistory:
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_current_iteration"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
@ -136,7 +142,7 @@ class DebugSampleHistory:
|
||||
return hits[0]["_source"]
|
||||
|
||||
def _get_next_for_another_iteration(
|
||||
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||
@ -189,7 +195,9 @@ class DebugSampleHistory:
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_next_for_another_iteration"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
@ -212,14 +220,13 @@ class DebugSampleHistory:
|
||||
If the iteration is not passed then get the latest event
|
||||
"""
|
||||
res = DebugSampleHistoryResult()
|
||||
es_index = get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||
return res
|
||||
|
||||
def init_state(state_: DebugSampleHistoryState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
self._reset_variant_states(es_index, state=state_)
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
def validate_state(state_: DebugSampleHistoryState):
|
||||
if state_.task != task or state_.metric != metric:
|
||||
@ -228,7 +235,7 @@ class DebugSampleHistory:
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_variant_states(es_index, state=state_)
|
||||
self._reset_variant_states(company_id=company_id, state=state_)
|
||||
|
||||
state: DebugSampleHistoryState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
@ -271,7 +278,12 @@ class DebugSampleHistory:
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_for_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
|
||||
hits = nested_get(es_res, ("hits", "hits"))
|
||||
if not hits:
|
||||
@ -282,9 +294,9 @@ class DebugSampleHistory:
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState):
|
||||
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
|
||||
variant_iterations = self._get_variant_iterations(
|
||||
es_index=es_index, task=state.task, metric=state.metric
|
||||
company_id=company_id, task=state.task, metric=state.metric
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
|
||||
@ -293,7 +305,7 @@ class DebugSampleHistory:
|
||||
|
||||
def _get_variant_iterations(
|
||||
self,
|
||||
es_index: str,
|
||||
company_id: str,
|
||||
task: str,
|
||||
metric: str,
|
||||
variants: Optional[Sequence[str]] = None,
|
||||
@ -344,7 +356,9 @@ class DebugSampleHistory:
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "get_debug_image_iterations"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||
)
|
||||
|
||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||
variant = variant_bucket["key"]
|
||||
|
@ -13,7 +13,14 @@ from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
|
||||
from apiserver.bll.event.event_common import EventType, EventSettings, get_index_name
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
get_index_name,
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
delete_company_events,
|
||||
)
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.es_factory import es_factory
|
||||
@ -156,7 +163,7 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
|
||||
if event_type != "log":
|
||||
if event_type != EventType.task_log.value:
|
||||
es_action["_id"] = self._get_event_id(event)
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
@ -389,10 +396,10 @@ class EventBLL(object):
|
||||
|
||||
def scroll_task_events(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
order,
|
||||
event_type=None,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
order: str,
|
||||
event_type: EventType,
|
||||
batch_size=10000,
|
||||
scroll_id=None,
|
||||
):
|
||||
@ -404,12 +411,7 @@ class EventBLL(object):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
size = min(batch_size, 10000)
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return [], None, 0
|
||||
|
||||
es_req = {
|
||||
@ -419,15 +421,25 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_last_iterations_per_event_metric_variant(
|
||||
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
):
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
es_req: dict = {
|
||||
@ -461,13 +473,14 @@ class EventBLL(object):
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
if event_type:
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "task_last_iter_metric_variant"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@ -494,9 +507,8 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
event_type = "plot"
|
||||
es_index = get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
event_type = EventType.metrics_plot
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
plot_valid_condition = {
|
||||
@ -519,7 +531,10 @@ class EventBLL(object):
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||
es_index, task_id, last_iterations_per_plot, event_type
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@ -551,8 +566,13 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@ -577,9 +597,9 @@ class EventBLL(object):
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
event_type=None,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
event_type: EventType,
|
||||
metric=None,
|
||||
variant=None,
|
||||
last_iter_count=None,
|
||||
@ -595,11 +615,8 @@ class EventBLL(object):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
must = []
|
||||
@ -614,7 +631,10 @@ class EventBLL(object):
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@ -642,8 +662,13 @@ class EventBLL(object):
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
||||
es_res = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@ -651,11 +676,10 @@ class EventBLL(object):
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def get_metrics_and_variants(self, company_id, task_id, event_type):
|
||||
|
||||
es_index = get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
def get_metrics_and_variants(
|
||||
self, company_id: str, task_id: str, event_type: EventType
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
@ -684,7 +708,9 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
@ -695,10 +721,9 @@ class EventBLL(object):
|
||||
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id, task_id):
|
||||
es_index = get_index_name(company_id, "training_stats_scalar")
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
@ -753,7 +778,9 @@ class EventBLL(object):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
@ -780,9 +807,8 @@ class EventBLL(object):
|
||||
return metrics, max_timestamp
|
||||
|
||||
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
|
||||
|
||||
es_index = get_index_name(company_id, "training_stats_vector")
|
||||
if not self.es.indices.exists(es_index):
|
||||
event_type = EventType.metrics_vector
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return [], []
|
||||
|
||||
es_req = {
|
||||
@ -800,7 +826,9 @@ class EventBLL(object):
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
@ -810,8 +838,10 @@ class EventBLL(object):
|
||||
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters(self, es_index, task_id, event_type, iters):
|
||||
if not self.es.indices.exists(es_index):
|
||||
def get_last_iters(
|
||||
self, company_id: str, event_type: EventType, task_id: str, iters: int
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
es_req: dict = {
|
||||
@ -827,11 +857,12 @@ class EventBLL(object):
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
if event_type:
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
@ -850,9 +881,14 @@ class EventBLL(object):
|
||||
extra_msg, company=company_id, id=task_id
|
||||
)
|
||||
|
||||
es_index = get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
event_type=EventType.all,
|
||||
body=es_req,
|
||||
refresh=True,
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
@ -13,6 +13,7 @@ class EventType(Enum):
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
all = "*"
|
||||
|
||||
|
||||
class EventSettings:
|
||||
@ -40,8 +41,8 @@ def get_index_name(company_id: str, event_type: str):
|
||||
return f"events-{event_type}-{company_id}"
|
||||
|
||||
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> bool:
|
||||
es_index = get_index_name(company_id, event_type)
|
||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
if not es.indices.exists(es_index):
|
||||
return True
|
||||
return False
|
||||
@ -50,16 +51,16 @@ def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> boo
|
||||
def search_company_events(
|
||||
es: Elasticsearch,
|
||||
company_id: Union[str, Sequence[str]],
|
||||
event_type: str,
|
||||
event_type: EventType,
|
||||
body: dict,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type)
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.search(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: str, body: dict, **kwargs
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type)
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
|
@ -10,7 +10,12 @@ from elasticsearch import Elasticsearch
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.event_common import EventType, get_index_name, EventSettings
|
||||
from apiserver.bll.event.event_common import (
|
||||
EventType,
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@ -36,31 +41,40 @@ class EventMetrics:
|
||||
The amount of points in each histogram should not exceed
|
||||
the requested samples
|
||||
"""
|
||||
es_index = get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, es_index, samples, ScalarKey.resolve(key)
|
||||
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
self,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
es_index=es_index, task_id=task_id, samples=samples, field=key.field
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
interval_groups = self._group_task_metric_intervals(intervals)
|
||||
|
||||
get_scalar_average = partial(
|
||||
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
|
||||
self._get_scalar_average,
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
key=key,
|
||||
)
|
||||
if run_parallel:
|
||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||
@ -110,13 +124,15 @@ class EventMetrics:
|
||||
"only tasks from the same company are supported"
|
||||
)
|
||||
|
||||
es_index = get_index_name(next(iter(companies)), "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
event_type = EventType.metrics_scalar
|
||||
company_id = next(iter(companies))
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
get_scalar_average_per_iter = partial(
|
||||
self._get_scalar_average_per_iter_core,
|
||||
es_index=es_index,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
run_parallel=False,
|
||||
@ -175,7 +191,12 @@ class EventMetrics:
|
||||
return metric_interval_groups
|
||||
|
||||
def _get_task_metric_intervals(
|
||||
self, es_index, task_id: str, samples: int, field: str = "iter"
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
@ -212,7 +233,9 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
aggs_result = es_res.get("aggregations")
|
||||
if not aggs_result:
|
||||
@ -255,7 +278,8 @@ class EventMetrics:
|
||||
self,
|
||||
metrics_interval: MetricIntervalGroup,
|
||||
task_id: str,
|
||||
es_index: str,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
key: ScalarKey,
|
||||
) -> Sequence[MetricData]:
|
||||
"""
|
||||
@ -283,7 +307,11 @@ class EventMetrics:
|
||||
}
|
||||
}
|
||||
aggs_result = self._query_aggregation_for_task_metrics(
|
||||
es_index, aggs=aggs, task_id=task_id, metrics=metrics
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
aggs=aggs,
|
||||
task_id=task_id,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
if not aggs_result:
|
||||
@ -314,7 +342,8 @@ class EventMetrics:
|
||||
|
||||
def _query_aggregation_for_task_metrics(
|
||||
self,
|
||||
es_index: str,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
aggs: dict,
|
||||
task_id: str,
|
||||
metrics: Sequence[Tuple[str, str]],
|
||||
@ -345,7 +374,9 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
return es_res.get("aggregations")
|
||||
|
||||
@ -356,30 +387,26 @@ class EventMetrics:
|
||||
For the requested tasks return all the metrics that
|
||||
reported events of the requested types
|
||||
"""
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
if not self.es.indices.exists(es_index):
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return {}
|
||||
|
||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||
res = pool.map(
|
||||
partial(
|
||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
||||
self._get_task_metrics,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
),
|
||||
task_ids,
|
||||
)
|
||||
return list(zip(task_ids, res))
|
||||
|
||||
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
|
||||
def _get_task_metrics(
|
||||
self, task_id: str, company_id: str, event_type: EventType
|
||||
) -> Sequence:
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"type": event_type.value}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@ -392,7 +419,9 @@ class EventMetrics:
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
)
|
||||
|
||||
return [
|
||||
metric["key"]
|
||||
|
@ -3,7 +3,11 @@ from typing import Optional, Tuple, Sequence
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import check_empty_data, search_company_events
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
@ -16,7 +20,7 @@ class TaskEventsResult:
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = "log"
|
||||
EVENT_TYPE = EventType.task_log
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
@ -75,7 +79,6 @@ class LogEventsIterator:
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
@ -102,7 +105,6 @@ class LogEventsIterator:
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
|
@ -32,6 +32,7 @@ from furl import furl
|
||||
from mongoengine import Q
|
||||
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.task.artifacts import get_artifact_id
|
||||
from apiserver.bll.task.param_utils import (
|
||||
split_param_name,
|
||||
@ -532,19 +533,22 @@ class PrePopulate:
|
||||
scroll_id = None
|
||||
while True:
|
||||
res = cls.event_bll.get_task_events(
|
||||
task.company, task.id, scroll_id=scroll_id
|
||||
company_id=task.company,
|
||||
task_id=task.id,
|
||||
event_type=EventType.all,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
scroll_id = res.next_scroll_id
|
||||
for event in res.events:
|
||||
event_type = event.get("type")
|
||||
if event_type == "training_debug_image":
|
||||
if event_type == EventType.metrics_image.value:
|
||||
url = cls._get_fixed_url(event.get("url"))
|
||||
if url:
|
||||
event["url"] = url
|
||||
artifacts.append(url)
|
||||
elif event_type == "plot":
|
||||
elif event_type == EventType.metrics_plot.value:
|
||||
plot_str: str = event.get("plot_str", "")
|
||||
for match in cls.img_source_regex.findall(plot_str):
|
||||
url = cls._get_fixed_url(match)
|
||||
|
@ -19,7 +19,7 @@ from apiserver.apimodels.events import (
|
||||
NextDebugImageSampleRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import get_index_name
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json
|
||||
@ -63,7 +63,7 @@ def get_task_log_v1_5(call, company_id, _):
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
order,
|
||||
event_type="log",
|
||||
event_type=EventType.task_log,
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
@ -90,7 +90,7 @@ def get_task_log_v1_7(call, company_id, _):
|
||||
company_id=task.get_index_company(),
|
||||
task_id=task_id,
|
||||
order=scroll_order,
|
||||
event_type="log",
|
||||
event_type=EventType.task_log,
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
@ -118,11 +118,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
from_timestamp=request.from_timestamp,
|
||||
)
|
||||
|
||||
if (
|
||||
request.order and (
|
||||
(request.navigate_earlier and request.order == LogOrderEnum.asc)
|
||||
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
|
||||
)
|
||||
if request.order and (
|
||||
(request.navigate_earlier and request.order == LogOrderEnum.asc)
|
||||
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
|
||||
):
|
||||
res.events.reverse()
|
||||
|
||||
@ -182,7 +180,7 @@ def download_task_log(call, company_id, _):
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
order="asc",
|
||||
event_type="log",
|
||||
event_type=EventType.task_log,
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
@ -219,7 +217,7 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.get_index_company(), task_id, "training_stats_vector"
|
||||
task.get_index_company(), task_id, EventType.metrics_vector
|
||||
)
|
||||
)
|
||||
|
||||
@ -232,7 +230,7 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
||||
)[0]
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(
|
||||
task.get_index_company(), task_id, "training_stats_scalar"
|
||||
task.get_index_company(), task_id, EventType.metrics_scalar
|
||||
)
|
||||
)
|
||||
|
||||
@ -272,7 +270,7 @@ def get_task_events(call, company_id, _):
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=event_type,
|
||||
event_type=EventType(event_type) if event_type else EventType.all,
|
||||
scroll_id=scroll_id,
|
||||
size=batch_size,
|
||||
)
|
||||
@ -297,7 +295,7 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="training_stats_scalar",
|
||||
event_type=EventType.metrics_scalar,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
scroll_id=scroll_id,
|
||||
@ -321,8 +319,9 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
|
||||
index_company, task_id
|
||||
)
|
||||
es_index = get_index_name(index_company, "*")
|
||||
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
|
||||
last_iters = event_bll.get_last_iters(
|
||||
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
|
||||
)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
last_iter=last_iters[0] if last_iters else 0,
|
||||
@ -344,7 +343,10 @@ def scalar_metrics_iter_histogram(
|
||||
company_id, request.task, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
||||
task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key
|
||||
task.get_index_company(),
|
||||
task_id=request.task,
|
||||
samples=request.samples,
|
||||
key=request.key,
|
||||
)
|
||||
call.result.data = metrics
|
||||
|
||||
@ -394,7 +396,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
||||
result = event_bll.get_task_events(
|
||||
next(iter(companies)),
|
||||
task_ids,
|
||||
event_type="plot",
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id,
|
||||
@ -436,7 +438,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
result = event_bll.get_task_events(
|
||||
next(iter(companies)),
|
||||
task_ids,
|
||||
event_type="plot",
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
@ -476,7 +478,7 @@ def get_task_plots_v1_7(call, company_id, _):
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="plot",
|
||||
event_type=EventType.metrics_plot,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id,
|
||||
@ -539,7 +541,7 @@ def get_debug_images_v1_7(call, company_id, _):
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
event_type=EventType.metrics_image,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id,
|
||||
@ -568,7 +570,7 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
event_type="training_debug_image",
|
||||
event_type=EventType.metrics_image,
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
@ -594,7 +596,10 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_ids = {m.task for m in request.metrics}
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin")
|
||||
company_id,
|
||||
task_ids=task_ids,
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
)
|
||||
|
||||
companies = {t.get_index_company() for t in tasks}
|
||||
@ -662,7 +667,7 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
|
||||
company_id=task.company,
|
||||
task=request.task,
|
||||
state_id=request.scroll_id,
|
||||
navigate_earlier=request.navigate_earlier
|
||||
navigate_earlier=request.navigate_earlier,
|
||||
)
|
||||
call.result.data = attr.asdict(res, recurse=False)
|
||||
|
||||
@ -670,7 +675,10 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
|
||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin")
|
||||
company_id,
|
||||
task_ids=request.tasks,
|
||||
allow_public=True,
|
||||
only=("company", "company_origin"),
|
||||
)[0]
|
||||
res = event_bll.metrics.get_tasks_metrics(
|
||||
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
|
||||
|
Loading…
Reference in New Issue
Block a user