diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 9f9e994b..598a40c1 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -104,10 +104,12 @@ func (m command) build() *cli.Command { Destination: &opts.format, }, &cli.StringFlag{ - Name: "mode", - Aliases: []string{"discovery-mode"}, - Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.", - Value: nvcdi.ModeAuto, + Name: "mode", + Aliases: []string{"discovery-mode"}, + Usage: "The mode to use when discovering the available entities. " + + "One of [" + strings.Join(nvcdi.AllModes[string](), " | ") + "]. " + + "If mode is set to 'auto' the mode will be determined based on the system configuration.", + Value: string(nvcdi.ModeAuto), Destination: &opts.mode, }, &cli.StringFlag{ @@ -184,13 +186,7 @@ func (m command) validateFlags(c *cli.Context, opts *options) error { } opts.mode = strings.ToLower(opts.mode) - switch opts.mode { - case nvcdi.ModeAuto: - case nvcdi.ModeCSV: - case nvcdi.ModeNvml: - case nvcdi.ModeWsl: - case nvcdi.ModeManagement: - default: + if !nvcdi.IsValidMode(opts.mode) { return fmt.Errorf("invalid discovery mode: %v", opts.mode) } diff --git a/pkg/nvcdi/api.go b/pkg/nvcdi/api.go index 1c84fefa..f1c7b97a 100644 --- a/pkg/nvcdi/api.go +++ b/pkg/nvcdi/api.go @@ -24,24 +24,6 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" ) -const ( - // ModeAuto configures the CDI spec generator to automatically detect the system configuration - ModeAuto = "auto" - // ModeNvml configures the CDI spec generator to use the NVML library. - ModeNvml = "nvml" - // ModeWsl configures the CDI spec generator to generate a WSL spec. - ModeWsl = "wsl" - // ModeManagement configures the CDI spec generator to generate a management spec. - ModeManagement = "management" - // ModeGds configures the CDI spec generator to generate a GDS spec. - ModeGds = "gds" - // ModeMofed configures the CDI spec generator to generate a MOFED spec. - ModeMofed = "mofed" - // ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV - // mountspec files. - ModeCSV = "csv" -) - // Interface defines the API for the nvcdi package type Interface interface { GetSpec() (spec.Interface, error) diff --git a/pkg/nvcdi/lib-imex.go b/pkg/nvcdi/lib-imex.go new file mode 100644 index 00000000..3c375d56 --- /dev/null +++ b/pkg/nvcdi/lib-imex.go @@ -0,0 +1,118 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + "path/filepath" + "strconv" + "strings" + + "tags.cncf.io/container-device-interface/pkg/cdi" + "tags.cncf.io/container-device-interface/specs-go" + + "github.com/NVIDIA/go-nvlib/pkg/nvlib/device" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" +) + +type imexlib nvcdilib + +var _ Interface = (*imexlib)(nil) + +const ( + classImexChannel = "imex-channel" +) + +// GetSpec should not be called for imexlib. +func (l *imexlib) GetSpec() (spec.Interface, error) { + return nil, fmt.Errorf("unexpected call to imexlib.GetSpec()") +} + +// GetAllDeviceSpecs returns the device specs for all available devices. +func (l *imexlib) GetAllDeviceSpecs() ([]specs.Device, error) { + channelsDiscoverer := discover.NewCharDeviceDiscoverer( + l.logger, + l.devRoot, + []string{"/dev/nvidia-caps-imex-channels/channel*"}, + ) + + channels, err := channelsDiscoverer.Devices() + if err != nil { + return nil, err + } + + var channelIDs []string + for _, channel := range channels { + channelIDs = append(channelIDs, filepath.Base(channel.Path)) + } + + return l.GetDeviceSpecsByID(channelIDs...) +} + +// GetCommonEdits returns an empty set of edits for IMEX devices. +func (l *imexlib) GetCommonEdits() (*cdi.ContainerEdits, error) { + return edits.FromDiscoverer(discover.None{}) +} + +// GetDeviceSpecsByID returns the CDI device specs for the IMEX channels specified. +func (l *imexlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) { + var deviceSpecs []specs.Device + for _, id := range ids { + trimmed := strings.TrimPrefix(id, "channel") + _, err := strconv.ParseUint(trimmed, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid channel ID %v: %w", id, err) + } + path := "/dev/nvidia-caps-imex-channels/channel" + trimmed + deviceSpec := specs.Device{ + Name: trimmed, + ContainerEdits: specs.ContainerEdits{ + DeviceNodes: []*specs.DeviceNode{ + { + Path: path, + HostPath: filepath.Join(l.devRoot, path), + }, + }, + }, + } + deviceSpecs = append(deviceSpecs, deviceSpec) + } + return deviceSpecs, nil +} + +// GetGPUDeviceEdits is unsupported for the imexlib specs +func (l *imexlib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) { + return nil, fmt.Errorf("GetGPUDeviceEdits is not supported") +} + +// GetGPUDeviceSpecs is unsupported for the imexlib specs +func (l *imexlib) GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error) { + return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported") +} + +// GetMIGDeviceEdits is unsupported for the imexlib specs +func (l *imexlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) { + return nil, fmt.Errorf("GetMIGDeviceEdits is not supported") +} + +// GetMIGDeviceSpecs is unsupported for the imexlib specs +func (l *imexlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) { + return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported") +} diff --git a/pkg/nvcdi/lib-imex_test.go b/pkg/nvcdi/lib-imex_test.go new file mode 100644 index 00000000..c6d099c9 --- /dev/null +++ b/pkg/nvcdi/lib-imex_test.go @@ -0,0 +1,80 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "bytes" + "path/filepath" + "strings" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/test" +) + +func TestImexMode(t *testing.T) { + t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true") + + logger, _ := testlog.NewNullLogger() + + moduleRoot, err := test.GetModuleRoot() + require.NoError(t, err) + hostRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1") + + expectedSpec := `--- +cdiVersion: 0.5.0 +containerEdits: + env: + - NVIDIA_VISIBLE_DEVICES=void +devices: +- containerEdits: + deviceNodes: + - hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel0 + path: /dev/nvidia-caps-imex-channels/channel0 + name: "0" +- containerEdits: + deviceNodes: + - hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel1 + path: /dev/nvidia-caps-imex-channels/channel1 + name: "1" +- containerEdits: + deviceNodes: + - hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel2047 + path: /dev/nvidia-caps-imex-channels/channel2047 + name: "2047" +kind: nvidia.com/imex-channel +` + expectedSpec = strings.ReplaceAll(expectedSpec, "{{ .hostRoot }}", hostRoot) + + lib, err := New( + WithLogger(logger), + WithMode(ModeImex), + WithDriverRoot(hostRoot), + ) + require.NoError(t, err) + + spec, err := lib.GetSpec() + require.NoError(t, err) + + var b bytes.Buffer + + _, err = spec.WriteTo(&b) + require.NoError(t, err) + require.Equal(t, expectedSpec, b.String()) +} diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go index 01c22ff3..c940b090 100644 --- a/pkg/nvcdi/lib-nvml.go +++ b/pkg/nvcdi/lib-nvml.go @@ -37,7 +37,7 @@ var _ Interface = (*nvmllib)(nil) // GetSpec should not be called for nvmllib func (l *nvmllib) GetSpec() (spec.Interface, error) { - return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()") + return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()") } // GetAllDeviceSpecs returns the device specs for all available devices. diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 8ed9e5aa..91c837fa 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -46,7 +46,7 @@ type nvcdilib struct { logger logger.Interface nvmllib nvml.Interface nvsandboxutilslib nvsandboxutils.Interface - mode string + mode Mode devicelib device.Interface deviceNamers DeviceNamers driverRoot string @@ -161,6 +161,11 @@ func New(opts ...Option) (Interface, error) { l.class = "mofed" } lib = (*mofedlib)(l) + case ModeImex: + if l.class == "" { + l.class = classImexChannel + } + lib = (*imexlib)(l) default: return nil, fmt.Errorf("unknown mode %q", l.mode) } @@ -206,28 +211,6 @@ func (m *wrapper) GetCommonEdits() (*cdi.ContainerEdits, error) { return edits, nil } -// resolveMode resolves the mode for CDI spec generation based on the current system. -func (l *nvcdilib) resolveMode() (rmode string) { - if l.mode != ModeAuto { - return l.mode - } - defer func() { - l.logger.Infof("Auto-detected mode as '%v'", rmode) - }() - - platform := l.infolib.ResolvePlatform() - switch platform { - case info.PlatformNVML: - return ModeNvml - case info.PlatformTegra: - return ModeCSV - case info.PlatformWSL: - return ModeWsl - } - l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml) - return ModeNvml -} - // getCudaVersion returns the CUDA version of the current system. func (l *nvcdilib) getCudaVersion() (string, error) { version, err := l.getCudaVersionNvsandboxutils() diff --git a/pkg/nvcdi/mode.go b/pkg/nvcdi/mode.go new file mode 100644 index 00000000..5b8f0369 --- /dev/null +++ b/pkg/nvcdi/mode.go @@ -0,0 +1,119 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "sync" + + "github.com/NVIDIA/go-nvlib/pkg/nvlib/info" +) + +type Mode string + +const ( + // ModeAuto configures the CDI spec generator to automatically detect the system configuration + ModeAuto = Mode("auto") + // ModeNvml configures the CDI spec generator to use the NVML library. + ModeNvml = Mode("nvml") + // ModeWsl configures the CDI spec generator to generate a WSL spec. + ModeWsl = Mode("wsl") + // ModeManagement configures the CDI spec generator to generate a management spec. + ModeManagement = Mode("management") + // ModeGds configures the CDI spec generator to generate a GDS spec. + ModeGds = Mode("gds") + // ModeMofed configures the CDI spec generator to generate a MOFED spec. + ModeMofed = Mode("mofed") + // ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV + // mountspec files. + ModeCSV = Mode("csv") + // ModeImex configures the CDI spec generated to generate a spec for the available IMEX channels. + ModeImex = Mode("imex") +) + +type modeConstraint interface { + string | Mode +} + +type modes struct { + lookup map[Mode]bool + all []Mode +} + +var validModes modes +var validModesOnce sync.Once + +func getModes() modes { + validModesOnce.Do(func() { + all := []Mode{ + ModeAuto, + ModeNvml, + ModeWsl, + ModeManagement, + ModeGds, + ModeMofed, + ModeCSV, + } + lookup := make(map[Mode]bool) + + for _, m := range all { + lookup[m] = true + } + + validModes = modes{ + lookup: lookup, + all: all, + } + }, + ) + return validModes +} + +// AllModes returns the set of valid modes. +func AllModes[T modeConstraint]() []T { + var output []T + for _, m := range getModes().all { + output = append(output, T(m)) + } + return output +} + +// IsValidMode checks whether a specified mode is valid. +func IsValidMode[T modeConstraint](mode T) bool { + return getModes().lookup[Mode(mode)] +} + +// resolveMode resolves the mode for CDI spec generation based on the current system. +func (l *nvcdilib) resolveMode() (rmode Mode) { + if l.mode != ModeAuto { + return l.mode + } + defer func() { + l.logger.Infof("Auto-detected mode as '%v'", rmode) + }() + + platform := l.infolib.ResolvePlatform() + switch platform { + case info.PlatformNVML: + return ModeNvml + case info.PlatformTegra: + return ModeCSV + case info.PlatformWSL: + return ModeWsl + } + l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml) + return ModeNvml +} diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 417687b9..362545d2 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -99,9 +99,9 @@ func WithNvmlLib(nvmllib nvml.Interface) Option { } // WithMode sets the discovery mode for the library -func WithMode(mode string) Option { +func WithMode[m modeConstraint](mode m) Option { return func(l *nvcdilib) { - l.mode = mode + l.mode = Mode(mode) } } diff --git a/testdata/lookup/rootfs-1/dev/nvidia-caps-imex-channels/channel0 b/testdata/lookup/rootfs-1/dev/nvidia-caps-imex-channels/channel0 new file mode 100644 index 00000000..e69de29b diff --git a/testdata/lookup/rootfs-1/dev/nvidia-caps-imex-channels/channel1 b/testdata/lookup/rootfs-1/dev/nvidia-caps-imex-channels/channel1 new file mode 100644 index 00000000..e69de29b diff --git a/testdata/lookup/rootfs-1/dev/nvidia-caps-imex-channels/channel2047 b/testdata/lookup/rootfs-1/dev/nvidia-caps-imex-channels/channel2047 new file mode 100644 index 00000000..e69de29b diff --git a/tools/container/toolkit/toolkit_test.go b/tools/container/toolkit/toolkit_test.go index 358d9166..203b3245 100644 --- a/tools/container/toolkit/toolkit_test.go +++ b/tools/container/toolkit/toolkit_test.go @@ -100,6 +100,12 @@ devices: path: /dev/nvidia0 - hostPath: /host/driver/root/dev/nvidiactl path: /dev/nvidiactl + - hostPath: /host/driver/root/dev/nvidia-caps-imex-channels/channel0 + path: /dev/nvidia-caps-imex-channels/channel0 + - hostPath: /host/driver/root/dev/nvidia-caps-imex-channels/channel1 + path: /dev/nvidia-caps-imex-channels/channel1 + - hostPath: /host/driver/root/dev/nvidia-caps-imex-channels/channel2047 + path: /dev/nvidia-caps-imex-channels/channel2047 name: all kind: example.com/class `,