Support plots navigation by iteration

This commit is contained in:
allegroai 2022-11-29 17:34:57 +02:00
parent c23e8a90d0
commit caaf801cd0
6 changed files with 48 additions and 6 deletions

View File

@ -79,6 +79,7 @@ class NextHistorySampleRequest(Base):
task: str = StringField(required=True)
scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True)
next_iteration: bool = BoolField(default=False)
model_events: bool = BoolField(default=False)

View File

@ -42,6 +42,7 @@ class StartedResponse(UpdateResponse):
class EnqueueResponse(UpdateResponse):
queued = IntField()
queue_watched = BoolField()
class EnqueueBatchItem(UpdateBatchItem):
@ -50,6 +51,7 @@ class EnqueueBatchItem(UpdateBatchItem):
class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
queue_watched = BoolField()
class DequeueResponse(UpdateResponse):
@ -97,6 +99,7 @@ class UpdateRequest(TaskUpdateRequest):
class EnqueueRequest(UpdateRequest):
queue = StringField()
queue_name = StringField()
verify_watched_queue = BoolField(default=False)
class DeleteRequest(UpdateRequest):
@ -275,6 +278,7 @@ class EnqueueManyRequest(TaskBatchRequest):
queue = StringField()
queue_name = StringField()
validate_tasks = BoolField(default=False)
verify_watched_queue = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest):

View File

@ -63,7 +63,12 @@ class HistorySampleIterator(abc.ABC):
)
def get_next_sample(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool
self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> HistorySampleResult:
"""
Get the sample for next/prev variant on the current iteration
@ -77,11 +82,19 @@ class HistorySampleIterator(abc.ABC):
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res
event = self._get_next_for_current_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
) or self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
if next_iteration:
event = self._get_next_for_another_iteration(
company_id=company_id, navigate_earlier=navigate_earlier, state=state
)
else:
# noinspection PyArgumentList
event = first(
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
for f in (
self._get_next_for_current_iteration,
self._get_next_for_another_iteration,
)
)
if not event:
return res

View File

@ -192,6 +192,13 @@ class QueueBLL(object):
ret_params=ret_params,
)
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
for worker in self.worker_bll.get_all(company_id):
if queue_id in worker.queues:
return True
return False
def get_queue_infos(
self,
company_id: str,

View File

@ -860,6 +860,7 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
)
call.result.data = attr.asdict(res, recurse=False)
@ -896,6 +897,7 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task=request.task,
state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
)
call.result.data = attr.asdict(res, recurse=False)

View File

@ -137,6 +137,21 @@ class TestTaskPlots(TestService):
)
self._assertEqualEvent(res.event, events[-3])
# next_iteration
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True
)
self._assertEqualEvent(res.event, None)
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
)
self._assertEqualEvent(res.event, events[-2])
self.assertEqual(res.event.iter, 1)
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
)
self._assertEqualEvent(res.event, None)
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]):
if ev2 is None:
self.assertIsNone(ev1)