/** # Copyright 2023 NVIDIA CORPORATION # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. **/ package nvml import ( "errors" "fmt" "sync" "github.com/NVIDIA/go-nvml/pkg/dl" ) import "C" const ( defaultNvmlLibraryName = "libnvidia-ml.so.1" defaultNvmlLibraryLoadFlags = dl.RTLD_LAZY | dl.RTLD_GLOBAL ) var errLibraryNotLoaded = errors.New("library not loaded") var errLibraryAlreadyLoaded = errors.New("library already loaded") // dynamicLibrary is an interface for abstacting the underlying library. // This also allows for mocking and testing. //go:generate moq -stub -out dynamicLibrary_mock.go . dynamicLibrary type dynamicLibrary interface { Lookup(string) error Open() error Close() error } // library represents an nvml library. // This includes a reference to the underlying DynamicLibrary type library struct { sync.Mutex path string refcount refcount dl dynamicLibrary } var _ Interface = (*library)(nil) // libnvml is a global instance of the nvml library. var libnvml = newLibrary() func New(opts ...LibraryOption) Interface { return newLibrary(opts...) } func newLibrary(opts ...LibraryOption) *library { l := &library{} l.init(opts...) return l } func (l *library) init(opts ...LibraryOption) { o := libraryOptions{} for _, opt := range opts { opt(&o) } if o.path == "" { o.path = defaultNvmlLibraryName } if o.flags == 0 { o.flags = defaultNvmlLibraryLoadFlags } l.path = o.path l.dl = dl.New(o.path, o.flags) } func (l *library) Extensions() ExtendedInterface { return l } // LookupSymbol checks whether the specified library symbol exists in the library. // Note that this requires that the library be loaded. func (l *library) LookupSymbol(name string) error { if l == nil || l.refcount == 0 { return fmt.Errorf("error looking up %s: %w", name, errLibraryNotLoaded) } return l.dl.Lookup(name) } // load initializes the library and updates the versioned symbols. // Multiple calls to an already loaded library will return without error. func (l *library) load() (rerr error) { l.Lock() defer l.Unlock() defer func() { l.refcount.IncOnNoError(rerr) }() if l.refcount > 0 { return nil } if err := l.dl.Open(); err != nil { return fmt.Errorf("error opening %s: %w", l.path, err) } // Update the errorStringFunc to point to nvml.ErrorString errorStringFunc = nvmlErrorString // Update all versioned symbols l.updateVersionedSymbols() return nil } // close the underlying library and ensure that the global pointer to the // library is set to nil to ensure that subsequent calls to open will reinitialize it. // Multiple calls to an already closed nvml library will return without error. func (l *library) close() (rerr error) { l.Lock() defer l.Unlock() defer func() { l.refcount.DecOnNoError(rerr) }() if l.refcount != 1 { return nil } if err := l.dl.Close(); err != nil { return fmt.Errorf("error closing %s: %w", l.path, err) } // Update the errorStringFunc to point to defaultErrorStringFunc errorStringFunc = defaultErrorStringFunc return nil } // Default all versioned APIs to v1 (to infer the types) var nvmlInit = nvmlInit_v1 var nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v1 var nvmlDeviceGetCount = nvmlDeviceGetCount_v1 var nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v1 var nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v1 var nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v1 var nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v1 var nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v1 var nvmlEventSetWait = nvmlEventSetWait_v1 var nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v1 var nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v1 var deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v1 var deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v1 var deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v1 var GetBlacklistDeviceCount = GetExcludedDeviceCount var GetBlacklistDeviceInfoByIndex = GetExcludedDeviceInfoByIndex var nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v1 var nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v1 // BlacklistDeviceInfo was replaced by ExcludedDeviceInfo type BlacklistDeviceInfo = ExcludedDeviceInfo type ProcessInfo_v1Slice []ProcessInfo_v1 type ProcessInfo_v2Slice []ProcessInfo_v2 func (pis ProcessInfo_v1Slice) ToProcessInfoSlice() []ProcessInfo { var newInfos []ProcessInfo for _, pi := range pis { info := ProcessInfo{ Pid: pi.Pid, UsedGpuMemory: pi.UsedGpuMemory, GpuInstanceId: 0xFFFFFFFF, // GPU instance ID is invalid in v1 ComputeInstanceId: 0xFFFFFFFF, // Compute instance ID is invalid in v1 } newInfos = append(newInfos, info) } return newInfos } func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo { var newInfos []ProcessInfo for _, pi := range pis { info := ProcessInfo(pi) newInfos = append(newInfos, info) } return newInfos } // updateVersionedSymbols checks for versioned symbols in the loaded dynamic library. // If newer versioned symbols exist, these replace the default `v1` symbols initialized above. // When new versioned symbols are added, these would have to be initialized above and have // corresponding checks and subsequent assignments added below. func (l *library) updateVersionedSymbols() { err := l.LookupSymbol("nvmlInit_v2") if err == nil { nvmlInit = nvmlInit_v2 } err = l.LookupSymbol("nvmlDeviceGetPciInfo_v2") if err == nil { nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2 } err = l.LookupSymbol("nvmlDeviceGetPciInfo_v3") if err == nil { nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3 } err = l.LookupSymbol("nvmlDeviceGetCount_v2") if err == nil { nvmlDeviceGetCount = nvmlDeviceGetCount_v2 } err = l.LookupSymbol("nvmlDeviceGetHandleByIndex_v2") if err == nil { nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2 } err = l.LookupSymbol("nvmlDeviceGetHandleByPciBusId_v2") if err == nil { nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2 } err = l.LookupSymbol("nvmlDeviceGetNvLinkRemotePciInfo_v2") if err == nil { nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2 } // Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes // a different set of parameters than the v1 function. //err = l.LookupSymbol("nvmlDeviceRemoveGpu_v2") //if err == nil { // nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2 //} err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v2") if err == nil { nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2 } err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v3") if err == nil { nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3 } err = l.LookupSymbol("nvmlDeviceGetGridLicensableFeatures_v4") if err == nil { nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4 } err = l.LookupSymbol("nvmlEventSetWait_v2") if err == nil { nvmlEventSetWait = nvmlEventSetWait_v2 } err = l.LookupSymbol("nvmlDeviceGetAttributes_v2") if err == nil { nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2 } err = l.LookupSymbol("nvmlComputeInstanceGetInfo_v2") if err == nil { nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2 } err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v2") if err == nil { deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2 } err = l.LookupSymbol("nvmlDeviceGetComputeRunningProcesses_v3") if err == nil { deviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3 } err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v2") if err == nil { deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2 } err = l.LookupSymbol("nvmlDeviceGetGraphicsRunningProcesses_v3") if err == nil { deviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3 } err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v2") if err == nil { deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2 } err = l.LookupSymbol("nvmlDeviceGetMPSComputeRunningProcesses_v3") if err == nil { deviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3 } err = l.LookupSymbol("nvmlDeviceGetGpuInstancePossiblePlacements_v2") if err == nil { nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2 } err = l.LookupSymbol("nvmlVgpuInstanceGetLicenseInfo_v2") if err == nil { nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2 } }