mirror of
https://github.com/clearml/clearml-server
synced 2025-03-03 10:43:10 +00:00
Support task log retrieval with no scroll
This commit is contained in:
parent
d9bdebefc7
commit
35a11db58e
@ -13,6 +13,7 @@ from luqum.parser import parser, ParseError
|
||||
from validators import email as email_validator, domain as domain_validator
|
||||
|
||||
from apierrors import errors
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
@ -213,3 +214,12 @@ class StringEnum(Enum):
|
||||
# noinspection PyMethodParameters
|
||||
def _generate_next_value_(name, start, count, last_values):
|
||||
return name
|
||||
|
||||
|
||||
class JsonSerializableMixin:
|
||||
def to_json(self: ModelBase):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type[ModelBase], s):
|
||||
return cls(**loads(s))
|
||||
|
@ -40,6 +40,14 @@ class DebugImagesRequest(Base):
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
@ -13,7 +12,7 @@ from jsonmodels.fields import (
|
||||
)
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import make_default, ListField, EnumField
|
||||
from apimodels import make_default, ListField, EnumField, JsonSerializableMixin
|
||||
|
||||
DEFAULT_TIMEOUT = 10 * 60
|
||||
|
||||
@ -61,7 +60,7 @@ class IdNameEntry(Base):
|
||||
name = StringField()
|
||||
|
||||
|
||||
class WorkerEntry(Base):
|
||||
class WorkerEntry(Base, JsonSerializableMixin):
|
||||
key = StringField() # not required due to migration issues
|
||||
id = StringField(required=True)
|
||||
user = EmbeddedField(IdNameEntry)
|
||||
@ -75,13 +74,6 @@ class WorkerEntry(Base):
|
||||
last_activity_time = DateTimeField(required=True)
|
||||
last_report_time = DateTimeField()
|
||||
|
||||
def to_json(self):
|
||||
return json.dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
return cls(**json.loads(s))
|
||||
|
||||
|
||||
class CurrentTaskEntry(IdNameEntry):
|
||||
running_time = IntField()
|
||||
|
@ -3,27 +3,25 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from operator import attrgetter, itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from boltons.iterutils import bucketize
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.fields import StringField, ListField, IntField
|
||||
|
||||
from database.model.task.metrics import MetricEventStats
|
||||
from database.model.task.task import Task
|
||||
from timing_context import TimingContext
|
||||
from utilities.json import loads, dumps
|
||||
|
||||
|
||||
class VariantScrollState(Base):
|
||||
@ -45,17 +43,10 @@ class MetricScrollState(Base):
|
||||
self.last_min_iter = self.last_max_iter = None
|
||||
|
||||
|
||||
class DebugImageEventsScrollState(Base):
|
||||
class DebugImageEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
metrics: Sequence[MetricScrollState] = ListField([MetricScrollState])
|
||||
|
||||
def to_json(self):
|
||||
return dumps(self.to_struct())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, s):
|
||||
return cls(**loads(s))
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class DebugImagesResult(object):
|
||||
@ -65,7 +56,12 @@ class DebugImagesResult(object):
|
||||
|
||||
class DebugImagesIterator:
|
||||
EVENT_TYPE = "training_debug_image"
|
||||
STATE_EXPIRATION_SECONDS = 3600
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
@property
|
||||
def _max_workers(self):
|
||||
@ -76,7 +72,7 @@ class DebugImagesIterator:
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=DebugImageEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.STATE_EXPIRATION_SECONDS,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
@ -92,27 +88,31 @@ class DebugImagesIterator:
|
||||
if not self.es.indices.exists(es_index):
|
||||
return DebugImagesResult()
|
||||
|
||||
unique_metrics = set(metrics)
|
||||
state = self.cache_manager.get_state(state_id) if state_id else None
|
||||
if not state:
|
||||
state = DebugImageEventsScrollState(
|
||||
id=database.utils.id(),
|
||||
metrics=self._init_metric_states(es_index, list(unique_metrics)),
|
||||
)
|
||||
else:
|
||||
state_metrics = set((m.task, m.name) for m in state.metrics)
|
||||
if state_metrics != unique_metrics:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"while getting debug images events", scroll_id=state_id
|
||||
)
|
||||
def init_state(state_: DebugImageEventsScrollState):
|
||||
unique_metrics = set(metrics)
|
||||
state_.metrics = self._init_metric_states(es_index, list(unique_metrics))
|
||||
|
||||
def validate_state(state_: DebugImageEventsScrollState):
|
||||
"""
|
||||
Validate that the metrics stored in the state are the same
|
||||
as requested in the current call.
|
||||
Refresh the state if requested
|
||||
"""
|
||||
state_metrics = set((m.task, m.name) for m in state_.metrics)
|
||||
if state_metrics != set(metrics):
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task metrics stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state)
|
||||
for metric_state in state.metrics:
|
||||
self._reinit_outdated_metric_states(company_id, es_index, state_)
|
||||
for metric_state in state_.metrics:
|
||||
metric_state.reset()
|
||||
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
try:
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
) as state:
|
||||
res = DebugImagesResult(next_scroll_id=state.id)
|
||||
with ThreadPoolExecutor(self._max_workers) as pool:
|
||||
res.metric_events = list(
|
||||
pool.map(
|
||||
@ -125,10 +125,8 @@ class DebugImagesIterator:
|
||||
state.metrics,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self.cache_manager.set_state(state)
|
||||
|
||||
return res
|
||||
return res
|
||||
|
||||
def _reinit_outdated_metric_states(
|
||||
self, company_id, es_index, state: DebugImageEventsScrollState
|
||||
|
@ -5,7 +5,6 @@ from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from mongoengine import Q
|
||||
@ -16,6 +15,7 @@ import es_factory
|
||||
from apierrors import errors
|
||||
from bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from bll.event.event_metrics import EventMetrics, EventType
|
||||
from bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
@ -29,13 +29,6 @@ EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult(object):
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.ib(factory=list)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ("task", "iter", "metric", "variant", "key")
|
||||
|
||||
@ -47,6 +40,7 @@ class EventBLL(object):
|
||||
)
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
|
169
server/bll/event/log_events_iterator.py
Normal file
169
server/bll/event/log_events_iterator.py
Normal file
@ -0,0 +1,169 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
from jsonmodels.fields import StringField, IntField
|
||||
from jsonmodels.models import Base
|
||||
from redis import StrictRedis
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels import JsonSerializableMixin
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
from bll.redis_cache_manager import RedisCacheManager
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
class LogEventsScrollState(Base, JsonSerializableMixin):
|
||||
id: str = StringField(required=True)
|
||||
task: str = StringField(required=True)
|
||||
last_min_timestamp: Optional[int] = IntField()
|
||||
last_max_timestamp: Optional[int] = IntField()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the scrolling state """
|
||||
self.last_min_timestamp = self.last_max_timestamp = None
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = "log"
|
||||
|
||||
@property
|
||||
def state_expiration_sec(self):
|
||||
return config.get(
|
||||
f"services.events.events_retrieval.state_expiration_sec", 3600
|
||||
)
|
||||
|
||||
def __init__(self, redis: StrictRedis, es: Elasticsearch):
|
||||
self.es = es
|
||||
self.cache_manager = RedisCacheManager(
|
||||
state_class=LogEventsScrollState,
|
||||
redis=redis,
|
||||
expiration_interval=self.state_expiration_sec,
|
||||
)
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> TaskEventsResult:
|
||||
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
def init_state(state_: LogEventsScrollState):
|
||||
state_.task = task_id
|
||||
|
||||
def validate_state(state_: LogEventsScrollState):
|
||||
"""
|
||||
Checks that the task id stored in the state
|
||||
is equal to the one passed with the current call
|
||||
Refresh the state if requested
|
||||
"""
|
||||
if state_.task != task_id:
|
||||
raise errors.bad_request.InvalidScrollId(
|
||||
"Task stored in the state does not match the passed one",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
state_.reset()
|
||||
|
||||
with self.cache_manager.get_or_create_state(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state,
|
||||
) as state:
|
||||
res = TaskEventsResult(next_scroll_id=state.id)
|
||||
res.events, res.total_events = self._get_events(
|
||||
es_index=es_index,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
state=state,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
es_index,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
state: LogEventsScrollState,
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": state.task}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if navigate_earlier and state.last_min_timestamp is not None:
|
||||
es_req["search_after"] = [state.last_min_timestamp]
|
||||
elif not navigate_earlier and state.last_max_timestamp is not None:
|
||||
es_req["search_after"] = [state.last_max_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
if navigate_earlier:
|
||||
state.last_max_timestamp = events[0]["timestamp"]
|
||||
state.last_min_timestamp = events[-1]["timestamp"]
|
||||
else:
|
||||
state.last_min_timestamp = events[0]["timestamp"]
|
||||
state.last_max_timestamp = events[-1]["timestamp"]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": state.task}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
|
||||
hits = es_result["hits"]["hits"]
|
||||
if not hits or len(hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
|
||||
already_present_ids = set(ev["_id"] for ev in events)
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[
|
||||
*events,
|
||||
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
|
||||
],
|
||||
hits_total,
|
||||
)
|
@ -1,15 +1,21 @@
|
||||
from typing import Optional, TypeVar, Generic, Type
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, TypeVar, Generic, Type, Callable
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
import database
|
||||
from timing_context import TimingContext
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _do_nothing(_: T):
|
||||
return
|
||||
|
||||
|
||||
class RedisCacheManager(Generic[T]):
|
||||
"""
|
||||
Class for store/retreive of state objects from redis
|
||||
Class for store/retrieve of state objects from redis
|
||||
|
||||
self.state_class - class of the state
|
||||
self.redis - instance of redis
|
||||
@ -42,3 +48,32 @@ class RedisCacheManager(Generic[T]):
|
||||
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
):
|
||||
"""
|
||||
Try to retrieve state with the given id from the Redis cache if yes then validates it
|
||||
If no then create a new one with randomly generated id
|
||||
Yield the state and write it back to redis once the user code block exits
|
||||
:param state_id: id of the state to retrieve
|
||||
:param init_state: user callback to init the newly created state
|
||||
If not passed then no init except for the id generation is done
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
try:
|
||||
yield state
|
||||
finally:
|
||||
self.set_state(state)
|
||||
|
@ -6,4 +6,8 @@ ignore_iteration {
|
||||
|
||||
# max number of concurrent queries to ES when calculating events metrics
|
||||
# should not exceed the amount of concurrent connections set in the ES driver
|
||||
max_metrics_concurrency: 4
|
||||
max_metrics_concurrency: 4
|
||||
|
||||
events_retrieval {
|
||||
state_expiration_sec: 3600
|
||||
}
|
||||
|
@ -258,7 +258,6 @@
|
||||
properties {
|
||||
added { type: integer }
|
||||
errors { type: integer }
|
||||
errors_info { type: object }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ from apimodels.events import (
|
||||
MetricEvents,
|
||||
IterationEvents,
|
||||
TaskMetricsRequest,
|
||||
LogEventsRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.event.event_metrics import EventMetrics
|
||||
@ -45,7 +46,7 @@ def add_batch(call: APICall, company_id, req_model):
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
def get_task_log(call, company_id, req_model):
|
||||
def get_task_log_v1_5(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
order = call.data.get("order") or "desc"
|
||||
@ -93,6 +94,28 @@ def get_task_log_v1_7(call, company_id, req_model):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
|
||||
def get_task_log(call, company_id, req_model: LogEventsRequest):
|
||||
task_id = req_model.task
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
|
||||
res = event_bll.log_events_iterator.get_task_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=req_model.batch_size,
|
||||
navigate_earlier=req_model.navigate_earlier,
|
||||
refresh=req_model.refresh,
|
||||
state_id=req_model.scroll_id,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
events=res.events,
|
||||
returned=len(res.events),
|
||||
total=res.total_events,
|
||||
scroll_id=res.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.download_task_log", required_fields=["task"])
|
||||
def download_task_log(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
|
@ -210,7 +210,7 @@ def get_all_ex(call: APICall):
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(EntityVisibility.archived.value)
|
||||
for result in Task.aggregate(status_count_pipeline):
|
||||
for result in Task.aggregate(*status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
EntityVisibility.archived if k else EntityVisibility.active
|
||||
@ -224,7 +224,7 @@ def get_all_ex(call: APICall):
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.aggregate(runtime_pipeline)
|
||||
for result in Task.aggregate(*runtime_pipeline)
|
||||
}
|
||||
|
||||
def safe_get(obj, path, default=None):
|
||||
|
@ -2,7 +2,7 @@
|
||||
Comprehensive test of all(?) use cases of datasets and frames
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
import operator
|
||||
import unittest
|
||||
from functools import partial
|
||||
from statistics import mean
|
||||
@ -22,21 +22,17 @@ class TestTaskEvents(TestService):
|
||||
)
|
||||
return self.create_temp("tasks", **task_input)
|
||||
|
||||
def _create_task_event(self, type_, task, iteration, **kwargs):
|
||||
@staticmethod
|
||||
def _create_task_event(type_, task, iteration, **kwargs):
|
||||
return {
|
||||
"worker": "test",
|
||||
"type": type_,
|
||||
"task": task,
|
||||
"iter": iteration,
|
||||
"timestamp": es_factory.get_timestamp_millis(),
|
||||
"timestamp": kwargs.get("timestamp") or es_factory.get_timestamp_millis(),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _copy_and_update(self, src_obj, new_data):
|
||||
obj = src_obj.copy()
|
||||
obj.update(new_data)
|
||||
return obj
|
||||
|
||||
def test_task_metrics(self):
|
||||
tasks = {
|
||||
self._temp_task(): {
|
||||
@ -83,8 +79,7 @@ class TestTaskEvents(TestService):
|
||||
|
||||
# test empty
|
||||
res = self.api.events.debug_images(
|
||||
metrics=[{"task": task, "metric": metric}],
|
||||
iters=5,
|
||||
metrics=[{"task": task, "metric": metric}], iters=5,
|
||||
)
|
||||
self.assertFalse(res.metrics)
|
||||
|
||||
@ -116,11 +111,11 @@ class TestTaskEvents(TestService):
|
||||
|
||||
# test forward navigation
|
||||
for page in range(3):
|
||||
scroll_id = assert_debug_images(scroll_id=scroll_id, page=page)
|
||||
scroll_id = assert_debug_images(scroll_id=scroll_id, expected_page=page)
|
||||
|
||||
# test backwards navigation
|
||||
scroll_id = assert_debug_images(
|
||||
scroll_id=scroll_id, page=0, navigate_earlier=False
|
||||
scroll_id=scroll_id, expected_page=0, navigate_earlier=False
|
||||
)
|
||||
|
||||
# beyond the latest iteration and back
|
||||
@ -131,10 +126,10 @@ class TestTaskEvents(TestService):
|
||||
navigate_earlier=False,
|
||||
)
|
||||
self.assertEqual(len(res["metrics"][0]["iterations"]), 0)
|
||||
assert_debug_images(scroll_id=scroll_id, page=1)
|
||||
assert_debug_images(scroll_id=scroll_id, expected_page=1)
|
||||
|
||||
# refresh
|
||||
assert_debug_images(scroll_id=scroll_id, page=0, refresh=True)
|
||||
assert_debug_images(scroll_id=scroll_id, expected_page=0, refresh=True)
|
||||
|
||||
def _assertDebugImages(
|
||||
self,
|
||||
@ -143,7 +138,7 @@ class TestTaskEvents(TestService):
|
||||
max_iter: int,
|
||||
unique_images: Sequence[int],
|
||||
scroll_id,
|
||||
page: int,
|
||||
expected_page: int,
|
||||
iters: int = 5,
|
||||
**extra_params,
|
||||
):
|
||||
@ -156,7 +151,7 @@ class TestTaskEvents(TestService):
|
||||
data = res["metrics"][0]
|
||||
self.assertEqual(data["task"], task)
|
||||
self.assertEqual(data["metric"], metric)
|
||||
left_iterations = max(0, max(unique_images) - page * iters)
|
||||
left_iterations = max(0, max(unique_images) - expected_page * iters)
|
||||
self.assertEqual(len(data["iterations"]), min(iters, left_iterations))
|
||||
for it in data["iterations"]:
|
||||
events_per_iter = sum(
|
||||
@ -166,26 +161,67 @@ class TestTaskEvents(TestService):
|
||||
return res.scroll_id
|
||||
|
||||
def test_task_logs(self):
|
||||
events = []
|
||||
task = self._temp_task()
|
||||
for iter_ in range(10):
|
||||
log_event = self._create_task_event("log", task, iteration=iter_)
|
||||
events.append(
|
||||
self._copy_and_update(
|
||||
log_event,
|
||||
{"msg": "This is a log message from test task iter " + str(iter_)},
|
||||
)
|
||||
timestamp = es_factory.get_timestamp_millis()
|
||||
events = [
|
||||
self._create_task_event(
|
||||
"log",
|
||||
task=task,
|
||||
iteration=iter_,
|
||||
timestamp=timestamp + iter_ * 1000,
|
||||
msg=f"This is a log message from test task iter {iter_}",
|
||||
)
|
||||
# sleep so timestamp is not the same
|
||||
time.sleep(0.01)
|
||||
for iter_ in range(10)
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
data = self.api.events.get_task_log(task=task)
|
||||
assert len(data["events"]) == 10
|
||||
# test forward navigation
|
||||
scroll_id = None
|
||||
for page in range(3):
|
||||
scroll_id = self._assert_log_events(
|
||||
task=task, scroll_id=scroll_id, expected_page=page
|
||||
)
|
||||
|
||||
self.api.tasks.reset(task=task)
|
||||
data = self.api.events.get_task_log(task=task)
|
||||
assert len(data["events"]) == 0
|
||||
# test backwards navigation
|
||||
scroll_id = self._assert_log_events(
|
||||
task=task, scroll_id=scroll_id, navigate_earlier=False
|
||||
)
|
||||
|
||||
# refresh
|
||||
self._assert_log_events(task=task, scroll_id=scroll_id)
|
||||
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
|
||||
|
||||
def _assert_log_events(
|
||||
self,
|
||||
task,
|
||||
scroll_id,
|
||||
batch_size: int = 5,
|
||||
expected_total: int = 10,
|
||||
expected_page: int = 0,
|
||||
**extra_params,
|
||||
):
|
||||
res = self.api.events.get_task_log(
|
||||
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params,
|
||||
)
|
||||
self.assertEqual(res.total, expected_total)
|
||||
expected_events = max(
|
||||
0, batch_size - max(0, (expected_page + 1) * batch_size - expected_total)
|
||||
)
|
||||
self.assertEqual(res.returned, expected_events)
|
||||
self.assertEqual(len(res.events), expected_events)
|
||||
unique_events = len({ev.iter for ev in res.events})
|
||||
self.assertEqual(len(res.events), unique_events)
|
||||
if res.events:
|
||||
cmp_operator = operator.ge
|
||||
if not extra_params.get("navigate_earlier", True):
|
||||
cmp_operator = operator.le
|
||||
self.assertTrue(
|
||||
all(
|
||||
cmp_operator(first.timestamp, second.timestamp)
|
||||
for first, second in zip(res.events, res.events[1:])
|
||||
)
|
||||
)
|
||||
return res.scroll_id
|
||||
|
||||
def test_task_metric_value_intervals_keys(self):
|
||||
metric = "Metric1"
|
||||
|
Loading…
Reference in New Issue
Block a user