mirror of
https://github.com/deepseek-ai/smallpond
synced 2025-06-26 18:27:45 +00:00
214 lines
6.9 KiB
Python
214 lines
6.9 KiB
Python
from typing import List
|
|
|
|
import pandas as pd
|
|
import pyarrow as pa
|
|
import pytest
|
|
|
|
from smallpond.dataframe import Session
|
|
|
|
|
|
def test_pandas(sp: Session):
|
|
pandas_df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|
df = sp.from_pandas(pandas_df)
|
|
assert df.to_pandas().equals(pandas_df)
|
|
|
|
|
|
def test_arrow(sp: Session):
|
|
arrow_table = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|
df = sp.from_arrow(arrow_table)
|
|
assert df.to_arrow() == arrow_table
|
|
|
|
|
|
def test_items(sp: Session):
|
|
df = sp.from_items([1, 2, 3])
|
|
assert df.take_all() == [{"item": 1}, {"item": 2}, {"item": 3}]
|
|
df = sp.from_items([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}])
|
|
assert df.take_all() == [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]
|
|
|
|
|
|
def test_csv(sp: Session):
|
|
df = sp.read_csv(
|
|
"tests/data/mock_urls/*.tsv",
|
|
schema={"urlstr": "varchar", "valstr": "varchar"},
|
|
delim=r"\t",
|
|
)
|
|
assert df.count() == 1000
|
|
|
|
|
|
def test_parquet(sp: Session):
|
|
df = sp.read_parquet("tests/data/mock_urls/*.parquet")
|
|
assert df.count() == 1000
|
|
|
|
|
|
def test_take(sp: Session):
|
|
df = sp.from_pandas(pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}))
|
|
assert df.take(2) == [{"a": 1, "b": 4}, {"a": 2, "b": 5}]
|
|
assert df.take_all() == [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]
|
|
|
|
|
|
def test_map(sp: Session):
|
|
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
|
|
df1 = df.map("a + b as c")
|
|
assert df1.to_arrow() == pa.table({"c": [5, 7, 9]})
|
|
df2 = df.map(lambda r: {"c": r["a"] + r["b"]})
|
|
assert df2.to_arrow() == pa.table({"c": [5, 7, 9]})
|
|
|
|
# user need to specify the schema if can not be inferred from the mapping values
|
|
df3 = df.map(
|
|
lambda r: {"c": None if r["a"] == 1 else r["a"] + r["b"]},
|
|
schema=pa.schema([("c", pa.int64())]),
|
|
)
|
|
assert df3.to_arrow() == pa.table({"c": pa.array([None, 7, 9], type=pa.int64())})
|
|
|
|
|
|
def test_flat_map(sp: Session):
|
|
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
|
|
df1 = df.flat_map(lambda r: [{"c": r["a"]}, {"c": r["b"]}])
|
|
assert df1.to_arrow() == pa.table({"c": [1, 4, 2, 5, 3, 6]})
|
|
df2 = df.flat_map("unnest(array[a, b]) as c")
|
|
assert df2.to_arrow() == pa.table({"c": [1, 4, 2, 5, 3, 6]})
|
|
|
|
# user need to specify the schema if can not be inferred from the mapping values
|
|
df3 = df.flat_map(lambda r: [{"c": None}], schema=pa.schema([("c", pa.int64())]))
|
|
assert df3.to_arrow() == pa.table({"c": pa.array([None, None, None], type=pa.int64())})
|
|
|
|
|
|
def test_map_batches(sp: Session):
|
|
df = sp.read_parquet("tests/data/mock_urls/*.parquet")
|
|
df = df.map_batches(
|
|
lambda batch: pa.table({"num_rows": [batch.num_rows]}),
|
|
batch_size=350,
|
|
)
|
|
assert df.take_all() == [{"num_rows": 350}, {"num_rows": 350}, {"num_rows": 300}]
|
|
|
|
|
|
def test_filter(sp: Session):
|
|
df = sp.from_arrow(pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}))
|
|
df1 = df.filter("a > 1")
|
|
assert df1.to_arrow() == pa.table({"a": [2, 3], "b": [5, 6]})
|
|
df2 = df.filter(lambda r: r["a"] > 1)
|
|
assert df2.to_arrow() == pa.table({"a": [2, 3], "b": [5, 6]})
|
|
|
|
|
|
def test_random_shuffle(sp: Session):
|
|
df = sp.from_items(list(range(1000))).repartition(10, by_rows=True)
|
|
df = df.random_shuffle()
|
|
shuffled = [d["item"] for d in df.take_all()]
|
|
assert sorted(shuffled) == list(range(1000))
|
|
|
|
def count_inversions(arr: List[int]) -> int:
|
|
return sum(sum(1 for j in range(i + 1, len(arr)) if arr[i] > arr[j]) for i in range(len(arr)))
|
|
|
|
# check the shuffle is random enough
|
|
# the expected number of inversions is n*(n-1)/4 = 249750
|
|
assert 220000 <= count_inversions(shuffled) <= 280000
|
|
|
|
|
|
def test_partition_by(sp: Session):
|
|
df = sp.from_items(list(range(1000))).repartition(10, by="item % 10")
|
|
df = df.map("min(item % 10) as min, max(item % 10) as max")
|
|
assert df.take_all() == [{"min": i, "max": i} for i in range(10)]
|
|
|
|
|
|
def test_partition_by_key_out_of_range(sp: Session):
|
|
df = sp.from_items(list(range(1000))).repartition(10, by="item % 11")
|
|
try:
|
|
df.to_arrow()
|
|
except Exception as ex:
|
|
assert "partition key 10 is out of range 0-9" in str(ex)
|
|
else:
|
|
assert False, "expected exception"
|
|
|
|
|
|
def test_partition_by_hash(sp: Session):
|
|
df = sp.from_items(list(range(1000))).repartition(10, hash_by="item")
|
|
items = [d["item"] for d in df.take_all()]
|
|
assert sorted(items) == list(range(1000))
|
|
|
|
|
|
def test_count(sp: Session):
|
|
df = sp.from_items([1, 2, 3])
|
|
assert df.count() == 3
|
|
|
|
|
|
def test_limit(sp: Session):
|
|
df = sp.from_items(list(range(1000))).repartition(10, by_rows=True)
|
|
assert df.limit(2).count() == 2
|
|
|
|
|
|
@pytest.mark.skip(reason="limit can not be pushed down to sql node for now")
|
|
@pytest.mark.timeout(10)
|
|
def test_limit_large(sp: Session):
|
|
# limit will be fused with the previous select
|
|
# otherwise, it will be timeout
|
|
df = sp.partial_sql("select * from range(1000000000)")
|
|
assert df.limit(2).count() == 2
|
|
|
|
|
|
def test_partial_sql(sp: Session):
|
|
# no input deps
|
|
df = sp.partial_sql("select * from range(3)")
|
|
assert df.to_arrow() == pa.table({"range": [0, 1, 2]})
|
|
|
|
# join
|
|
df1 = sp.from_arrow(pa.table({"id1": [1, 2, 3], "val1": ["a", "b", "c"]}))
|
|
df2 = sp.from_arrow(pa.table({"id2": [1, 2, 3], "val2": ["d", "e", "f"]}))
|
|
joined = sp.partial_sql("select id1, val1, val2 from {0} join {1} on id1 = id2", df1, df2)
|
|
assert joined.to_arrow() == pa.table(
|
|
{"id1": [1, 2, 3], "val1": ["a", "b", "c"], "val2": ["d", "e", "f"]},
|
|
schema=pa.schema(
|
|
[
|
|
("id1", pa.int64()),
|
|
("val1", pa.large_string()),
|
|
("val2", pa.large_string()),
|
|
]
|
|
),
|
|
)
|
|
|
|
|
|
def test_error_message(sp: Session):
|
|
df = sp.from_items([1, 2, 3])
|
|
df = sp.partial_sql("select a,, from {0}", df)
|
|
try:
|
|
df.to_arrow()
|
|
except Exception as ex:
|
|
# sql query should be in the exception message
|
|
assert "select a,, from" in str(ex)
|
|
else:
|
|
assert False, "expected exception"
|
|
|
|
|
|
def test_unpicklable_task_exception(sp: Session):
|
|
from loguru import logger
|
|
|
|
df = sp.from_items([1, 2, 3])
|
|
try:
|
|
df.map(lambda x: logger.info("use outside logger")).to_arrow()
|
|
except Exception as ex:
|
|
assert "Can't pickle task" in str(ex)
|
|
assert "HINT: DO NOT use externally imported loguru logger in your task. Please import it within the task." in str(ex)
|
|
else:
|
|
assert False, "expected exception"
|
|
|
|
|
|
def test_log(sp: Session):
|
|
df = sp.from_items([1, 2, 3])
|
|
|
|
def log_record(x):
|
|
import logging
|
|
import sys
|
|
|
|
from loguru import logger
|
|
|
|
print("stdout")
|
|
print("stderr", file=sys.stderr)
|
|
logger.info("loguru")
|
|
logging.info("logging")
|
|
return x
|
|
|
|
df.map(log_record).to_arrow()
|
|
|
|
# TODO: check logs should be see in the log file
|
|
# FIXME: logs in unit test are not written to the log file
|
|
# because we share the same ray instance for all tests
|