From 58046b4e0189d1391bccbd922e1c501ab826acde Mon Sep 17 00:00:00 2001 From: A-transformer Date: Thu, 27 Feb 2025 09:48:20 +0400 Subject: [PATCH] pytest Integration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Decorate the test function with @pytest.mark.skipif(...) so the test is skipped if CUDA is unavailable. Move all testing logic into a function named test_jit() so it’s automatically discoverable by pytest. --- tests/test_jit.py | 60 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/tests/test_jit.py b/tests/test_jit.py index 78bc77b..952157a 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -1,4 +1,5 @@ import os +import pytest import torch from typing import Any @@ -6,6 +7,10 @@ from deep_gemm import jit class Capture: + """ + Context manager to capture stdout via OS pipes. + """ + def __init__(self) -> None: self.read_fd = None self.write_fd = None @@ -28,37 +33,66 @@ class Capture: return self.captured -if __name__ == '__main__': - # Runtime +@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA is required for this test.') +def test_jit(): + # Print NVCC compiler print(f'NVCC compiler: {jit.get_nvcc_compiler()}\n') - # Templates + # Define function arguments and code body print('Generated code:') - args = (('lhs', torch.float8_e4m3fn), ('rhs', torch.float8_e4m3fn), ('scale', torch.float), ('out', torch.bfloat16), - ('enable_double_streams', bool), ('stream', torch.cuda.Stream)) - body = "\n" + args = ( + ('lhs', torch.float8_e4m3fn), + ('rhs', torch.float8_e4m3fn), + ('scale', torch.float), + ('out', torch.bfloat16), + ('enable_double_streams', bool), + ('stream', torch.cuda.Stream), + ) + + body = '' body += 'std::cout << reinterpret_cast(lhs) << std::endl;\n' body += 'std::cout << reinterpret_cast(rhs) << std::endl;\n' body += 'std::cout << reinterpret_cast(scale) << std::endl;\n' body += 'std::cout << reinterpret_cast(out) << std::endl;\n' body += 'std::cout << enable_double_streams << std::endl;\n' body += 'std::cout << reinterpret_cast(stream) << std::endl;\n' + code = jit.generate((), args, body) print(code) - # Build + # Build the function print('Building ...') func = jit.build('test_func', args, code) # Test correctness print('Running ...') - fp8_tensor = torch.empty((1, ), dtype=torch.float8_e4m3fn, device='cuda') - fp32_tensor = torch.empty((1, ), dtype=torch.float, device='cuda') - bf16_tensor = torch.empty((1, ), dtype=torch.bfloat16, device='cuda') + fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device='cuda') + fp32_tensor = torch.empty((1,), dtype=torch.float, device='cuda') + bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device='cuda') + with Capture() as capture: - assert func(fp8_tensor, fp8_tensor, fp32_tensor, bf16_tensor, True, torch.cuda.current_stream()) == 0 + ret = func( + fp8_tensor, + fp8_tensor, + fp32_tensor, + bf16_tensor, + True, + torch.cuda.current_stream(), + ) + # If your JIT returns an error code, test it here + assert ret == 0, f'JIT function returned error code: {ret}' + output = capture.capture() - ref_output = f'{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n' - assert output == ref_output, f'{output=}, {ref_output=}' + ref_output = ( + f'{fp8_tensor.data_ptr()}\n' + f'{fp8_tensor.data_ptr()}\n' + f'{fp32_tensor.data_ptr()}\n' + f'{bf16_tensor.data_ptr()}\n' + f'1\n' + f'{torch.cuda.current_stream().cuda_stream}\n' + ) + + # Compare the captured output to the reference + assert output == ref_output, f'Mismatch!\nGot:\n{output}\nExpected:\n{ref_output}' print('JIT test passed')