mirror of
https://github.com/deepseek-ai/DeepGEMM
synced 2025-05-09 22:00:34 +00:00
pytest Integration
Decorate the test function with @pytest.mark.skipif(...) so the test is skipped if CUDA is unavailable. Move all testing logic into a function named test_jit() so it’s automatically discoverable by pytest.
This commit is contained in:
parent
a6d97a1c1b
commit
58046b4e01
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -6,6 +7,10 @@ from deep_gemm import jit
|
|||||||
|
|
||||||
|
|
||||||
class Capture:
|
class Capture:
|
||||||
|
"""
|
||||||
|
Context manager to capture stdout via OS pipes.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.read_fd = None
|
self.read_fd = None
|
||||||
self.write_fd = None
|
self.write_fd = None
|
||||||
@ -28,37 +33,66 @@ class Capture:
|
|||||||
return self.captured
|
return self.captured
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA is required for this test.')
|
||||||
# Runtime
|
def test_jit():
|
||||||
|
# Print NVCC compiler
|
||||||
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
|
||||||
|
|
||||||
# Templates
|
# Define function arguments and code body
|
||||||
print('Generated code:')
|
print('Generated code:')
|
||||||
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16),
|
args = (
|
||||||
('enable_double_streams', bool), ('stream', torch.cuda.Stream))
|
('lhs', torch.float8_e4m3fn),
|
||||||
body = "\n"
|
('rhs', torch.float8_e4m3fn),
|
||||||
|
('scale', torch.float),
|
||||||
|
('out', torch.bfloat16),
|
||||||
|
('enable_double_streams', bool),
|
||||||
|
('stream', torch.cuda.Stream),
|
||||||
|
)
|
||||||
|
|
||||||
|
body = ''
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
|
||||||
body += 'std::cout << enable_double_streams << std::endl;\n'
|
body += 'std::cout << enable_double_streams << std::endl;\n'
|
||||||
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
|
||||||
|
|
||||||
code = jit.generate((), args, body)
|
code = jit.generate((), args, body)
|
||||||
print(code)
|
print(code)
|
||||||
|
|
||||||
# Build
|
# Build the function
|
||||||
print('Building ...')
|
print('Building ...')
|
||||||
func = jit.build('test_func', args, code)
|
func = jit.build('test_func', args, code)
|
||||||
|
|
||||||
# Test correctness
|
# Test correctness
|
||||||
print('Running ...')
|
print('Running ...')
|
||||||
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda')
|
fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device='cuda')
|
||||||
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda')
|
fp32_tensor = torch.empty((1,), dtype=torch.float, device='cuda')
|
||||||
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda')
|
bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device='cuda')
|
||||||
|
|
||||||
with Capture() as capture:
|
with Capture() as capture:
|
||||||
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0
|
ret = func(
|
||||||
|
fp8_tensor,
|
||||||
|
fp8_tensor,
|
||||||
|
fp32_tensor,
|
||||||
|
bf16_tensor,
|
||||||
|
True,
|
||||||
|
torch.cuda.current_stream(),
|
||||||
|
)
|
||||||
|
# If your JIT returns an error code, test it here
|
||||||
|
assert ret == 0, f'JIT function returned error code: {ret}'
|
||||||
|
|
||||||
output = capture.capture()
|
output = capture.capture()
|
||||||
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n'
|
ref_output = (
|
||||||
assert output == ref_output, f'{output=}, {ref_output=}'
|
f'{fp8_tensor.data_ptr()}\n'
|
||||||
|
f'{fp8_tensor.data_ptr()}\n'
|
||||||
|
f'{fp32_tensor.data_ptr()}\n'
|
||||||
|
f'{bf16_tensor.data_ptr()}\n'
|
||||||
|
f'1\n'
|
||||||
|
f'{torch.cuda.current_stream().cuda_stream}\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare the captured output to the reference
|
||||||
|
assert output == ref_output, f'Mismatch!\nGot:\n{output}\nExpected:\n{ref_output}'
|
||||||
|
|
||||||
print('JIT test passed')
|
print('JIT test passed')
|
||||||
|
Loading…
Reference in New Issue
Block a user