diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index 749f1f0c..ab5ddee7 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -122,11 +122,10 @@ func TestGoodInput(t *testing.T) { err = cmdCreate.Run() require.NoError(t, err, "runtime should not return an error") - // Check config.json for NVIDIA prestart hook + // Check config.json to ensure that the NVIDIA prestart was not inserted. spec, err = cfg.getRuntimeSpec() require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") - require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json") - require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") + require.Empty(t, spec.Hooks, "there should be no hooks in config.json") } // NVIDIA prestart hook already present in config file @@ -168,11 +167,10 @@ func TestDuplicateHook(t *testing.T) { output, err := cmdCreate.CombinedOutput() require.NoErrorf(t, err, "runtime should not return an error", "output=%v", string(output)) - // Check config.json for NVIDIA prestart hook + // Check config.json to ensure that the NVIDIA prestart hook was removed. spec, err = cfg.getRuntimeSpec() require.NoError(t, err, "should be no errors when reading and parsing spec from config.json") - require.NotEmpty(t, spec.Hooks, "there should be hooks in config.json") - require.Equal(t, 1, nvidiaHookCount(spec.Hooks), "exactly one nvidia prestart hook should be inserted correctly into config.json") + require.Empty(t, spec.Hooks, "there should be no hooks in config.json") } // addNVIDIAHook is a basic wrapper for an addHookModifier that is used for @@ -240,18 +238,3 @@ func (c testConfig) generateNewRuntimeSpec() error { } return nil } - -// Return number of valid NVIDIA prestart hooks in runtime spec -func nvidiaHookCount(hooks *specs.Hooks) int { - if hooks == nil { - return 0 - } - - count := 0 - for _, hook := range hooks.Prestart { - if strings.Contains(hook.Path, nvidiaHook) { - count++ - } - } - return count -} diff --git a/internal/info/auto.go b/internal/info/auto.go index 7546731a..ce64fc6e 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -41,6 +41,9 @@ const ( // to the container config required for the requested CDI devices in the // same way that other CDI clients would. CDIRuntimeMode = RuntimeMode("cdi") + // In JitCDIRuntimeMode the nvidia-container-runtime generates in-memory CDI + // specifications for requested NVIDIA devices. + JitCDIRuntimeMode = RuntimeMode("jit-cdi") ) type RuntimeModeResolver interface { @@ -116,9 +119,9 @@ func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode RuntimeMode) { switch nvinfo.ResolvePlatform() { case info.PlatformNVML, info.PlatformWSL: - return LegacyRuntimeMode + return JitCDIRuntimeMode case info.PlatformTegra: return CSVRuntimeMode } - return LegacyRuntimeMode + return JitCDIRuntimeMode } diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index 25f14327..f6d99c7e 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -43,11 +43,16 @@ func TestResolveAutoMode(t *testing.T) { mode: "not-auto", expectedMode: "not-auto", }, + { + description: "legacy resolves to legacy", + mode: "legacy", + expectedMode: "legacy", + }, { description: "no info defaults to legacy", mode: "auto", info: map[string]bool{}, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "non-nvml, non-tegra, nvgpu resolves to csv", @@ -80,14 +85,14 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "csv", }, { - description: "nvml, non-tegra, non-nvgpu resolves to legacy", + description: "nvml, non-tegra, non-nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": true, "tegra": false, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "nvml, non-tegra, nvgpu resolves to csv", @@ -100,14 +105,14 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "csv", }, { - description: "nvml, tegra, non-nvgpu resolves to legacy", + description: "nvml, tegra, non-nvgpu resolves to jit-cdi", mode: "auto", info: map[string]bool{ "nvml": true, "tegra": true, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "nvml, tegra, nvgpu resolves to csv", @@ -136,7 +141,7 @@ func TestResolveAutoMode(t *testing.T) { }, }, { - description: "at least one non-cdi device resolves to legacy", + description: "at least one non-cdi device resolves to jit-cdi", mode: "auto", envmap: map[string]string{ "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0,0", @@ -146,7 +151,7 @@ func TestResolveAutoMode(t *testing.T) { "tegra": false, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "at least one non-cdi device resolves to csv", @@ -170,7 +175,7 @@ func TestResolveAutoMode(t *testing.T) { expectedMode: "cdi", }, { - description: "cdi mount and non-CDI devices resolves to legacy", + description: "cdi mount and non-CDI devices resolves to jit-cdi", mode: "auto", mounts: []string{ "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0", @@ -181,7 +186,7 @@ func TestResolveAutoMode(t *testing.T) { "tegra": false, "nvgpu": false, }, - expectedMode: "legacy", + expectedMode: "jit-cdi", }, { description: "cdi mount and non-CDI envvar resolves to cdi", @@ -199,22 +204,6 @@ func TestResolveAutoMode(t *testing.T) { }, expectedMode: "cdi", }, - { - description: "non-cdi mount and CDI envvar resolves to legacy", - mode: "auto", - envmap: map[string]string{ - "NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0", - }, - mounts: []string{ - "/var/run/nvidia-container-devices/0", - }, - info: map[string]bool{ - "nvml": true, - "tegra": false, - "nvgpu": false, - }, - expectedMode: "legacy", - }, } for _, tc := range testCases { diff --git a/internal/modifier/cdi.go b/internal/modifier/cdi.go index ff75eccc..2a73a054 100644 --- a/internal/modifier/cdi.go +++ b/internal/modifier/cdi.go @@ -18,6 +18,7 @@ package modifier import ( "fmt" + "strings" "tags.cncf.io/container-device-interface/pkg/parser" @@ -27,17 +28,27 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi" - "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" +) + +const ( + automaticDeviceVendor = "runtime.nvidia.com" + automaticDeviceClass = "gpu" + automaticDeviceKind = automaticDeviceVendor + "/" + automaticDeviceClass + automaticDevicePrefix = automaticDeviceKind + "=" ) // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is // used to select the devices to include. -func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { +func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, isJitCDI bool) (oci.SpecModifier, error) { + defaultKind := cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind + if isJitCDI { + defaultKind = automaticDeviceKind + } deviceRequestor := newCDIDeviceRequestor( logger, image, - cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, + defaultKind, ) devices := deviceRequestor.DeviceRequests() if len(devices) == 0 { @@ -107,17 +118,34 @@ func (c *cdiDeviceRequestor) DeviceRequests() []string { func filterAutomaticDevices(devices []string) []string { var automatic []string for _, device := range devices { - vendor, class, _ := parser.ParseDevice(device) - if vendor == "runtime.nvidia.com" && class == "gpu" { - automatic = append(automatic, device) + if !strings.HasPrefix(device, automaticDevicePrefix) { + continue } + automatic = append(automatic, device) } return automatic } func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, devices []string) (oci.SpecModifier, error) { logger.Debugf("Generating in-memory CDI specs for devices %v", devices) - spec, err := generateAutomaticCDISpec(logger, cfg, devices) + + var identifiers []string + for _, device := range devices { + identifiers = append(identifiers, strings.TrimPrefix(device, automaticDevicePrefix)) + } + + cdilib, err := nvcdi.New( + nvcdi.WithLogger(logger), + nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), + nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), + nvcdi.WithVendor(automaticDeviceVendor), + nvcdi.WithClass(automaticDeviceClass), + ) + if err != nil { + return nil, fmt.Errorf("failed to construct CDI library: %w", err) + } + + spec, err := cdilib.GetSpec(identifiers...) if err != nil { return nil, fmt.Errorf("failed to generate CDI spec: %w", err) } @@ -132,27 +160,6 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de return cdiDeviceRequestor, nil } -func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) { - cdilib, err := nvcdi.New( - nvcdi.WithLogger(logger), - nvcdi.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path), - nvcdi.WithDriverRoot(cfg.NVIDIAContainerCLIConfig.Root), - nvcdi.WithVendor("runtime.nvidia.com"), - nvcdi.WithClass("gpu"), - ) - if err != nil { - return nil, fmt.Errorf("failed to construct CDI library: %w", err) - } - - var identifiers []string - for _, device := range devices { - _, _, id := parser.ParseDevice(device) - identifiers = append(identifiers, id) - } - - return cdilib.GetSpec(identifiers...) -} - type deduplicatedDeviceRequestor struct { deviceRequestor } diff --git a/internal/modifier/cdi_test.go b/internal/modifier/cdi_test.go index a5ca78f5..0163e4ce 100644 --- a/internal/modifier/cdi_test.go +++ b/internal/modifier/cdi_test.go @@ -70,6 +70,18 @@ func TestDeviceRequests(t *testing.T) { }, expectedDevices: []string{"nvidia.com/gpu=0", "example.com/class=device"}, }, + { + description: "cdi devices from envvar with default kind", + input: cdiDeviceRequestor{ + defaultKind: "runtime.nvidia.com/gpu", + }, + spec: &specs.Spec{ + Process: &specs.Process{ + Env: []string{"NVIDIA_VISIBLE_DEVICES=all"}, + }, + }, + expectedDevices: []string{"runtime.nvidia.com/gpu=all"}, + }, { description: "no matching annotations", prefixes: []string{"not-prefix/"}, diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index 99da0f94..c5ea892f 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -107,8 +107,8 @@ func newModeModifier(logger logger.Interface, mode info.RuntimeMode, cfg *config return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil case info.CSVRuntimeMode: return modifier.NewCSVModifier(logger, cfg, image) - case info.CDIRuntimeMode: - return modifier.NewCDIModifier(logger, cfg, image) + case info.CDIRuntimeMode, info.JitCDIRuntimeMode: + return modifier.NewCDIModifier(logger, cfg, image, mode == info.JitCDIRuntimeMode) } return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode) @@ -160,7 +160,7 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe // supportedModifierTypes returns the modifiers supported for a specific runtime mode. func supportedModifierTypes(mode info.RuntimeMode) []string { switch mode { - case info.CDIRuntimeMode: + case info.CDIRuntimeMode, info.JitCDIRuntimeMode: // For CDI mode we make no additional modifications. return []string{"nvidia-hook-remover", "mode"} case info.CSVRuntimeMode: