diff --git a/deep_gemm/jit/template.py b/deep_gemm/jit/template.py index c983817..cdca4c4 100644 --- a/deep_gemm/jit/template.py +++ b/deep_gemm/jit/template.py @@ -1,9 +1,9 @@ import copy import ctypes import os +import torch from typing import Any, Dict, Iterable, Tuple -import torch # Name map for Python `eval` typename_map: Dict[Any, str] = {