diff --git a/internal/cuda/cuda.go b/internal/cuda/cuda.go index 6224de3e..2c70a821 100644 --- a/internal/cuda/cuda.go +++ b/internal/cuda/cuda.go @@ -31,11 +31,21 @@ import ( #define CUDAAPI #endif +typedef int CUdevice; + +typedef enum CUdevice_attribute_enum { + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76 +} CUdevice_attribute; + typedef enum cudaError_enum { CUDA_SUCCESS = 0 } CUresult; +CUresult CUDAAPI cuInit(unsigned int Flags); CUresult CUDAAPI cuDriverGetVersion(int *driverVersion); +CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal); +CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); */ import "C" @@ -71,6 +81,48 @@ func Version() (string, error) { return fmt.Sprintf("%d.%d", major, minor), nil } +// ComputeCapability returns the CUDA compute capability of a device with the specified index as a string +// or an error if this cannot be determined. +func ComputeCapability(index int) (string, error) { + lib, err := load() + if err != nil { + return "", err + } + defer lib.Close() + + if err := lib.Lookup("cuInit"); err != nil { + return "", fmt.Errorf("failed to lookup symbol: %v", err) + } + if err := lib.Lookup("cuDeviceGet"); err != nil { + return "", fmt.Errorf("failed to lookup symbol: %v", err) + } + if err := lib.Lookup("cuDeviceGetAttribute"); err != nil { + return "", fmt.Errorf("failed to lookup symbol: %v", err) + } + + if result := C.cuInit(C.uint(0)); result != C.CUDA_SUCCESS { + return "", fmt.Errorf("failed to initialize CUDA: result=%v", result) + } + + var device C.CUdevice + // NOTE: We only query the first device + if result := C.cuDeviceGet(&device, C.int(index)); result != C.CUDA_SUCCESS { + return "", fmt.Errorf("failed to get CUDA device %v: result=%v", 0, result) + } + + var major C.int + if result := C.cuDeviceGetAttribute(&major, C.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device); result != C.CUDA_SUCCESS { + return "", fmt.Errorf("failed to get CUDA compute capability major for device %v : result=%v", 0, result) + } + + var minor C.int + if result := C.cuDeviceGetAttribute(&minor, C.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device); result != C.CUDA_SUCCESS { + return "", fmt.Errorf("failed to get CUDA compute capability minor for device %v: result=%v", 0, result) + } + + return fmt.Sprintf("%d.%d", major, minor), nil +} + func load() (*dl.DynamicLibrary, error) { lib := dl.New(libraryName, libraryLoadFlags) if lib == nil {