Support plots navigation by iteration

This commit is contained in:
allegroai 2022-11-29 17:34:57 +02:00
parent c23e8a90d0
commit caaf801cd0
6 changed files with 48 additions and 6 deletions

View File

@ -79,6 +79,7 @@ class NextHistorySampleRequest(Base):
task: str = StringField(required=True) task: str = StringField(required=True)
scroll_id: Optional[str] = StringField() scroll_id: Optional[str] = StringField()
navigate_earlier: bool = BoolField(default=True) navigate_earlier: bool = BoolField(default=True)
next_iteration: bool = BoolField(default=False)
model_events: bool = BoolField(default=False) model_events: bool = BoolField(default=False)

View File

@ -42,6 +42,7 @@ class StartedResponse(UpdateResponse):
class EnqueueResponse(UpdateResponse): class EnqueueResponse(UpdateResponse):
queued = IntField() queued = IntField()
queue_watched = BoolField()
class EnqueueBatchItem(UpdateBatchItem): class EnqueueBatchItem(UpdateBatchItem):
@ -50,6 +51,7 @@ class EnqueueBatchItem(UpdateBatchItem):
class EnqueueManyResponse(BatchResponse): class EnqueueManyResponse(BatchResponse):
succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem) succeeded: Sequence[EnqueueBatchItem] = ListField(EnqueueBatchItem)
queue_watched = BoolField()
class DequeueResponse(UpdateResponse): class DequeueResponse(UpdateResponse):
@ -97,6 +99,7 @@ class UpdateRequest(TaskUpdateRequest):
class EnqueueRequest(UpdateRequest): class EnqueueRequest(UpdateRequest):
queue = StringField() queue = StringField()
queue_name = StringField() queue_name = StringField()
verify_watched_queue = BoolField(default=False)
class DeleteRequest(UpdateRequest): class DeleteRequest(UpdateRequest):
@ -275,6 +278,7 @@ class EnqueueManyRequest(TaskBatchRequest):
queue = StringField() queue = StringField()
queue_name = StringField() queue_name = StringField()
validate_tasks = BoolField(default=False) validate_tasks = BoolField(default=False)
verify_watched_queue = BoolField(default=False)
class DeleteManyRequest(TaskBatchRequest): class DeleteManyRequest(TaskBatchRequest):

View File

@ -63,7 +63,12 @@ class HistorySampleIterator(abc.ABC):
) )
def get_next_sample( def get_next_sample(
self, company_id: str, task: str, state_id: str, navigate_earlier: bool self,
company_id: str,
task: str,
state_id: str,
navigate_earlier: bool,
next_iteration: bool,
) -> HistorySampleResult: ) -> HistorySampleResult:
""" """
Get the sample for next/prev variant on the current iteration Get the sample for next/prev variant on the current iteration
@ -77,11 +82,19 @@ class HistorySampleIterator(abc.ABC):
if check_empty_data(self.es, company_id=company_id, event_type=self.event_type): if check_empty_data(self.es, company_id=company_id, event_type=self.event_type):
return res return res
event = self._get_next_for_current_iteration( if next_iteration:
company_id=company_id, navigate_earlier=navigate_earlier, state=state event = self._get_next_for_another_iteration(
) or self._get_next_for_another_iteration( company_id=company_id, navigate_earlier=navigate_earlier, state=state
company_id=company_id, navigate_earlier=navigate_earlier, state=state )
) else:
# noinspection PyArgumentList
event = first(
f(company_id=company_id, navigate_earlier=navigate_earlier, state=state)
for f in (
self._get_next_for_current_iteration,
self._get_next_for_another_iteration,
)
)
if not event: if not event:
return res return res

View File

@ -192,6 +192,13 @@ class QueueBLL(object):
ret_params=ret_params, ret_params=ret_params,
) )
def check_for_workers(self, company_id: str, queue_id: str) -> bool:
for worker in self.worker_bll.get_all(company_id):
if queue_id in worker.queues:
return True
return False
def get_queue_infos( def get_queue_infos(
self, self,
company_id: str, company_id: str,

View File

@ -860,6 +860,7 @@ def next_debug_image_sample(call, company_id, request: NextHistorySampleRequest)
task=request.task, task=request.task,
state_id=request.scroll_id, state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
) )
call.result.data = attr.asdict(res, recurse=False) call.result.data = attr.asdict(res, recurse=False)
@ -896,6 +897,7 @@ def next_plot_sample(call, company_id, request: NextHistorySampleRequest):
task=request.task, task=request.task,
state_id=request.scroll_id, state_id=request.scroll_id,
navigate_earlier=request.navigate_earlier, navigate_earlier=request.navigate_earlier,
next_iteration=request.next_iteration,
) )
call.result.data = attr.asdict(res, recurse=False) call.result.data = attr.asdict(res, recurse=False)

View File

@ -137,6 +137,21 @@ class TestTaskPlots(TestService):
) )
self._assertEqualEvent(res.event, events[-3]) self._assertEqualEvent(res.event, events[-3])
# next_iteration
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True
)
self._assertEqualEvent(res.event, None)
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
)
self._assertEqualEvent(res.event, events[-2])
self.assertEqual(res.event.iter, 1)
res = self.api.events.next_plot_sample(
task=task, scroll_id=res.scroll_id, next_iteration=True, navigate_earlier=False
)
self._assertEqualEvent(res.event, None)
def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]): def _assertEqualEvent(self, ev1: dict, ev2: Optional[dict]):
if ev2 is None: if ev2 is None:
self.assertIsNone(ev1) self.assertIsNone(ev1)