diff --git a/pkg/nvcdi/full-gpu-nvml.go b/pkg/nvcdi/full-gpu-nvml.go index 9dc6780e..ab1fe9cc 100644 --- a/pkg/nvcdi/full-gpu-nvml.go +++ b/pkg/nvcdi/full-gpu-nvml.go @@ -73,6 +73,7 @@ type byPathHookDiscoverer struct { driverRoot string nvidiaCTKPath string pciBusID string + deviceNodes discover.Discover } var _ discover.Discover = (*byPathHookDiscoverer)(nil) @@ -111,6 +112,7 @@ func newFullGPUDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPat driverRoot: driverRoot, nvidiaCTKPath: nvidiaCTKPath, pciBusID: pciBusID, + deviceNodes: deviceNodes, } dd := discover.Merge( @@ -158,6 +160,20 @@ func (d *byPathHookDiscoverer) Mounts() ([]discover.Mount, error) { } func (d *byPathHookDiscoverer) deviceNodeLinks() ([]string, error) { + devices, err := d.deviceNodes.Devices() + if err != nil { + return nil, fmt.Errorf("failed to discover device nodes: %v", err) + } + + if len(devices) == 0 { + return nil, nil + } + + selectedDevices := make(map[string]bool) + for _, d := range devices { + selectedDevices[d.HostPath] = true + } + candidates := []string{ fmt.Sprintf("/dev/dri/by-path/pci-%s-card", d.pciBusID), fmt.Sprintf("/dev/dri/by-path/pci-%s-render", d.pciBusID), @@ -172,6 +188,14 @@ func (d *byPathHookDiscoverer) deviceNodeLinks() ([]string, error) { continue } + deviceNode := device + if !filepath.IsAbs(device) { + deviceNode = filepath.Join(filepath.Dir(linkPath), device) + } + if !selectedDevices[deviceNode] { + d.logger.Debugf("ignoring device symlink %v -> %v since %v is not mounted", linkPath, device, deviceNode) + continue + } d.logger.Debugf("adding device symlink %v -> %v", linkPath, device) links = append(links, fmt.Sprintf("%v::%v", device, linkPath)) }