pytest Integration

Decorate the test function with @pytest.mark.skipif(...) so the test is skipped if CUDA is unavailable.
Move all testing logic into a function named test_jit() so it’s automatically discoverable by pytest.
This commit is contained in:
A-transformer 2025-02-27 09:48:20 +04:00 committed by GitHub
parent a6d97a1c1b
commit 58046b4e01
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,4 +1,5 @@
import os import os
import pytest
import torch import torch
from typing import Any from typing import Any
@ -6,6 +7,10 @@ from deep_gemm import jit
class Capture: class Capture:
"""
Context manager to capture stdout via OS pipes.
"""
def __init__(self) -> None: def __init__(self) -> None:
self.read_fd = None self.read_fd = None
self.write_fd = None self.write_fd = None
@ -28,37 +33,66 @@ class Capture:
return self.captured return self.captured
if __name__ == '__main__': @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA is required for this test.')
# Runtime def test_jit():
# Print NVCC compiler
print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n')
# Templates # Define function arguments and code body
print('Generated code:') print('Generated code:')
args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), args = (
('enable_double_streams', bool), ('stream', torch.cuda.Stream)) ('lhs', torch.float8_e4m3fn),
body = "\n" ('rhs', torch.float8_e4m3fn),
('scale', torch.float),
('out', torch.bfloat16),
('enable_double_streams', bool),
('stream', torch.cuda.Stream),
)
body = ''
body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n' body += 'std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n' body += 'std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n' body += 'std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n' body += 'std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n'
body += 'std::cout << enable_double_streams << std::endl;\n' body += 'std::cout << enable_double_streams << std::endl;\n'
body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n' body += 'std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n'
code = jit.generate((), args, body) code = jit.generate((), args, body)
print(code) print(code)
# Build # Build the function
print('Building ...') print('Building ...')
func = jit.build('test_func', args, code) func = jit.build('test_func', args, code)
# Test correctness # Test correctness
print('Running ...') print('Running ...')
fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device='cuda')
fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') fp32_tensor = torch.empty((1,), dtype=torch.float, device='cuda')
bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device='cuda')
with Capture() as capture: with Capture() as capture:
assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 ret = func(
fp8_tensor,
fp8_tensor,
fp32_tensor,
bf16_tensor,
True,
torch.cuda.current_stream(),
)
# If your JIT returns an error code, test it here
assert ret == 0, f'JIT function returned error code: {ret}'
output = capture.capture() output = capture.capture()
ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n' ref_output = (
assert output == ref_output, f'{output=}, {ref_output=}' f'{fp8_tensor.data_ptr()}\n'
f'{fp8_tensor.data_ptr()}\n'
f'{fp32_tensor.data_ptr()}\n'
f'{bf16_tensor.data_ptr()}\n'
f'1\n'
f'{torch.cuda.current_stream().cuda_stream}\n'
)
# Compare the captured output to the reference
assert output == ref_output, f'Mismatch!\nGot:\n{output}\nExpected:\n{ref_output}'
print('JIT test passed') print('JIT test passed')