smallpond/tests/test_workqueue.py
2025-03-05 22:46:23 +08:00

118 lines
3.4 KiB
Python

import multiprocessing
import multiprocessing.dummy
import multiprocessing.queues
import queue
import tempfile
import time
import unittest
from loguru import logger
from smallpond.execution.workqueue import (
WorkItem,
WorkQueue,
WorkQueueInMemory,
WorkQueueOnFilesystem,
)
from tests.test_fabric import TestFabric
class PrintWork(WorkItem):
def __init__(self, name: str, message: str) -> None:
super().__init__(name, cpu_limit=1, gpu_limit=0, memory_limit=0)
self.message = message
def run(self) -> bool:
logger.debug(f"{self.key}: {self.message}")
return True
def producer(wq: WorkQueue, id: int, numItems: int, numConsumers: int) -> None:
print(f"wq.outbound_works: {wq.outbound_works}")
for i in range(numItems):
wq.push(PrintWork(f"item-{i}", message="hello"), buffering=(i % 3 == 1))
# wq.push(PrintWork(f"item-{i}", message="hello"))
if i % 5 == 0:
wq.flush()
for i in range(numConsumers):
wq.push(PrintWork(f"stop-{i}", message="stop"))
logger.success(f"producer {id} generated {numItems} items")
def consumer(wq: WorkQueue, id: int) -> int:
numItems = 0
numWaits = 0
running = True
while running:
items = wq.pop(count=1)
if not items:
numWaits += 1
time.sleep(0.01)
continue
for item in items:
assert isinstance(item, PrintWork)
if item.message == "stop":
running = False
break
item.exec()
numItems += 1
logger.success(f"consumer {id} collected {numItems} items, {numWaits} waits")
logger.complete()
return numItems
class WorkQueueTestBase(object):
wq: WorkQueue = None
pool: multiprocessing.Pool = None
def setUp(self) -> None:
logger.disable("smallpond.execution.workqueue")
return super().setUp()
def test_basics(self):
numItems = 200
for i in range(numItems):
self.wq.push(PrintWork(f"item-{i}", message="hello"))
numCollected = 0
for _ in range(numItems):
items = self.wq.pop()
logger.info(f"{len(items)} items")
numCollected += len(items)
if numItems == numCollected:
break
def test_multi_consumers(self):
numConsumers = 10
numItems = 200
result = self.pool.starmap_async(consumer, [(self.wq, id) for id in range(numConsumers)])
producer(self.wq, 0, numItems, numConsumers)
logger.info("waiting for result")
numCollected = sum(result.get(timeout=20))
logger.info(f"expected vs collected: {numItems} vs {numCollected}")
self.assertEqual(numItems, numCollected)
logger.success("all done")
self.pool.terminate()
self.pool.join()
logger.success("workers stopped")
class TestWorkQueueInMemory(WorkQueueTestBase, TestFabric, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
self.wq = WorkQueueInMemory(queue_type=queue.Queue)
self.pool = multiprocessing.dummy.Pool(10)
class TestWorkQueueOnFilesystem(WorkQueueTestBase, TestFabric, unittest.TestCase):
workq_root: str
def setUp(self) -> None:
super().setUp()
self.workq_root = tempfile.mkdtemp(dir=self.runtime_ctx.queue_root)
self.wq = WorkQueueOnFilesystem(self.workq_root, sort=True)
self.pool = multiprocessing.Pool(10)