mirror of
https://github.com/clearml/clearml-server
synced 2025-06-23 08:45:30 +00:00
Use EVENT_TYPE enum instead of string
This commit is contained in:
parent
6974aa3a99
commit
4707647c92
@ -19,6 +19,7 @@ from apiserver.bll.event.event_common import (
|
|||||||
EventSettings,
|
EventSettings,
|
||||||
check_empty_data,
|
check_empty_data,
|
||||||
search_company_events,
|
search_company_events,
|
||||||
|
EventType,
|
||||||
)
|
)
|
||||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
@ -59,7 +60,7 @@ class DebugImagesResult(object):
|
|||||||
|
|
||||||
|
|
||||||
class DebugImagesIterator:
|
class DebugImagesIterator:
|
||||||
EVENT_TYPE = "training_debug_image"
|
EVENT_TYPE = EventType.metrics_image
|
||||||
|
|
||||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||||
self.es = es
|
self.es = es
|
||||||
@ -142,10 +143,10 @@ class DebugImagesIterator:
|
|||||||
return [
|
return [
|
||||||
(
|
(
|
||||||
(task.id, stats.metric),
|
(task.id, stats.metric),
|
||||||
stats.event_stats_by_type[self.EVENT_TYPE].last_update,
|
stats.event_stats_by_type[self.EVENT_TYPE.value].last_update,
|
||||||
)
|
)
|
||||||
for stats in metric_stats.values()
|
for stats in metric_stats.values()
|
||||||
if self.EVENT_TYPE in stats.event_stats_by_type
|
if self.EVENT_TYPE.value in stats.event_stats_by_type
|
||||||
]
|
]
|
||||||
|
|
||||||
update_times = dict(
|
update_times = dict(
|
||||||
@ -257,7 +258,6 @@ class DebugImagesIterator:
|
|||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
event_type=self.EVENT_TYPE,
|
event_type=self.EVENT_TYPE,
|
||||||
body=es_req,
|
body=es_req,
|
||||||
routing=task,
|
|
||||||
)
|
)
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
return []
|
return []
|
||||||
@ -401,7 +401,6 @@ class DebugImagesIterator:
|
|||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
event_type=self.EVENT_TYPE,
|
event_type=self.EVENT_TYPE,
|
||||||
body=es_req,
|
body=es_req,
|
||||||
routing=metric.task,
|
|
||||||
)
|
)
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
return metric.task, metric.name, []
|
return metric.task, metric.name, []
|
||||||
|
@ -10,7 +10,12 @@ from redis import StrictRedis
|
|||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.apimodels import JsonSerializableMixin
|
from apiserver.apimodels import JsonSerializableMixin
|
||||||
from apiserver.bll.event.event_common import EventSettings, get_index_name
|
from apiserver.bll.event.event_common import (
|
||||||
|
EventSettings,
|
||||||
|
EventType,
|
||||||
|
check_empty_data,
|
||||||
|
search_company_events,
|
||||||
|
)
|
||||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
@ -44,7 +49,7 @@ class DebugSampleHistoryResult(object):
|
|||||||
|
|
||||||
|
|
||||||
class DebugSampleHistory:
|
class DebugSampleHistory:
|
||||||
EVENT_TYPE = "training_debug_image"
|
EVENT_TYPE = EventType.metrics_image
|
||||||
|
|
||||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||||
self.es = es
|
self.es = es
|
||||||
@ -66,14 +71,13 @@ class DebugSampleHistory:
|
|||||||
if not state or state.task != task:
|
if not state or state.task != task:
|
||||||
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
raise errors.bad_request.InvalidScrollId(scroll_id=state_id)
|
||||||
|
|
||||||
es_index = get_index_name(company_id, self.EVENT_TYPE)
|
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
image = self._get_next_for_current_iteration(
|
image = self._get_next_for_current_iteration(
|
||||||
es_index=es_index, navigate_earlier=navigate_earlier, state=state
|
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||||
) or self._get_next_for_another_iteration(
|
) or self._get_next_for_another_iteration(
|
||||||
es_index=es_index, navigate_earlier=navigate_earlier, state=state
|
company_id=company_id, navigate_earlier=navigate_earlier, state=state
|
||||||
)
|
)
|
||||||
if not image:
|
if not image:
|
||||||
return res
|
return res
|
||||||
@ -94,7 +98,7 @@ class DebugSampleHistory:
|
|||||||
res.max_iteration = var_state.max_iteration
|
res.max_iteration = var_state.max_iteration
|
||||||
|
|
||||||
def _get_next_for_current_iteration(
|
def _get_next_for_current_iteration(
|
||||||
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
|
Get the image for next (if navigated earlier is False) or previous variant sorted by name for the same iteration
|
||||||
@ -127,7 +131,9 @@ class DebugSampleHistory:
|
|||||||
with translate_errors_context(), TimingContext(
|
with translate_errors_context(), TimingContext(
|
||||||
"es", "get_next_for_current_iteration"
|
"es", "get_next_for_current_iteration"
|
||||||
):
|
):
|
||||||
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
hits = nested_get(es_res, ("hits", "hits"))
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
if not hits:
|
if not hits:
|
||||||
@ -136,7 +142,7 @@ class DebugSampleHistory:
|
|||||||
return hits[0]["_source"]
|
return hits[0]["_source"]
|
||||||
|
|
||||||
def _get_next_for_another_iteration(
|
def _get_next_for_another_iteration(
|
||||||
self, es_index: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
self, company_id: str, navigate_earlier: bool, state: DebugSampleHistoryState
|
||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
|
Get the image for the first variant for the next iteration (if navigate_earlier is set to False)
|
||||||
@ -189,7 +195,9 @@ class DebugSampleHistory:
|
|||||||
with translate_errors_context(), TimingContext(
|
with translate_errors_context(), TimingContext(
|
||||||
"es", "get_next_for_another_iteration"
|
"es", "get_next_for_another_iteration"
|
||||||
):
|
):
|
||||||
es_res = self.es.search(index=es_index, body=es_req, routing=state.task)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
hits = nested_get(es_res, ("hits", "hits"))
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
if not hits:
|
if not hits:
|
||||||
@ -212,14 +220,13 @@ class DebugSampleHistory:
|
|||||||
If the iteration is not passed then get the latest event
|
If the iteration is not passed then get the latest event
|
||||||
"""
|
"""
|
||||||
res = DebugSampleHistoryResult()
|
res = DebugSampleHistoryResult()
|
||||||
es_index = get_index_name(company_id, self.EVENT_TYPE)
|
if check_empty_data(self.es, company_id=company_id, event_type=self.EVENT_TYPE):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def init_state(state_: DebugSampleHistoryState):
|
def init_state(state_: DebugSampleHistoryState):
|
||||||
state_.task = task
|
state_.task = task
|
||||||
state_.metric = metric
|
state_.metric = metric
|
||||||
self._reset_variant_states(es_index, state=state_)
|
self._reset_variant_states(company_id=company_id, state=state_)
|
||||||
|
|
||||||
def validate_state(state_: DebugSampleHistoryState):
|
def validate_state(state_: DebugSampleHistoryState):
|
||||||
if state_.task != task or state_.metric != metric:
|
if state_.task != task or state_.metric != metric:
|
||||||
@ -228,7 +235,7 @@ class DebugSampleHistory:
|
|||||||
scroll_id=state_.id,
|
scroll_id=state_.id,
|
||||||
)
|
)
|
||||||
if refresh:
|
if refresh:
|
||||||
self._reset_variant_states(es_index, state=state_)
|
self._reset_variant_states(company_id=company_id, state=state_)
|
||||||
|
|
||||||
state: DebugSampleHistoryState
|
state: DebugSampleHistoryState
|
||||||
with self.cache_manager.get_or_create_state(
|
with self.cache_manager.get_or_create_state(
|
||||||
@ -271,7 +278,12 @@ class DebugSampleHistory:
|
|||||||
with translate_errors_context(), TimingContext(
|
with translate_errors_context(), TimingContext(
|
||||||
"es", "get_debug_image_for_variant"
|
"es", "get_debug_image_for_variant"
|
||||||
):
|
):
|
||||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
es_res = search_company_events(
|
||||||
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=self.EVENT_TYPE,
|
||||||
|
body=es_req,
|
||||||
|
)
|
||||||
|
|
||||||
hits = nested_get(es_res, ("hits", "hits"))
|
hits = nested_get(es_res, ("hits", "hits"))
|
||||||
if not hits:
|
if not hits:
|
||||||
@ -282,9 +294,9 @@ class DebugSampleHistory:
|
|||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState):
|
def _reset_variant_states(self, company_id: str, state: DebugSampleHistoryState):
|
||||||
variant_iterations = self._get_variant_iterations(
|
variant_iterations = self._get_variant_iterations(
|
||||||
es_index=es_index, task=state.task, metric=state.metric
|
company_id=company_id, task=state.task, metric=state.metric
|
||||||
)
|
)
|
||||||
state.variant_states = [
|
state.variant_states = [
|
||||||
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
|
VariantState(name=var_name, min_iteration=min_iter, max_iteration=max_iter)
|
||||||
@ -293,7 +305,7 @@ class DebugSampleHistory:
|
|||||||
|
|
||||||
def _get_variant_iterations(
|
def _get_variant_iterations(
|
||||||
self,
|
self,
|
||||||
es_index: str,
|
company_id: str,
|
||||||
task: str,
|
task: str,
|
||||||
metric: str,
|
metric: str,
|
||||||
variants: Optional[Sequence[str]] = None,
|
variants: Optional[Sequence[str]] = None,
|
||||||
@ -344,7 +356,9 @@ class DebugSampleHistory:
|
|||||||
with translate_errors_context(), TimingContext(
|
with translate_errors_context(), TimingContext(
|
||||||
"es", "get_debug_image_iterations"
|
"es", "get_debug_image_iterations"
|
||||||
):
|
):
|
||||||
es_res = self.es.search(index=es_index, body=es_req, routing=task)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=self.EVENT_TYPE, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
def get_variant_data(variant_bucket: dict) -> Tuple[str, int, int]:
|
||||||
variant = variant_bucket["key"]
|
variant = variant_bucket["key"]
|
||||||
|
@ -13,7 +13,14 @@ from mongoengine import Q
|
|||||||
from nested_dict import nested_dict
|
from nested_dict import nested_dict
|
||||||
|
|
||||||
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
|
from apiserver.bll.event.debug_sample_history import DebugSampleHistory
|
||||||
from apiserver.bll.event.event_common import EventType, EventSettings, get_index_name
|
from apiserver.bll.event.event_common import (
|
||||||
|
EventType,
|
||||||
|
EventSettings,
|
||||||
|
get_index_name,
|
||||||
|
check_empty_data,
|
||||||
|
search_company_events,
|
||||||
|
delete_company_events,
|
||||||
|
)
|
||||||
from apiserver.bll.util import parallel_chunked_decorator
|
from apiserver.bll.util import parallel_chunked_decorator
|
||||||
from apiserver.database import utils as dbutils
|
from apiserver.database import utils as dbutils
|
||||||
from apiserver.es_factory import es_factory
|
from apiserver.es_factory import es_factory
|
||||||
@ -156,7 +163,7 @@ class EventBLL(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
|
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
|
||||||
if event_type != "log":
|
if event_type != EventType.task_log.value:
|
||||||
es_action["_id"] = self._get_event_id(event)
|
es_action["_id"] = self._get_event_id(event)
|
||||||
else:
|
else:
|
||||||
es_action["_id"] = dbutils.id()
|
es_action["_id"] = dbutils.id()
|
||||||
@ -389,10 +396,10 @@ class EventBLL(object):
|
|||||||
|
|
||||||
def scroll_task_events(
|
def scroll_task_events(
|
||||||
self,
|
self,
|
||||||
company_id,
|
company_id: str,
|
||||||
task_id,
|
task_id: str,
|
||||||
order,
|
order: str,
|
||||||
event_type=None,
|
event_type: EventType,
|
||||||
batch_size=10000,
|
batch_size=10000,
|
||||||
scroll_id=None,
|
scroll_id=None,
|
||||||
):
|
):
|
||||||
@ -404,12 +411,7 @@ class EventBLL(object):
|
|||||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||||
else:
|
else:
|
||||||
size = min(batch_size, 10000)
|
size = min(batch_size, 10000)
|
||||||
if event_type is None:
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
event_type = "*"
|
|
||||||
|
|
||||||
es_index = get_index_name(company_id, event_type)
|
|
||||||
|
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return [], None, 0
|
return [], None, 0
|
||||||
|
|
||||||
es_req = {
|
es_req = {
|
||||||
@ -419,15 +421,25 @@ class EventBLL(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
es_res = search_company_events(
|
||||||
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
|
body=es_req,
|
||||||
|
scroll="1h",
|
||||||
|
)
|
||||||
|
|
||||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||||
return events, next_scroll_id, total_events
|
return events, next_scroll_id, total_events
|
||||||
|
|
||||||
def get_last_iterations_per_event_metric_variant(
|
def get_last_iterations_per_event_metric_variant(
|
||||||
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str
|
self,
|
||||||
|
company_id: str,
|
||||||
|
task_id: str,
|
||||||
|
num_last_iterations: int,
|
||||||
|
event_type: EventType,
|
||||||
):
|
):
|
||||||
if not self.es.indices.exists(es_index):
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
es_req: dict = {
|
es_req: dict = {
|
||||||
@ -461,13 +473,14 @@ class EventBLL(object):
|
|||||||
},
|
},
|
||||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||||
}
|
}
|
||||||
if event_type:
|
|
||||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext(
|
with translate_errors_context(), TimingContext(
|
||||||
"es", "task_last_iter_metric_variant"
|
"es", "task_last_iter_metric_variant"
|
||||||
):
|
):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -494,9 +507,8 @@ class EventBLL(object):
|
|||||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||||
else:
|
else:
|
||||||
event_type = "plot"
|
event_type = EventType.metrics_plot
|
||||||
es_index = get_index_name(company_id, event_type)
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return TaskEventsResult()
|
return TaskEventsResult()
|
||||||
|
|
||||||
plot_valid_condition = {
|
plot_valid_condition = {
|
||||||
@ -519,7 +531,10 @@ class EventBLL(object):
|
|||||||
should = []
|
should = []
|
||||||
for i, task_id in enumerate(tasks):
|
for i, task_id in enumerate(tasks):
|
||||||
last_iters = self.get_last_iterations_per_event_metric_variant(
|
last_iters = self.get_last_iterations_per_event_metric_variant(
|
||||||
es_index, task_id, last_iterations_per_plot, event_type
|
company_id=company_id,
|
||||||
|
task_id=task_id,
|
||||||
|
num_last_iterations=last_iterations_per_plot,
|
||||||
|
event_type=event_type,
|
||||||
)
|
)
|
||||||
if not last_iters:
|
if not last_iters:
|
||||||
continue
|
continue
|
||||||
@ -551,8 +566,13 @@ class EventBLL(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
with translate_errors_context(), TimingContext("es", "get_task_plots"):
|
||||||
es_res = self.es.search(
|
es_res = search_company_events(
|
||||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
|
body=es_req,
|
||||||
|
ignore=404,
|
||||||
|
scroll="1h",
|
||||||
)
|
)
|
||||||
|
|
||||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||||
@ -577,9 +597,9 @@ class EventBLL(object):
|
|||||||
|
|
||||||
def get_task_events(
|
def get_task_events(
|
||||||
self,
|
self,
|
||||||
company_id,
|
company_id: str,
|
||||||
task_id,
|
task_id: str,
|
||||||
event_type=None,
|
event_type: EventType,
|
||||||
metric=None,
|
metric=None,
|
||||||
variant=None,
|
variant=None,
|
||||||
last_iter_count=None,
|
last_iter_count=None,
|
||||||
@ -595,11 +615,8 @@ class EventBLL(object):
|
|||||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||||
else:
|
else:
|
||||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||||
if event_type is None:
|
|
||||||
event_type = "*"
|
|
||||||
|
|
||||||
es_index = get_index_name(company_id, event_type)
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return TaskEventsResult()
|
return TaskEventsResult()
|
||||||
|
|
||||||
must = []
|
must = []
|
||||||
@ -614,7 +631,10 @@ class EventBLL(object):
|
|||||||
should = []
|
should = []
|
||||||
for i, task_id in enumerate(task_ids):
|
for i, task_id in enumerate(task_ids):
|
||||||
last_iters = self.get_last_iters(
|
last_iters = self.get_last_iters(
|
||||||
es_index, task_id, event_type, last_iter_count
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
|
task_id=task_id,
|
||||||
|
iters=last_iter_count,
|
||||||
)
|
)
|
||||||
if not last_iters:
|
if not last_iters:
|
||||||
continue
|
continue
|
||||||
@ -642,8 +662,13 @@ class EventBLL(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||||
es_res = self.es.search(
|
es_res = search_company_events(
|
||||||
index=es_index, body=es_req, ignore=404, scroll="1h",
|
self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
|
body=es_req,
|
||||||
|
ignore=404,
|
||||||
|
scroll="1h",
|
||||||
)
|
)
|
||||||
|
|
||||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||||
@ -651,11 +676,10 @@ class EventBLL(object):
|
|||||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_metrics_and_variants(self, company_id, task_id, event_type):
|
def get_metrics_and_variants(
|
||||||
|
self, company_id: str, task_id: str, event_type: EventType
|
||||||
es_index = get_index_name(company_id, event_type)
|
):
|
||||||
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
es_req = {
|
es_req = {
|
||||||
@ -684,7 +708,9 @@ class EventBLL(object):
|
|||||||
with translate_errors_context(), TimingContext(
|
with translate_errors_context(), TimingContext(
|
||||||
"es", "events_get_metrics_and_variants"
|
"es", "events_get_metrics_and_variants"
|
||||||
):
|
):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
metrics = {}
|
metrics = {}
|
||||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||||
@ -695,10 +721,9 @@ class EventBLL(object):
|
|||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def get_task_latest_scalar_values(self, company_id, task_id):
|
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
|
||||||
es_index = get_index_name(company_id, "training_stats_scalar")
|
event_type = EventType.metrics_scalar
|
||||||
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
es_req = {
|
es_req = {
|
||||||
@ -753,7 +778,9 @@ class EventBLL(object):
|
|||||||
with translate_errors_context(), TimingContext(
|
with translate_errors_context(), TimingContext(
|
||||||
"es", "events_get_metrics_and_variants"
|
"es", "events_get_metrics_and_variants"
|
||||||
):
|
):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
metrics = []
|
metrics = []
|
||||||
max_timestamp = 0
|
max_timestamp = 0
|
||||||
@ -780,9 +807,8 @@ class EventBLL(object):
|
|||||||
return metrics, max_timestamp
|
return metrics, max_timestamp
|
||||||
|
|
||||||
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
|
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
|
||||||
|
event_type = EventType.metrics_vector
|
||||||
es_index = get_index_name(company_id, "training_stats_vector")
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return [], []
|
return [], []
|
||||||
|
|
||||||
es_req = {
|
es_req = {
|
||||||
@ -800,7 +826,9 @@ class EventBLL(object):
|
|||||||
"sort": ["iter"],
|
"sort": ["iter"],
|
||||||
}
|
}
|
||||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
vectors = []
|
vectors = []
|
||||||
iterations = []
|
iterations = []
|
||||||
@ -810,8 +838,10 @@ class EventBLL(object):
|
|||||||
|
|
||||||
return iterations, vectors
|
return iterations, vectors
|
||||||
|
|
||||||
def get_last_iters(self, es_index, task_id, event_type, iters):
|
def get_last_iters(
|
||||||
if not self.es.indices.exists(es_index):
|
self, company_id: str, event_type: EventType, task_id: str, iters: int
|
||||||
|
):
|
||||||
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
es_req: dict = {
|
es_req: dict = {
|
||||||
@ -827,11 +857,12 @@ class EventBLL(object):
|
|||||||
},
|
},
|
||||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||||
}
|
}
|
||||||
if event_type:
|
|
||||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
if "aggregations" not in es_res:
|
if "aggregations" not in es_res:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -850,9 +881,14 @@ class EventBLL(object):
|
|||||||
extra_msg, company=company_id, id=task_id
|
extra_msg, company=company_id, id=task_id
|
||||||
)
|
)
|
||||||
|
|
||||||
es_index = get_index_name(company_id, "*")
|
|
||||||
es_req = {"query": {"term": {"task": task_id}}}
|
es_req = {"query": {"term": {"task": task_id}}}
|
||||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||||
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
|
es_res = delete_company_events(
|
||||||
|
es=self.es,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=EventType.all,
|
||||||
|
body=es_req,
|
||||||
|
refresh=True,
|
||||||
|
)
|
||||||
|
|
||||||
return es_res.get("deleted", 0)
|
return es_res.get("deleted", 0)
|
||||||
|
@ -13,6 +13,7 @@ class EventType(Enum):
|
|||||||
metrics_image = "training_debug_image"
|
metrics_image = "training_debug_image"
|
||||||
metrics_plot = "plot"
|
metrics_plot = "plot"
|
||||||
task_log = "log"
|
task_log = "log"
|
||||||
|
all = "*"
|
||||||
|
|
||||||
|
|
||||||
class EventSettings:
|
class EventSettings:
|
||||||
@ -40,8 +41,8 @@ def get_index_name(company_id: str, event_type: str):
|
|||||||
return f"events-{event_type}-{company_id}"
|
return f"events-{event_type}-{company_id}"
|
||||||
|
|
||||||
|
|
||||||
def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> bool:
|
def check_empty_data(es: Elasticsearch, company_id: str, event_type: EventType) -> bool:
|
||||||
es_index = get_index_name(company_id, event_type)
|
es_index = get_index_name(company_id, event_type.value)
|
||||||
if not es.indices.exists(es_index):
|
if not es.indices.exists(es_index):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
@ -50,16 +51,16 @@ def check_empty_data(es: Elasticsearch, company_id: str, event_type: str) -> boo
|
|||||||
def search_company_events(
|
def search_company_events(
|
||||||
es: Elasticsearch,
|
es: Elasticsearch,
|
||||||
company_id: Union[str, Sequence[str]],
|
company_id: Union[str, Sequence[str]],
|
||||||
event_type: str,
|
event_type: EventType,
|
||||||
body: dict,
|
body: dict,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
es_index = get_index_name(company_id, event_type)
|
es_index = get_index_name(company_id, event_type.value)
|
||||||
return es.search(index=es_index, body=body, **kwargs)
|
return es.search(index=es_index, body=body, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def delete_company_events(
|
def delete_company_events(
|
||||||
es: Elasticsearch, company_id: str, event_type: str, body: dict, **kwargs
|
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||||
) -> dict:
|
) -> dict:
|
||||||
es_index = get_index_name(company_id, event_type)
|
es_index = get_index_name(company_id, event_type.value)
|
||||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||||
|
@ -10,7 +10,12 @@ from elasticsearch import Elasticsearch
|
|||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
|
|
||||||
from apiserver.apierrors import errors
|
from apiserver.apierrors import errors
|
||||||
from apiserver.bll.event.event_common import EventType, get_index_name, EventSettings
|
from apiserver.bll.event.event_common import (
|
||||||
|
EventType,
|
||||||
|
EventSettings,
|
||||||
|
search_company_events,
|
||||||
|
check_empty_data,
|
||||||
|
)
|
||||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||||
from apiserver.config_repo import config
|
from apiserver.config_repo import config
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
@ -36,31 +41,40 @@ class EventMetrics:
|
|||||||
The amount of points in each histogram should not exceed
|
The amount of points in each histogram should not exceed
|
||||||
the requested samples
|
the requested samples
|
||||||
"""
|
"""
|
||||||
es_index = get_index_name(company_id, "training_stats_scalar")
|
event_type = EventType.metrics_scalar
|
||||||
if not self.es.indices.exists(es_index):
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
return self._get_scalar_average_per_iter_core(
|
return self._get_scalar_average_per_iter_core(
|
||||||
task_id, es_index, samples, ScalarKey.resolve(key)
|
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_scalar_average_per_iter_core(
|
def _get_scalar_average_per_iter_core(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
es_index: str,
|
company_id: str,
|
||||||
|
event_type: EventType,
|
||||||
samples: int,
|
samples: int,
|
||||||
key: ScalarKey,
|
key: ScalarKey,
|
||||||
run_parallel: bool = True,
|
run_parallel: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
intervals = self._get_task_metric_intervals(
|
intervals = self._get_task_metric_intervals(
|
||||||
es_index=es_index, task_id=task_id, samples=samples, field=key.field
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
|
task_id=task_id,
|
||||||
|
samples=samples,
|
||||||
|
field=key.field,
|
||||||
)
|
)
|
||||||
if not intervals:
|
if not intervals:
|
||||||
return {}
|
return {}
|
||||||
interval_groups = self._group_task_metric_intervals(intervals)
|
interval_groups = self._group_task_metric_intervals(intervals)
|
||||||
|
|
||||||
get_scalar_average = partial(
|
get_scalar_average = partial(
|
||||||
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
|
self._get_scalar_average,
|
||||||
|
task_id=task_id,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
|
key=key,
|
||||||
)
|
)
|
||||||
if run_parallel:
|
if run_parallel:
|
||||||
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
with ThreadPoolExecutor(max_workers=EventSettings.max_workers) as pool:
|
||||||
@ -110,13 +124,15 @@ class EventMetrics:
|
|||||||
"only tasks from the same company are supported"
|
"only tasks from the same company are supported"
|
||||||
)
|
)
|
||||||
|
|
||||||
es_index = get_index_name(next(iter(companies)), "training_stats_scalar")
|
event_type = EventType.metrics_scalar
|
||||||
if not self.es.indices.exists(es_index):
|
company_id = next(iter(companies))
|
||||||
|
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
get_scalar_average_per_iter = partial(
|
get_scalar_average_per_iter = partial(
|
||||||
self._get_scalar_average_per_iter_core,
|
self._get_scalar_average_per_iter_core,
|
||||||
es_index=es_index,
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
samples=samples,
|
samples=samples,
|
||||||
key=ScalarKey.resolve(key),
|
key=ScalarKey.resolve(key),
|
||||||
run_parallel=False,
|
run_parallel=False,
|
||||||
@ -175,7 +191,12 @@ class EventMetrics:
|
|||||||
return metric_interval_groups
|
return metric_interval_groups
|
||||||
|
|
||||||
def _get_task_metric_intervals(
|
def _get_task_metric_intervals(
|
||||||
self, es_index, task_id: str, samples: int, field: str = "iter"
|
self,
|
||||||
|
company_id: str,
|
||||||
|
event_type: EventType,
|
||||||
|
task_id: str,
|
||||||
|
samples: int,
|
||||||
|
field: str = "iter",
|
||||||
) -> Sequence[MetricInterval]:
|
) -> Sequence[MetricInterval]:
|
||||||
"""
|
"""
|
||||||
Calculate interval per task metric variant so that the resulting
|
Calculate interval per task metric variant so that the resulting
|
||||||
@ -212,7 +233,9 @@ class EventMetrics:
|
|||||||
}
|
}
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||||
|
)
|
||||||
|
|
||||||
aggs_result = es_res.get("aggregations")
|
aggs_result = es_res.get("aggregations")
|
||||||
if not aggs_result:
|
if not aggs_result:
|
||||||
@ -255,7 +278,8 @@ class EventMetrics:
|
|||||||
self,
|
self,
|
||||||
metrics_interval: MetricIntervalGroup,
|
metrics_interval: MetricIntervalGroup,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
es_index: str,
|
company_id: str,
|
||||||
|
event_type: EventType,
|
||||||
key: ScalarKey,
|
key: ScalarKey,
|
||||||
) -> Sequence[MetricData]:
|
) -> Sequence[MetricData]:
|
||||||
"""
|
"""
|
||||||
@ -283,7 +307,11 @@ class EventMetrics:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
aggs_result = self._query_aggregation_for_task_metrics(
|
aggs_result = self._query_aggregation_for_task_metrics(
|
||||||
es_index, aggs=aggs, task_id=task_id, metrics=metrics
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
|
aggs=aggs,
|
||||||
|
task_id=task_id,
|
||||||
|
metrics=metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not aggs_result:
|
if not aggs_result:
|
||||||
@ -314,7 +342,8 @@ class EventMetrics:
|
|||||||
|
|
||||||
def _query_aggregation_for_task_metrics(
|
def _query_aggregation_for_task_metrics(
|
||||||
self,
|
self,
|
||||||
es_index: str,
|
company_id: str,
|
||||||
|
event_type: EventType,
|
||||||
aggs: dict,
|
aggs: dict,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
metrics: Sequence[Tuple[str, str]],
|
metrics: Sequence[Tuple[str, str]],
|
||||||
@ -345,7 +374,9 @@ class EventMetrics:
|
|||||||
}
|
}
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||||
|
)
|
||||||
|
|
||||||
return es_res.get("aggregations")
|
return es_res.get("aggregations")
|
||||||
|
|
||||||
@ -356,30 +387,26 @@ class EventMetrics:
|
|||||||
For the requested tasks return all the metrics that
|
For the requested tasks return all the metrics that
|
||||||
reported events of the requested types
|
reported events of the requested types
|
||||||
"""
|
"""
|
||||||
es_index = get_index_name(company_id, event_type.value)
|
if check_empty_data(self.es, company_id, event_type):
|
||||||
if not self.es.indices.exists(es_index):
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
with ThreadPoolExecutor(EventSettings.max_workers) as pool:
|
||||||
res = pool.map(
|
res = pool.map(
|
||||||
partial(
|
partial(
|
||||||
self._get_task_metrics, es_index=es_index, event_type=event_type,
|
self._get_task_metrics,
|
||||||
|
company_id=company_id,
|
||||||
|
event_type=event_type,
|
||||||
),
|
),
|
||||||
task_ids,
|
task_ids,
|
||||||
)
|
)
|
||||||
return list(zip(task_ids, res))
|
return list(zip(task_ids, res))
|
||||||
|
|
||||||
def _get_task_metrics(self, task_id, es_index, event_type: EventType) -> Sequence:
|
def _get_task_metrics(
|
||||||
|
self, task_id: str, company_id: str, event_type: EventType
|
||||||
|
) -> Sequence:
|
||||||
es_req = {
|
es_req = {
|
||||||
"size": 0,
|
"size": 0,
|
||||||
"query": {
|
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||||
"bool": {
|
|
||||||
"must": [
|
|
||||||
{"term": {"task": task_id}},
|
|
||||||
{"term": {"type": event_type.value}},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"aggs": {
|
"aggs": {
|
||||||
"metrics": {
|
"metrics": {
|
||||||
"terms": {
|
"terms": {
|
||||||
@ -392,7 +419,9 @@ class EventMetrics:
|
|||||||
}
|
}
|
||||||
|
|
||||||
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
|
||||||
es_res = self.es.search(index=es_index, body=es_req)
|
es_res = search_company_events(
|
||||||
|
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
metric["key"]
|
metric["key"]
|
||||||
|
@ -3,7 +3,11 @@ from typing import Optional, Tuple, Sequence
|
|||||||
import attr
|
import attr
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
|
|
||||||
from apiserver.bll.event.event_common import check_empty_data, search_company_events
|
from apiserver.bll.event.event_common import (
|
||||||
|
check_empty_data,
|
||||||
|
search_company_events,
|
||||||
|
EventType,
|
||||||
|
)
|
||||||
from apiserver.database.errors import translate_errors_context
|
from apiserver.database.errors import translate_errors_context
|
||||||
from apiserver.timing_context import TimingContext
|
from apiserver.timing_context import TimingContext
|
||||||
|
|
||||||
@ -16,7 +20,7 @@ class TaskEventsResult:
|
|||||||
|
|
||||||
|
|
||||||
class LogEventsIterator:
|
class LogEventsIterator:
|
||||||
EVENT_TYPE = "log"
|
EVENT_TYPE = EventType.task_log
|
||||||
|
|
||||||
def __init__(self, es: Elasticsearch):
|
def __init__(self, es: Elasticsearch):
|
||||||
self.es = es
|
self.es = es
|
||||||
@ -75,7 +79,6 @@ class LogEventsIterator:
|
|||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
event_type=self.EVENT_TYPE,
|
event_type=self.EVENT_TYPE,
|
||||||
body=es_req,
|
body=es_req,
|
||||||
routing=task_id,
|
|
||||||
)
|
)
|
||||||
hits = es_result["hits"]["hits"]
|
hits = es_result["hits"]["hits"]
|
||||||
hits_total = es_result["hits"]["total"]["value"]
|
hits_total = es_result["hits"]["total"]["value"]
|
||||||
@ -102,7 +105,6 @@ class LogEventsIterator:
|
|||||||
company_id=company_id,
|
company_id=company_id,
|
||||||
event_type=self.EVENT_TYPE,
|
event_type=self.EVENT_TYPE,
|
||||||
body=es_req,
|
body=es_req,
|
||||||
routing=task_id,
|
|
||||||
)
|
)
|
||||||
last_second_hits = es_result["hits"]["hits"]
|
last_second_hits = es_result["hits"]["hits"]
|
||||||
if not last_second_hits or len(last_second_hits) < 2:
|
if not last_second_hits or len(last_second_hits) < 2:
|
||||||
|
@ -32,6 +32,7 @@ from furl import furl
|
|||||||
from mongoengine import Q
|
from mongoengine import Q
|
||||||
|
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
|
from apiserver.bll.event.event_common import EventType
|
||||||
from apiserver.bll.task.artifacts import get_artifact_id
|
from apiserver.bll.task.artifacts import get_artifact_id
|
||||||
from apiserver.bll.task.param_utils import (
|
from apiserver.bll.task.param_utils import (
|
||||||
split_param_name,
|
split_param_name,
|
||||||
@ -532,19 +533,22 @@ class PrePopulate:
|
|||||||
scroll_id = None
|
scroll_id = None
|
||||||
while True:
|
while True:
|
||||||
res = cls.event_bll.get_task_events(
|
res = cls.event_bll.get_task_events(
|
||||||
task.company, task.id, scroll_id=scroll_id
|
company_id=task.company,
|
||||||
|
task_id=task.id,
|
||||||
|
event_type=EventType.all,
|
||||||
|
scroll_id=scroll_id,
|
||||||
)
|
)
|
||||||
if not res.events:
|
if not res.events:
|
||||||
break
|
break
|
||||||
scroll_id = res.next_scroll_id
|
scroll_id = res.next_scroll_id
|
||||||
for event in res.events:
|
for event in res.events:
|
||||||
event_type = event.get("type")
|
event_type = event.get("type")
|
||||||
if event_type == "training_debug_image":
|
if event_type == EventType.metrics_image.value:
|
||||||
url = cls._get_fixed_url(event.get("url"))
|
url = cls._get_fixed_url(event.get("url"))
|
||||||
if url:
|
if url:
|
||||||
event["url"] = url
|
event["url"] = url
|
||||||
artifacts.append(url)
|
artifacts.append(url)
|
||||||
elif event_type == "plot":
|
elif event_type == EventType.metrics_plot.value:
|
||||||
plot_str: str = event.get("plot_str", "")
|
plot_str: str = event.get("plot_str", "")
|
||||||
for match in cls.img_source_regex.findall(plot_str):
|
for match in cls.img_source_regex.findall(plot_str):
|
||||||
url = cls._get_fixed_url(match)
|
url = cls._get_fixed_url(match)
|
||||||
|
@ -19,7 +19,7 @@ from apiserver.apimodels.events import (
|
|||||||
NextDebugImageSampleRequest,
|
NextDebugImageSampleRequest,
|
||||||
)
|
)
|
||||||
from apiserver.bll.event import EventBLL
|
from apiserver.bll.event import EventBLL
|
||||||
from apiserver.bll.event.event_common import get_index_name
|
from apiserver.bll.event.event_common import EventType
|
||||||
from apiserver.bll.task import TaskBLL
|
from apiserver.bll.task import TaskBLL
|
||||||
from apiserver.service_repo import APICall, endpoint
|
from apiserver.service_repo import APICall, endpoint
|
||||||
from apiserver.utilities import json
|
from apiserver.utilities import json
|
||||||
@ -63,7 +63,7 @@ def get_task_log_v1_5(call, company_id, _):
|
|||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
order,
|
order,
|
||||||
event_type="log",
|
event_type=EventType.task_log,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
)
|
)
|
||||||
@ -90,7 +90,7 @@ def get_task_log_v1_7(call, company_id, _):
|
|||||||
company_id=task.get_index_company(),
|
company_id=task.get_index_company(),
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
order=scroll_order,
|
order=scroll_order,
|
||||||
event_type="log",
|
event_type=EventType.task_log,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
)
|
)
|
||||||
@ -118,11 +118,9 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
|||||||
from_timestamp=request.from_timestamp,
|
from_timestamp=request.from_timestamp,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if request.order and (
|
||||||
request.order and (
|
|
||||||
(request.navigate_earlier and request.order == LogOrderEnum.asc)
|
(request.navigate_earlier and request.order == LogOrderEnum.asc)
|
||||||
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
|
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
|
||||||
)
|
|
||||||
):
|
):
|
||||||
res.events.reverse()
|
res.events.reverse()
|
||||||
|
|
||||||
@ -182,7 +180,7 @@ def download_task_log(call, company_id, _):
|
|||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
order="asc",
|
order="asc",
|
||||||
event_type="log",
|
event_type=EventType.task_log,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
)
|
)
|
||||||
@ -219,7 +217,7 @@ def get_vector_metrics_and_variants(call, company_id, _):
|
|||||||
)[0]
|
)[0]
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
metrics=event_bll.get_metrics_and_variants(
|
metrics=event_bll.get_metrics_and_variants(
|
||||||
task.get_index_company(), task_id, "training_stats_vector"
|
task.get_index_company(), task_id, EventType.metrics_vector
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -232,7 +230,7 @@ def get_scalar_metrics_and_variants(call, company_id, _):
|
|||||||
)[0]
|
)[0]
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
metrics=event_bll.get_metrics_and_variants(
|
metrics=event_bll.get_metrics_and_variants(
|
||||||
task.get_index_company(), task_id, "training_stats_scalar"
|
task.get_index_company(), task_id, EventType.metrics_scalar
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -272,7 +270,7 @@ def get_task_events(call, company_id, _):
|
|||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
sort=[{"timestamp": {"order": order}}],
|
sort=[{"timestamp": {"order": order}}],
|
||||||
event_type=event_type,
|
event_type=EventType(event_type) if event_type else EventType.all,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
size=batch_size,
|
size=batch_size,
|
||||||
)
|
)
|
||||||
@ -297,7 +295,7 @@ def get_scalar_metric_data(call, company_id, _):
|
|||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
event_type="training_stats_scalar",
|
event_type=EventType.metrics_scalar,
|
||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
metric=metric,
|
metric=metric,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
@ -321,8 +319,9 @@ def get_task_latest_scalar_values(call, company_id, _):
|
|||||||
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
|
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
|
||||||
index_company, task_id
|
index_company, task_id
|
||||||
)
|
)
|
||||||
es_index = get_index_name(index_company, "*")
|
last_iters = event_bll.get_last_iters(
|
||||||
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
|
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
|
||||||
|
)
|
||||||
call.result.data = dict(
|
call.result.data = dict(
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
last_iter=last_iters[0] if last_iters else 0,
|
last_iter=last_iters[0] if last_iters else 0,
|
||||||
@ -344,7 +343,10 @@ def scalar_metrics_iter_histogram(
|
|||||||
company_id, request.task, allow_public=True, only=("company", "company_origin")
|
company_id, request.task, allow_public=True, only=("company", "company_origin")
|
||||||
)[0]
|
)[0]
|
||||||
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
|
||||||
task.get_index_company(), task_id=request.task, samples=request.samples, key=request.key
|
task.get_index_company(),
|
||||||
|
task_id=request.task,
|
||||||
|
samples=request.samples,
|
||||||
|
key=request.key,
|
||||||
)
|
)
|
||||||
call.result.data = metrics
|
call.result.data = metrics
|
||||||
|
|
||||||
@ -394,7 +396,7 @@ def get_multi_task_plots_v1_7(call, company_id, _):
|
|||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
next(iter(companies)),
|
next(iter(companies)),
|
||||||
task_ids,
|
task_ids,
|
||||||
event_type="plot",
|
event_type=EventType.metrics_plot,
|
||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
size=10000,
|
size=10000,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
@ -436,7 +438,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
|||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
next(iter(companies)),
|
next(iter(companies)),
|
||||||
task_ids,
|
task_ids,
|
||||||
event_type="plot",
|
event_type=EventType.metrics_plot,
|
||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
last_iter_count=iters,
|
last_iter_count=iters,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
@ -476,7 +478,7 @@ def get_task_plots_v1_7(call, company_id, _):
|
|||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
event_type="plot",
|
event_type=EventType.metrics_plot,
|
||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
size=10000,
|
size=10000,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
@ -539,7 +541,7 @@ def get_debug_images_v1_7(call, company_id, _):
|
|||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
event_type="training_debug_image",
|
event_type=EventType.metrics_image,
|
||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
size=10000,
|
size=10000,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
@ -568,7 +570,7 @@ def get_debug_images_v1_8(call, company_id, _):
|
|||||||
result = event_bll.get_task_events(
|
result = event_bll.get_task_events(
|
||||||
task.get_index_company(),
|
task.get_index_company(),
|
||||||
task_id,
|
task_id,
|
||||||
event_type="training_debug_image",
|
event_type=EventType.metrics_image,
|
||||||
sort=[{"iter": {"order": "desc"}}],
|
sort=[{"iter": {"order": "desc"}}],
|
||||||
last_iter_count=iters,
|
last_iter_count=iters,
|
||||||
scroll_id=scroll_id,
|
scroll_id=scroll_id,
|
||||||
@ -594,7 +596,10 @@ def get_debug_images_v1_8(call, company_id, _):
|
|||||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||||
task_ids = {m.task for m in request.metrics}
|
task_ids = {m.task for m in request.metrics}
|
||||||
tasks = task_bll.assert_exists(
|
tasks = task_bll.assert_exists(
|
||||||
company_id, task_ids=task_ids, allow_public=True, only=("company", "company_origin")
|
company_id,
|
||||||
|
task_ids=task_ids,
|
||||||
|
allow_public=True,
|
||||||
|
only=("company", "company_origin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
companies = {t.get_index_company() for t in tasks}
|
companies = {t.get_index_company() for t in tasks}
|
||||||
@ -662,7 +667,7 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
|
|||||||
company_id=task.company,
|
company_id=task.company,
|
||||||
task=request.task,
|
task=request.task,
|
||||||
state_id=request.scroll_id,
|
state_id=request.scroll_id,
|
||||||
navigate_earlier=request.navigate_earlier
|
navigate_earlier=request.navigate_earlier,
|
||||||
)
|
)
|
||||||
call.result.data = attr.asdict(res, recurse=False)
|
call.result.data = attr.asdict(res, recurse=False)
|
||||||
|
|
||||||
@ -670,7 +675,10 @@ def next_debug_image_sample(call, company_id, request: NextDebugImageSampleReque
|
|||||||
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
|
||||||
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
|
||||||
task = task_bll.assert_exists(
|
task = task_bll.assert_exists(
|
||||||
company_id, task_ids=request.tasks, allow_public=True, only=("company", "company_origin")
|
company_id,
|
||||||
|
task_ids=request.tasks,
|
||||||
|
allow_public=True,
|
||||||
|
only=("company", "company_origin"),
|
||||||
)[0]
|
)[0]
|
||||||
res = event_bll.metrics.get_tasks_metrics(
|
res = event_bll.metrics.get_tasks_metrics(
|
||||||
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
|
task.get_index_company(), task_ids=request.tasks, event_type=request.event_type
|
||||||
|
Loading…
Reference in New Issue
Block a user