diff --git a/internal/platform-support/dgpu/dgpu.go b/internal/platform-support/dgpu/dgpu.go index 00982a62..b79f6bd4 100644 --- a/internal/platform-support/dgpu/dgpu.go +++ b/internal/platform-support/dgpu/dgpu.go @@ -21,32 +21,19 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps" ) // NewForDevice creates a discoverer for the specified Device. func NewForDevice(d device.Device, opts ...Option) (discover.Discover, error) { - o := &options{} - for _, opt := range opts { - opt(o) - } - - if o.logger == nil { - o.logger = logger.New() - } + o := new(opts...) return o.newNvmlDGPUDiscoverer(&toRequiredInfo{d}) } // NewForDevice creates a discoverer for the specified device and its associated MIG device. func NewForMigDevice(d device.Device, mig device.MigDevice, opts ...Option) (discover.Discover, error) { - o := &options{} - for _, opt := range opts { - opt(o) - } - - if o.logger == nil { - o.logger = logger.New() - } + o := new(opts...) return o.newNvmlMigDiscoverer( &toRequiredMigInfo{ @@ -55,3 +42,26 @@ func NewForMigDevice(d device.Device, mig device.MigDevice, opts ...Option) (dis }, ) } + +func new(opts ...Option) *options { + o := &options{} + for _, opt := range opts { + opt(o) + } + + if o.logger == nil { + o.logger = logger.New() + } + + if o.migCaps == nil { + migCaps, err := nvcaps.NewMigCaps() + if err != nil { + o.logger.Debugf("ignoring error getting MIG capability device paths: %v", err) + o.migCapsError = err + } else { + o.migCaps = migCaps + } + } + + return o +} diff --git a/internal/platform-support/dgpu/nvml.go b/internal/platform-support/dgpu/nvml.go index be111102..f24f4d55 100644 --- a/internal/platform-support/dgpu/nvml.go +++ b/internal/platform-support/dgpu/nvml.go @@ -78,24 +78,23 @@ type requiredMigInfo interface { } func (o *options) newNvmlMigDiscoverer(d requiredMigInfo) (discover.Discover, error) { + if o.migCaps == nil || o.migCapsError != nil { + return nil, fmt.Errorf("error getting MIG capability device paths: %v", o.migCapsError) + } + gpu, gi, ci, err := d.getPlacementInfo() if err != nil { return nil, fmt.Errorf("error getting placement info: %w", err) } - migCaps, err := nvcaps.NewMigCaps() - if err != nil { - return nil, fmt.Errorf("error getting MIG capability device paths: %v", err) - } - giCap := nvcaps.NewGPUInstanceCap(gpu, gi) - giCapDevicePath, err := migCaps.GetCapDevicePath(giCap) + giCapDevicePath, err := o.migCaps.GetCapDevicePath(giCap) if err != nil { return nil, fmt.Errorf("failed to get GI cap device path: %v", err) } ciCap := nvcaps.NewComputeInstanceCap(gpu, gi, ci) - ciCapDevicePath, err := migCaps.GetCapDevicePath(ciCap) + ciCapDevicePath, err := o.migCaps.GetCapDevicePath(ciCap) if err != nil { return nil, fmt.Errorf("failed to get CI cap device path: %v", err) } diff --git a/internal/platform-support/dgpu/nvml_test.go b/internal/platform-support/dgpu/nvml_test.go index 3e8306c9..da4ac2a7 100644 --- a/internal/platform-support/dgpu/nvml_test.go +++ b/internal/platform-support/dgpu/nvml_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/require" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps" ) // TODO: In order to properly test this, we need a mechanism to inject / @@ -85,3 +86,86 @@ func TestNewNvmlDGPUDiscoverer(t *testing.T) { }) } } + +func TestNewNvmlMIGDiscoverer(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + nvmllib := &mock.Interface{} + devicelib := device.New( + nvmllib, + ) + + testCases := []struct { + description string + mig *mock.Device + parent nvml.Device + migCaps nvcaps.MigCaps + expectedError error + expectedDevices []discover.Device + expectedHooks []discover.Hook + expectedMounts []discover.Mount + }{ + { + description: "", + mig: &mock.Device{ + IsMigDeviceHandleFunc: func() (bool, nvml.Return) { + return true, nvml.SUCCESS + }, + GetGpuInstanceIdFunc: func() (int, nvml.Return) { + return 1, nvml.SUCCESS + }, + GetComputeInstanceIdFunc: func() (int, nvml.Return) { + return 2, nvml.SUCCESS + }, + }, + parent: &mock.Device{ + GetMinorNumberFunc: func() (int, nvml.Return) { + return 3, nvml.SUCCESS + }, + GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) { + var busID [32]int8 + for i, b := range []byte("00000000:45:00:00") { + busID[i] = int8(b) + } + info := nvml.PciInfo{ + BusId: busID, + } + return info, nvml.SUCCESS + }, + }, + migCaps: nvcaps.MigCaps{ + "gpu3/gi1/access": 31, + "gpu3/gi1/ci2/access": 312, + }, + expectedDevices: nil, + expectedMounts: nil, + expectedHooks: []discover.Hook{}, + }, + } + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + + tc.mig.GetDeviceHandleFromMigDeviceHandleFunc = func() (nvml.Device, nvml.Return) { + return tc.parent, nvml.SUCCESS + } + parent, err := devicelib.NewDevice(tc.parent) + require.NoError(t, err) + + mig, err := devicelib.NewMigDevice(tc.mig) + require.NoError(t, err) + + d, err := NewForMigDevice(parent, mig, + WithLogger(logger), + WithMIGCaps(tc.migCaps), + ) + require.ErrorIs(t, err, tc.expectedError) + + devices, _ := d.Devices() + require.EqualValues(t, tc.expectedDevices, devices) + hooks, _ := d.Hooks() + require.EqualValues(t, tc.expectedHooks, hooks) + mounts, _ := d.Mounts() + require.EqualValues(t, tc.expectedMounts, mounts) + }) + } +} diff --git a/internal/platform-support/dgpu/options.go b/internal/platform-support/dgpu/options.go index cea58c6d..41e4d7a9 100644 --- a/internal/platform-support/dgpu/options.go +++ b/internal/platform-support/dgpu/options.go @@ -18,12 +18,18 @@ package dgpu import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps" ) type options struct { logger logger.Interface devRoot string nvidiaCDIHookPath string + + // migCaps stores the MIG capabilities for the system. + // If MIG is not available, this is nil. + migCaps nvcaps.MigCaps + migCapsError error } type Option func(*options) @@ -48,3 +54,10 @@ func WithNVIDIACDIHookPath(path string) Option { l.nvidiaCDIHookPath = path } } + +// WithMIGCaps sets the MIG capabilities. +func WithMIGCaps(migCaps nvcaps.MigCaps) Option { + return func(l *options) { + l.migCaps = migCaps + } +}