mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
72 lines
3.0 KiB
Python
72 lines
3.0 KiB
Python
import itertools
|
|
import unittest
|
|
|
|
import numpy as np
|
|
from hypothesis import given
|
|
from hypothesis import strategies as st
|
|
|
|
from smallpond.common import get_nth_partition, split_into_cols, split_into_rows
|
|
from tests.test_fabric import TestFabric
|
|
|
|
|
|
class TestCommon(TestFabric, unittest.TestCase):
|
|
def test_get_nth_partition(self):
|
|
items = [1, 2, 3]
|
|
# split into 1 partitions
|
|
self.assertListEqual([1, 2, 3], get_nth_partition(items, 0, 1))
|
|
# split into 2 partitions
|
|
self.assertListEqual([1, 2], get_nth_partition(items, 0, 2))
|
|
self.assertListEqual([3], get_nth_partition(items, 1, 2))
|
|
# split into 3 partitions
|
|
self.assertListEqual([1], get_nth_partition(items, 0, 3))
|
|
self.assertListEqual([2], get_nth_partition(items, 1, 3))
|
|
self.assertListEqual([3], get_nth_partition(items, 2, 3))
|
|
# split into 5 partitions
|
|
self.assertListEqual([1], get_nth_partition(items, 0, 5))
|
|
self.assertListEqual([2], get_nth_partition(items, 1, 5))
|
|
self.assertListEqual([3], get_nth_partition(items, 2, 5))
|
|
self.assertListEqual([], get_nth_partition(items, 3, 5))
|
|
self.assertListEqual([], get_nth_partition(items, 4, 5))
|
|
|
|
@given(st.data())
|
|
def test_split_into_rows(self, data: st.data):
|
|
nelements = data.draw(st.integers(1, 100))
|
|
npartitions = data.draw(st.integers(1, 2 * nelements))
|
|
items = list(range(nelements))
|
|
computed = split_into_rows(items, npartitions)
|
|
expected = [get_nth_partition(items, n, npartitions) for n in range(npartitions)]
|
|
self.assertEqual(expected, computed)
|
|
|
|
@given(st.data())
|
|
def test_split_into_cols(self, data: st.data):
|
|
nelements = data.draw(st.integers(1, 100))
|
|
npartitions = data.draw(st.integers(1, 2 * nelements))
|
|
items = list(range(nelements))
|
|
chunks = split_into_cols(items, npartitions)
|
|
self.assertEqual(npartitions, len(chunks))
|
|
self.assertListEqual(
|
|
items,
|
|
[x for row in itertools.zip_longest(*chunks) for x in row if x is not None],
|
|
)
|
|
chunk_sizes = set(len(chk) for chk in chunks)
|
|
if len(chunk_sizes) > 1:
|
|
small_size, large_size = sorted(chunk_sizes)
|
|
self.assertEqual(small_size + 1, large_size)
|
|
else:
|
|
(chunk_size,) = chunk_sizes
|
|
self.assertEqual(len(items), chunk_size * npartitions)
|
|
|
|
def test_split_into_rows_bench(self):
|
|
for nelements in [100000, 1000000]:
|
|
items = np.arange(nelements)
|
|
for npartitions in [1024, 4096, 10240, nelements, 2 * nelements]:
|
|
chunks = split_into_rows(items, npartitions)
|
|
self.assertEqual(npartitions, len(chunks))
|
|
|
|
def test_split_into_cols_bench(self):
|
|
for nelements in [100000, 1000000]:
|
|
items = np.arange(nelements)
|
|
for npartitions in [1024, 4096, 10240, nelements, 2 * nelements]:
|
|
chunks = split_into_cols(items, npartitions)
|
|
self.assertEqual(npartitions, len(chunks))
|