Fix JIT tests

This commit is contained in:
Chenggang Zhao
2025-05-16 14:39:58 +08:00
parent 78d8362e7a
commit 391755ada0
6 changed files with 14 additions and 11 deletions

View File

@@ -74,7 +74,8 @@ static void __instantiate_kernel() {{
if __name__ == '__main__':
print('Generated code:')
code = VectorAddRuntime.generate(T='float')
kwargs = {'T': 'float'}
code = VectorAddRuntime.generate(kwargs)
print(code)
print()
@@ -85,7 +86,7 @@ if __name__ == '__main__':
# Build
print('Building ...')
func = compiler_cls.build('test_func', code, VectorAddRuntime)
func = compiler_cls.build('test_func', code, VectorAddRuntime, kwargs)
# Run and check
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')