diff --git a/cmd/nvidia-container-toolkit/capabilities.go b/cmd/nvidia-container-toolkit/capabilities.go index 5c53c8da..062af2e7 100644 --- a/cmd/nvidia-container-toolkit/capabilities.go +++ b/cmd/nvidia-container-toolkit/capabilities.go @@ -2,6 +2,15 @@ package main import ( "log" + "strings" +) + +const ( + allDriverCapabilities = DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx") + defaultDriverCapabilities = DriverCapabilities("utility,compute") + + none = DriverCapabilities("") + all = DriverCapabilities("all") ) func capabilityToCLI(cap string) string { @@ -25,3 +34,50 @@ func capabilityToCLI(cap string) string { } return "" } + +// DriverCapabilities is used to process the NVIDIA_DRIVER_CAPABILITIES environment +// variable. Operations include default values, filtering, and handling meta values such as "all" +type DriverCapabilities string + +// Intersection returns intersection between two sets of capabilities. +func (d DriverCapabilities) Intersection(capabilities DriverCapabilities) DriverCapabilities { + if capabilities == all { + return d + } + if d == all { + return capabilities + } + + lookup := make(map[string]bool) + for _, c := range d.list() { + lookup[c] = true + } + var found []string + for _, c := range capabilities.list() { + if lookup[c] { + found = append(found, c) + } + } + + intersection := DriverCapabilities(strings.Join(found, ",")) + return intersection +} + +// String returns the string representation of the driver capabilities +func (d DriverCapabilities) String() string { + return string(d) +} + +// list returns the driver capabilities as a list +func (d DriverCapabilities) list() []string { + var caps []string + for _, c := range strings.Split(string(d), ",") { + trimmed := strings.TrimSpace(c) + if len(trimmed) == 0 { + continue + } + caps = append(caps, trimmed) + } + + return caps +} diff --git a/cmd/nvidia-container-toolkit/capabilities_test.go b/cmd/nvidia-container-toolkit/capabilities_test.go new file mode 100644 index 00000000..0386e6a7 --- /dev/null +++ b/cmd/nvidia-container-toolkit/capabilities_test.go @@ -0,0 +1,134 @@ +/** +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package main + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDriverCapabilitiesIntersection(t *testing.T) { + testCases := []struct { + capabilities DriverCapabilities + supportedCapabilities DriverCapabilities + expectedIntersection DriverCapabilities + }{ + { + capabilities: none, + supportedCapabilities: none, + expectedIntersection: none, + }, + { + capabilities: all, + supportedCapabilities: none, + expectedIntersection: none, + }, + { + capabilities: all, + supportedCapabilities: allDriverCapabilities, + expectedIntersection: allDriverCapabilities, + }, + { + capabilities: allDriverCapabilities, + supportedCapabilities: all, + expectedIntersection: allDriverCapabilities, + }, + { + capabilities: none, + supportedCapabilities: all, + expectedIntersection: none, + }, + { + capabilities: none, + supportedCapabilities: DriverCapabilities("cap1"), + expectedIntersection: none, + }, + { + capabilities: DriverCapabilities("cap0,cap1"), + supportedCapabilities: DriverCapabilities("cap1,cap0"), + expectedIntersection: DriverCapabilities("cap0,cap1"), + }, + { + capabilities: defaultDriverCapabilities, + supportedCapabilities: allDriverCapabilities, + expectedIntersection: defaultDriverCapabilities, + }, + { + capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"), + supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"), + expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"), + }, + { + capabilities: DriverCapabilities("cap1"), + supportedCapabilities: none, + expectedIntersection: none, + }, + { + capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"), + supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"), + expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"), + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + intersection := tc.supportedCapabilities.Intersection(tc.capabilities) + require.EqualValues(t, tc.expectedIntersection, intersection) + }) + } +} + +func TestDriverCapabilitiesList(t *testing.T) { + testCases := []struct { + capabilities DriverCapabilities + expected []string + }{ + { + capabilities: DriverCapabilities(""), + }, + { + capabilities: DriverCapabilities(" "), + }, + { + capabilities: DriverCapabilities(","), + }, + { + capabilities: DriverCapabilities(",cap"), + expected: []string{"cap"}, + }, + { + capabilities: DriverCapabilities("cap,"), + expected: []string{"cap"}, + }, + { + capabilities: DriverCapabilities("cap0,,cap1"), + expected: []string{"cap0", "cap1"}, + }, + { + capabilities: DriverCapabilities("cap1,cap0,cap3"), + expected: []string{"cap1", "cap0", "cap3"}, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + require.EqualValues(t, tc.expected, tc.capabilities.list()) + }) + } +} diff --git a/cmd/nvidia-container-toolkit/container_config.go b/cmd/nvidia-container-toolkit/container_config.go index dae4cc7b..b129895f 100644 --- a/cmd/nvidia-container-toolkit/container_config.go +++ b/cmd/nvidia-container-toolkit/container_config.go @@ -26,11 +26,6 @@ const ( envNVDriverCapabilities = "NVIDIA_DRIVER_CAPABILITIES" ) -const ( - allDriverCapabilities = "compute,compat32,graphics,utility,video,display,ngx" - defaultDriverCapabilities = "utility,compute" -) - const ( capSysAdmin = "CAP_SYS_ADMIN" ) @@ -316,33 +311,27 @@ func getMigMonitorDevices(env map[string]string) *string { return nil } -func getDriverCapabilities(env map[string]string, legacyImage bool) *string { - // Grab a reference to the capabilities from the envvar - // if it actually exists in the environment. - var capabilities *string - if caps, ok := env[envNVDriverCapabilities]; ok { - capabilities = &caps +func getDriverCapabilities(env map[string]string, supportedDriverCapabilities DriverCapabilities, legacyImage bool) DriverCapabilities { + // We use the default driver capabilities by default. This is filtered to only include the + // supported capabilities + capabilities := supportedDriverCapabilities.Intersection(defaultDriverCapabilities) + + capsEnv, capsEnvSpecified := env[envNVDriverCapabilities] + + if !capsEnvSpecified && legacyImage { + // Environment variable unset with legacy image: set all capabilities. + return supportedDriverCapabilities } - // Environment variable unset with legacy image: set all capabilities. - if capabilities == nil && legacyImage { - allCaps := allDriverCapabilities - return &allCaps + if capsEnvSpecified && len(capsEnv) > 0 { + // If the envvironment variable is specified and is non-empty, use the capabilities value + envCapabilities := DriverCapabilities(capsEnv) + capabilities = supportedDriverCapabilities.Intersection(envCapabilities) + if envCapabilities != all && capabilities != envCapabilities { + log.Panicln(fmt.Errorf("unsupported capabilities found in '%v' (allowed '%v')", envCapabilities, capabilities)) + } } - // Environment variable unset or set but empty: set default capabilities. - if capabilities == nil || len(*capabilities) == 0 { - defaultCaps := defaultDriverCapabilities - return &defaultCaps - } - - // Environment variable set to "all": set all capabilities. - if *capabilities == "all" { - allCaps := allDriverCapabilities - return &allCaps - } - - // Any other value return capabilities } @@ -389,10 +378,7 @@ func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mou log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") } - var driverCapabilities string - if c := getDriverCapabilities(env, legacyImage); c != nil { - driverCapabilities = *c - } + driverCapabilities := getDriverCapabilities(env, hookConfig.SupportedDriverCapabilities, legacyImage).String() requirements := getRequirements(env, legacyImage) diff --git a/cmd/nvidia-container-toolkit/container_test.go b/cmd/nvidia-container-toolkit/container_config_test.go similarity index 78% rename from cmd/nvidia-container-toolkit/container_test.go rename to cmd/nvidia-container-toolkit/container_config_test.go index 3cda6280..19d0bf23 100644 --- a/cmd/nvidia-container-toolkit/container_test.go +++ b/cmd/nvidia-container-toolkit/container_config_test.go @@ -12,6 +12,7 @@ func TestGetNvidiaConfig(t *testing.T) { description string env map[string]string privileged bool + hookConfig *HookConfig expectedConfig *nvidiaConfig expectedPanic bool }{ @@ -35,7 +36,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: allDriverCapabilities, + DriverCapabilities: allDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -49,7 +50,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: allDriverCapabilities, + DriverCapabilities: allDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -81,7 +82,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "", - DriverCapabilities: allDriverCapabilities, + DriverCapabilities: allDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -95,7 +96,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: allDriverCapabilities, + DriverCapabilities: allDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -110,7 +111,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -125,7 +126,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: allDriverCapabilities, + DriverCapabilities: allDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -135,12 +136,12 @@ func TestGetNvidiaConfig(t *testing.T) { env: map[string]string{ envCUDAVersion: "9.0", envNVVisibleDevices: "gpu0,gpu1", - envNVDriverCapabilities: "cap0,cap1", + envNVDriverCapabilities: "video,display", }, privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "cap0,cap1", + DriverCapabilities: "video,display", Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -150,14 +151,14 @@ func TestGetNvidiaConfig(t *testing.T) { env: map[string]string{ envCUDAVersion: "9.0", envNVVisibleDevices: "gpu0,gpu1", - envNVDriverCapabilities: "cap0,cap1", + envNVDriverCapabilities: "video,display", envNVRequirePrefix + "REQ0": "req0=true", envNVRequirePrefix + "REQ1": "req1=false", }, privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "cap0,cap1", + DriverCapabilities: "video,display", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, DisableRequire: false, }, @@ -167,7 +168,7 @@ func TestGetNvidiaConfig(t *testing.T) { env: map[string]string{ envCUDAVersion: "9.0", envNVVisibleDevices: "gpu0,gpu1", - envNVDriverCapabilities: "cap0,cap1", + envNVDriverCapabilities: "video,display", envNVRequirePrefix + "REQ0": "req0=true", envNVRequirePrefix + "REQ1": "req1=false", envNVDisableRequire: "true", @@ -175,7 +176,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "cap0,cap1", + DriverCapabilities: "video,display", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, DisableRequire: true, }, @@ -206,7 +207,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -238,7 +239,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -252,7 +253,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -267,7 +268,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -282,7 +283,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: allDriverCapabilities, + DriverCapabilities: allDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -292,12 +293,12 @@ func TestGetNvidiaConfig(t *testing.T) { env: map[string]string{ envNVRequireCUDA: "cuda>=9.0", envNVVisibleDevices: "gpu0,gpu1", - envNVDriverCapabilities: "cap0,cap1", + envNVDriverCapabilities: "video,display", }, privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "cap0,cap1", + DriverCapabilities: "video,display", Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -307,14 +308,14 @@ func TestGetNvidiaConfig(t *testing.T) { env: map[string]string{ envNVRequireCUDA: "cuda>=9.0", envNVVisibleDevices: "gpu0,gpu1", - envNVDriverCapabilities: "cap0,cap1", + envNVDriverCapabilities: "video,display", envNVRequirePrefix + "REQ0": "req0=true", envNVRequirePrefix + "REQ1": "req1=false", }, privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "cap0,cap1", + DriverCapabilities: "video,display", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, DisableRequire: false, }, @@ -324,7 +325,7 @@ func TestGetNvidiaConfig(t *testing.T) { env: map[string]string{ envNVRequireCUDA: "cuda>=9.0", envNVVisibleDevices: "gpu0,gpu1", - envNVDriverCapabilities: "cap0,cap1", + envNVDriverCapabilities: "video,display", envNVRequirePrefix + "REQ0": "req0=true", envNVRequirePrefix + "REQ1": "req1=false", envNVDisableRequire: "true", @@ -332,7 +333,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "cap0,cap1", + DriverCapabilities: "video,display", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, DisableRequire: true, }, @@ -346,7 +347,7 @@ func TestGetNvidiaConfig(t *testing.T) { expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{}, DisableRequire: false, }, @@ -362,7 +363,7 @@ func TestGetNvidiaConfig(t *testing.T) { expectedConfig: &nvidiaConfig{ Devices: "all", MigConfigDevices: "mig0,mig1", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -388,7 +389,7 @@ func TestGetNvidiaConfig(t *testing.T) { expectedConfig: &nvidiaConfig{ Devices: "all", MigMonitorDevices: "mig0,mig1", - DriverCapabilities: defaultDriverCapabilities, + DriverCapabilities: defaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, DisableRequire: false, }, @@ -403,14 +404,62 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedPanic: true, }, + { + description: "Hook config set as driver-capabilities-all", + env: map[string]string{ + envNVVisibleDevices: "all", + envNVDriverCapabilities: "all", + }, + privileged: true, + hookConfig: &HookConfig{ + SupportedDriverCapabilities: "video,display", + }, + expectedConfig: &nvidiaConfig{ + Devices: "all", + DriverCapabilities: "video,display", + }, + }, + { + description: "Hook config set, envvar sets driver-capabilities", + env: map[string]string{ + envNVVisibleDevices: "all", + envNVDriverCapabilities: "video,display", + }, + privileged: true, + hookConfig: &HookConfig{ + SupportedDriverCapabilities: "video,display,compute,utility", + }, + expectedConfig: &nvidiaConfig{ + Devices: "all", + DriverCapabilities: "video,display", + }, + }, + { + description: "Hook config set, envvar unset sets default driver-capabilities", + env: map[string]string{ + envNVVisibleDevices: "all", + }, + privileged: true, + hookConfig: &HookConfig{ + SupportedDriverCapabilities: "video,display,utility,compute", + }, + expectedConfig: &nvidiaConfig{ + Devices: "all", + DriverCapabilities: defaultDriverCapabilities.String(), + }, + }, } for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { // Wrap the call to getNvidiaConfig() in a closure. var config *nvidiaConfig getConfig := func() { - hookConfig := getDefaultHookConfig() - config = getNvidiaConfig(&hookConfig, tc.env, nil, tc.privileged) + hookConfig := tc.hookConfig + if hookConfig == nil { + defaultConfig := getDefaultHookConfig() + hookConfig = &defaultConfig + } + config = getNvidiaConfig(hookConfig, tc.env, nil, tc.privileged) } // For any tests that are expected to panic, make sure they do. @@ -822,3 +871,119 @@ func TestGetDevicesFromEnvvar(t *testing.T) { }) } } + +func TestGetDriverCapabilities(t *testing.T) { + + supportedCapabilities := "compute,utility,display,video" + + testCases := []struct { + description string + env map[string]string + legacyImage bool + supportedCapabilities string + expectedPanic bool + expectedCapabilities string + }{ + { + description: "Env is set for legacy image", + env: map[string]string{ + envNVDriverCapabilities: "display,video", + }, + legacyImage: true, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: "display,video", + }, + { + description: "Env is all for legacy image", + env: map[string]string{ + envNVDriverCapabilities: "all", + }, + legacyImage: true, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: supportedCapabilities, + }, + { + description: "Env is empty for legacy image", + env: map[string]string{ + envNVDriverCapabilities: "", + }, + legacyImage: true, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: defaultDriverCapabilities.String(), + }, + { + description: "Env unset for legacy image is 'all'", + env: map[string]string{}, + legacyImage: true, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: supportedCapabilities, + }, + { + description: "Env is set for modern image", + env: map[string]string{ + envNVDriverCapabilities: "display,video", + }, + legacyImage: false, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: "display,video", + }, + { + description: "Env unset for modern image is default", + env: map[string]string{}, + legacyImage: false, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: defaultDriverCapabilities.String(), + }, + { + description: "Env is all for modern image", + env: map[string]string{ + envNVDriverCapabilities: "all", + }, + legacyImage: false, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: supportedCapabilities, + }, + { + description: "Env is empty for modern image", + env: map[string]string{ + envNVDriverCapabilities: "", + }, + legacyImage: false, + supportedCapabilities: supportedCapabilities, + expectedCapabilities: defaultDriverCapabilities.String(), + }, + { + description: "Invalid capabilities panic", + env: map[string]string{ + envNVDriverCapabilities: "compute,utility", + }, + supportedCapabilities: "not-compute,not-utility", + expectedPanic: true, + }, + { + description: "Default is restricted for modern image", + legacyImage: false, + supportedCapabilities: "compute", + expectedCapabilities: "compute", + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + var capabilites DriverCapabilities + + getDriverCapabilities := func() { + supportedCapabilities := DriverCapabilities(tc.supportedCapabilities) + capabilites = getDriverCapabilities(tc.env, supportedCapabilities, tc.legacyImage) + } + + if tc.expectedPanic { + require.Panics(t, getDriverCapabilities) + return + } + + getDriverCapabilities() + require.EqualValues(t, tc.expectedCapabilities, capabilites) + }) + } +} diff --git a/cmd/nvidia-container-toolkit/hook_config.go b/cmd/nvidia-container-toolkit/hook_config.go index 8d35cb19..b4cb0133 100644 --- a/cmd/nvidia-container-toolkit/hook_config.go +++ b/cmd/nvidia-container-toolkit/hook_config.go @@ -35,10 +35,11 @@ type CLIConfig struct { // HookConfig : options for the nvidia-container-toolkit. type HookConfig struct { - DisableRequire bool `toml:"disable-require"` - SwarmResource *string `toml:"swarm-resource"` - AcceptEnvvarUnprivileged bool `toml:"accept-nvidia-visible-devices-envvar-when-unprivileged"` - AcceptDeviceListAsVolumeMounts bool `toml:"accept-nvidia-visible-devices-as-volume-mounts"` + DisableRequire bool `toml:"disable-require"` + SwarmResource *string `toml:"swarm-resource"` + AcceptEnvvarUnprivileged bool `toml:"accept-nvidia-visible-devices-envvar-when-unprivileged"` + AcceptDeviceListAsVolumeMounts bool `toml:"accept-nvidia-visible-devices-as-volume-mounts"` + SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"` NvidiaContainerCLI CLIConfig `toml:"nvidia-container-cli"` } @@ -49,6 +50,7 @@ func getDefaultHookConfig() (config HookConfig) { SwarmResource: nil, AcceptEnvvarUnprivileged: true, AcceptDeviceListAsVolumeMounts: false, + SupportedDriverCapabilities: allDriverCapabilities, NvidiaContainerCLI: CLIConfig{ Root: nil, Path: nil, @@ -85,6 +87,15 @@ func getHookConfig() (config HookConfig) { } } + if config.SupportedDriverCapabilities == all { + config.SupportedDriverCapabilities = allDriverCapabilities + } + // We ensure that the supported-driver-capabilites option is a subset of allDriverCapabilities + if intersection := allDriverCapabilities.Intersection(config.SupportedDriverCapabilities); intersection != config.SupportedDriverCapabilities { + configName := config.getConfigOption("SupportedDriverCapabilities") + log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allDriverCapabilities) + } + return config } diff --git a/cmd/nvidia-container-toolkit/hook_config_test.go b/cmd/nvidia-container-toolkit/hook_config_test.go new file mode 100644 index 00000000..4886664d --- /dev/null +++ b/cmd/nvidia-container-toolkit/hook_config_test.go @@ -0,0 +1,105 @@ +/** +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package main + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetHookConfig(t *testing.T) { + testCases := []struct { + lines []string + expectedPanic bool + expectedDriverCapabilities DriverCapabilities + }{ + { + expectedDriverCapabilities: allDriverCapabilities, + }, + { + lines: []string{ + "supported-driver-capabilities = \"all\"", + }, + expectedDriverCapabilities: allDriverCapabilities, + }, + { + lines: []string{ + "supported-driver-capabilities = \"compute,utility,not-compute\"", + }, + expectedPanic: true, + }, + { + lines: []string{}, + expectedDriverCapabilities: allDriverCapabilities, + }, + { + lines: []string{ + "supported-driver-capabilities = \"\"", + }, + expectedDriverCapabilities: none, + }, + { + lines: []string{ + "supported-driver-capabilities = \"utility,compute\"", + }, + expectedDriverCapabilities: DriverCapabilities("utility,compute"), + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + var filename string + defer func() { + if len(filename) > 0 { + os.Remove(filename) + } + configflag = nil + }() + + if tc.lines != nil { + configFile, err := os.CreateTemp("", "*.toml") + require.NoError(t, err) + defer configFile.Close() + + filename = configFile.Name() + configflag = &filename + + for _, line := range tc.lines { + _, err := configFile.WriteString(fmt.Sprintf("%s\n", line)) + require.NoError(t, err) + } + } + + var config HookConfig + getHookConfig := func() { + config = getHookConfig() + } + + if tc.expectedPanic { + require.Panics(t, getHookConfig) + return + } + + getHookConfig() + + require.EqualValues(t, tc.expectedDriverCapabilities, config.SupportedDriverCapabilities) + }) + } +} diff --git a/packaging/debian/changelog b/packaging/debian/changelog index 26ea10b6..7a10f506 100644 --- a/packaging/debian/changelog +++ b/packaging/debian/changelog @@ -1,5 +1,6 @@ nvidia-container-toolkit (1.6.0~rc.3-1) experimental; urgency=medium + * Add supported-driver-capabilities config option to the nvidia-container-toolkit * Move OCI and command line checks for runtime to internal oci package -- NVIDIA CORPORATION Mon, 15 Nov 2021 13:02:23 +0100 diff --git a/packaging/rpm/SPECS/nvidia-container-toolkit.spec b/packaging/rpm/SPECS/nvidia-container-toolkit.spec index 93f78b6d..5ff266a7 100644 --- a/packaging/rpm/SPECS/nvidia-container-toolkit.spec +++ b/packaging/rpm/SPECS/nvidia-container-toolkit.spec @@ -66,6 +66,7 @@ rm -f %{_bindir}/nvidia-container-runtime-hook %changelog * Mon Nov 15 2021 NVIDIA CORPORATION 1.6.0-0.1.rc.3 +- Add supported-driver-capabilities config option to the nvidia-container-toolkit - Move OCI and command line checks for runtime to internal oci package * Fri Nov 05 2021 NVIDIA CORPORATION 1.6.0-0.1.rc.2