diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index 749f1f0c..192be38d 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -125,8 +125,8 @@ func TestGoodInput(t *testing.T) { // Check config.json for NVIDIA prestart hook spec, err = cfg.getRuntimeSpec() require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") - require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json") - require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") + require.Empty(t, spec.Hooks, "there should be hooks in config.json") + require.Equal(t, 0, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") } // NVIDIA prestart hook already present in config file @@ -171,8 +171,8 @@ func TestDuplicateHook(t *testing.T) { // Check config.json for NVIDIA prestart hook spec, err = cfg.getRuntimeSpec() require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") - require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json") - require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") + require.Empty(t, spec.Hooks, "there should be no hooks in config.json") + require.Equal(t, 0, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") } // addNVIDIAHook is a basic wrapper for an addHookModifier that is used for diff --git a/internal/info/auto.go b/internal/info/auto.go index 27e5132e..be3610c2 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -30,6 +30,7 @@ const ( RuntimeModeLegacy = RuntimeMode("legacy") RuntimeModeCSV = RuntimeMode("csv") RuntimeModeCDI = RuntimeMode("cdi") + RuntimeModeJitCDI = RuntimeMode("jit-cdi") ) // ResolveAutoMode determines the correct mode for the platform if set to "auto" @@ -57,9 +58,9 @@ func resolveMode(logger logger.Interface, mode string, image image.CUDA, propert switch nvinfo.ResolvePlatform() { case info.PlatformNVML, info.PlatformWSL: - return RuntimeModeLegacy + return RuntimeModeJitCDI case info.PlatformTegra: return RuntimeModeCSV } - return RuntimeModeLegacy + return RuntimeModeJitCDI } diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index 4fbfcde4..8703bfe0 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -44,10 +44,15 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "not-auto", }, { - description: "no info defaults to legacy", + description: "legacy resolves to legacy", + mode: "legacy", + expectedMode: "legacy", + }, + { + description: "no info defaults to jit-cdi", mode: "auto", info: map[string]bool{}, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "non-nvml, non-tegra, nvgpu resolves to csv", @@ -80,14 +85,14 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "csv", }, { - description: "nvml, non-tegra, non-nvgpu resolves to legacy", + description: "nvml, non-tegra, non-nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": true, "tegra": false, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "nvml, non-tegra, nvgpu resolves to csv", @@ -100,14 +105,14 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "csv", }, { - description: "nvml, tegra, non-nvgpu resolves to legacy", + description: "nvml, tegra, non-nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": true, "tegra": true, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "nvml, tegra, nvgpu resolves to csv", @@ -136,7 +141,7 @@ func TestResolveAutoMode(t *testing.T) { }, }, { - description: "at least one non-cdi device resolves to legacy", + description: "at least one non-cdi device resolves to jit-cdi", mode: "auto", envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0", @@ -146,7 +151,7 @@ func TestResolveAutoMode(t *testing.T) { "tegra": false, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "at least one non-cdi device resolves to csv", @@ -170,7 +175,7 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "cdi", }, { - description: "cdi mount and non-CDI devices resolves to legacy", + description: "cdi mount and non-CDI devices resolves to jit-cdi", mode: "auto", mounts: []string{ "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0", @@ -181,10 +186,10 @@ func TestResolveAutoMode(t *testing.T) { "tegra": false, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { - description: "cdi mount and non-CDI envvar resolves to legacy", + description: "cdi mount and non-CDI envvar resolves to jit-cdi", mode: "auto", envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "0", @@ -197,7 +202,7 @@ func TestResolveAutoMode(t *testing.T) { "tegra": false, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, } diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index 90cd481b..aac6c762 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -31,11 +31,22 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" ) +const ( + automaticDeviceVendor = "runtime.nvidia.com" + automaticDeviceClass = "gpu" + automaticDeviceKind = automaticDeviceVendor + "/" + automaticDeviceClass +) + // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is // used to select the devices to include. -func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) { - devices, err := getDevicesFromSpec(logger, ociSpec, cfg) +func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, isJitCDI bool) (oci.SpecModifier, error) { + defaultKind := cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind + if isJitCDI { + defaultKind = automaticDeviceKind + } + + devices, err := getDevicesFromSpec(logger, ociSpec, cfg, defaultKind) if err != nil { return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err) } @@ -65,7 +76,7 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spe ) } -func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) { +func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config, defaultKind string) ([]string, error) { rawSpec, err := ociSpec.Load() if err != nil { return nil, fmt.Errorf("failed to load OCI spec: %v", err) @@ -83,26 +94,16 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C if err != nil { return nil, err } - if cfg.AcceptDeviceListAsVolumeMounts { - mountDevices := container.CDIDevicesFromMounts() - if len(mountDevices) > 0 { - return mountDevices, nil - } - } var devices []string - seen := make(map[string]bool) - for _, name := range container.VisibleDevicesFromEnvVar() { - if !parser.IsQualifiedName(name) { - name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name) + if cfg.AcceptDeviceListAsVolumeMounts { + devices = normalizeDeviceList(logger, defaultKind, append(container.DevicesFromMounts(), container.CDIDevicesFromMounts()...)...) + if len(devices) > 0 { + return devices, nil } - if seen[name] { - logger.Debugf("Ignoring duplicate device %q", name) - continue - } - devices = append(devices, name) } + devices = normalizeDeviceList(logger, defaultKind, container.VisibleDevicesFromEnvVar()...) if len(devices) == 0 { return nil, nil } @@ -116,6 +117,24 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C return nil, nil } +func normalizeDeviceList(logger logger.Interface, defaultKind string, devices ...string) []string { + seen := make(map[string]bool) + var normalized []string + for _, name := range devices { + if !parser.IsQualifiedName(name) { + name = fmt.Sprintf("%s=%s", defaultKind, name) + } + if seen[name] { + logger.Debugf("Ignoring duplicate device %q", name) + continue + } + normalized = append(normalized, name) + seen[name] = true + } + + return normalized +} + // getAnnotationDevices returns a list of devices specified in the annotations. // Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of // fully-qualified CDI devices names. If any device name is not fully-quality an error is returned. @@ -156,7 +175,7 @@ func filterAutomaticDevices(devices []string) []string { var automatic []string for _, device := range devices { vendor, class, _ := parser.ParseDevice(device) - if vendor == "runtime.nvidia.com" && class == "gpu" { + if vendor == automaticDeviceVendor && class == automaticDeviceClass { automatic = append(automatic, device) } } @@ -165,6 +184,8 @@ func filterAutomaticDevices(devices []string) []string { func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) { logger.Debugf("Generating in-memory CDI specs for devices %v", devices) + // TODO: We should try to load the kernel modules and create the device nodes here. + // Failures should raise a warning and not error out. spec, err := generateAutomaticCDISpec(logger, cfg, devices) if err != nil { return nil, fmt.Errorf("failed to generate CDI spec: %w", err) diff --git a/internal/modifier/cdi_test.go b/internal/modifier/cdi_test.go index 88ff697a..d725a4cb 100644 --- a/internal/modifier/cdi_test.go +++ b/internal/modifier/cdi_test.go @@ -20,9 +20,49 @@ import ( "fmt" "testing" + "github.com/opencontainers/runtime-spec/specs-go" + testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" ) +func TestGetDevicesFromSpec(t *testing.T) { + logger, _ := testlog.NewNullLogger() + testCases := []struct { + description string + spec *specs.Spec + config *config.Config + defaultKind string + expectedDevices []string + }{ + { + description: "NVIDIA_VISIBLE_DEVICES=all", + spec: &specs.Spec{ + Process: &specs.Process{ + Env: []string{"NVIDIA_VISIBLE_DEVICES=all"}, + }, + }, + config: func() *config.Config { + c, _ := config.GetDefault() + return c + }(), + defaultKind: "runtime.nvidia.com/gpu", + expectedDevices: []string{"runtime.nvidia.com/gpu=all"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + devices, err := getDevicesFromSpec(logger, oci.NewMemorySpec(tc.spec), tc.config, tc.defaultKind) + require.NoError(t, err) + + require.EqualValues(t, tc.expectedDevices, devices) + }) + } +} + func TestGetAnnotationDevices(t *testing.T) { testCases := []struct { description string diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index 90e004d4..2ac6fc8b 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -111,8 +111,8 @@ func newModeModifier(logger logger.Interface, mode info.RuntimeMode, cfg *config return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil case info.RuntimeModeCSV: return modifier.NewCSVModifier(logger, cfg, image) - case info.RuntimeModeCDI: - return modifier.NewCDIModifier(logger, cfg, ociSpec) + case info.RuntimeModeCDI, info.RuntimeModeJitCDI: + return modifier.NewCDIModifier(logger, cfg, ociSpec, mode == info.RuntimeModeJitCDI) } return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode) @@ -121,7 +121,7 @@ func newModeModifier(logger logger.Interface, mode info.RuntimeMode, cfg *config // supportedModifierTypes returns the modifiers supported for a specific runtime mode. func supportedModifierTypes(mode info.RuntimeMode) []string { switch mode { - case info.RuntimeModeCDI: + case info.RuntimeModeCDI, info.RuntimeModeJitCDI: // For CDI mode we make no additional modifications. return []string{"nvidia-hook-remover", "mode"} case info.RuntimeModeCSV: