From b18ac09f775462ee08494a2d783ad18c7d482914 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 10 Aug 2023 14:15:24 +0200 Subject: [PATCH] Refactor handling of DriverCapabilities This change consolidates the handling of NVIDIA_DRIVER_CAPABILITIES in the interal/image package. Signed-off-by: Evan Lezar --- .../capabilities.go | 56 -------- .../capabilities_test.go | 134 ------------------ .../container_config.go | 12 +- .../container_config_test.go | 76 +++++----- .../hook_config.go | 26 ++-- .../hook_config_test.go | 20 ++- internal/config/config.go | 6 +- internal/config/config_test.go | 11 +- internal/config/image/capabilities.go | 96 ++++++++++++- internal/config/image/capabilities_test.go | 134 ++++++++++++++++++ 10 files changed, 303 insertions(+), 268 deletions(-) delete mode 100644 cmd/nvidia-container-runtime-hook/capabilities_test.go create mode 100644 internal/config/image/capabilities_test.go diff --git a/cmd/nvidia-container-runtime-hook/capabilities.go b/cmd/nvidia-container-runtime-hook/capabilities.go index 062af2e7..5c53c8da 100644 --- a/cmd/nvidia-container-runtime-hook/capabilities.go +++ b/cmd/nvidia-container-runtime-hook/capabilities.go @@ -2,15 +2,6 @@ package main import ( "log" - "strings" -) - -const ( - allDriverCapabilities = DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx") - defaultDriverCapabilities = DriverCapabilities("utility,compute") - - none = DriverCapabilities("") - all = DriverCapabilities("all") ) func capabilityToCLI(cap string) string { @@ -34,50 +25,3 @@ func capabilityToCLI(cap string) string { } return "" } - -// DriverCapabilities is used to process the NVIDIA_DRIVER_CAPABILITIES environment -// variable. Operations include default values, filtering, and handling meta values such as "all" -type DriverCapabilities string - -// Intersection returns intersection between two sets of capabilities. -func (d DriverCapabilities) Intersection(capabilities DriverCapabilities) DriverCapabilities { - if capabilities == all { - return d - } - if d == all { - return capabilities - } - - lookup := make(map[string]bool) - for _, c := range d.list() { - lookup[c] = true - } - var found []string - for _, c := range capabilities.list() { - if lookup[c] { - found = append(found, c) - } - } - - intersection := DriverCapabilities(strings.Join(found, ",")) - return intersection -} - -// String returns the string representation of the driver capabilities -func (d DriverCapabilities) String() string { - return string(d) -} - -// list returns the driver capabilities as a list -func (d DriverCapabilities) list() []string { - var caps []string - for _, c := range strings.Split(string(d), ",") { - trimmed := strings.TrimSpace(c) - if len(trimmed) == 0 { - continue - } - caps = append(caps, trimmed) - } - - return caps -} diff --git a/cmd/nvidia-container-runtime-hook/capabilities_test.go b/cmd/nvidia-container-runtime-hook/capabilities_test.go deleted file mode 100644 index 0386e6a7..00000000 --- a/cmd/nvidia-container-runtime-hook/capabilities_test.go +++ /dev/null @@ -1,134 +0,0 @@ -/** -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -**/ - -package main - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestDriverCapabilitiesIntersection(t *testing.T) { - testCases := []struct { - capabilities DriverCapabilities - supportedCapabilities DriverCapabilities - expectedIntersection DriverCapabilities - }{ - { - capabilities: none, - supportedCapabilities: none, - expectedIntersection: none, - }, - { - capabilities: all, - supportedCapabilities: none, - expectedIntersection: none, - }, - { - capabilities: all, - supportedCapabilities: allDriverCapabilities, - expectedIntersection: allDriverCapabilities, - }, - { - capabilities: allDriverCapabilities, - supportedCapabilities: all, - expectedIntersection: allDriverCapabilities, - }, - { - capabilities: none, - supportedCapabilities: all, - expectedIntersection: none, - }, - { - capabilities: none, - supportedCapabilities: DriverCapabilities("cap1"), - expectedIntersection: none, - }, - { - capabilities: DriverCapabilities("cap0,cap1"), - supportedCapabilities: DriverCapabilities("cap1,cap0"), - expectedIntersection: DriverCapabilities("cap0,cap1"), - }, - { - capabilities: defaultDriverCapabilities, - supportedCapabilities: allDriverCapabilities, - expectedIntersection: defaultDriverCapabilities, - }, - { - capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"), - supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"), - expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"), - }, - { - capabilities: DriverCapabilities("cap1"), - supportedCapabilities: none, - expectedIntersection: none, - }, - { - capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"), - supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"), - expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"), - }, - } - - for i, tc := range testCases { - t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { - intersection := tc.supportedCapabilities.Intersection(tc.capabilities) - require.EqualValues(t, tc.expectedIntersection, intersection) - }) - } -} - -func TestDriverCapabilitiesList(t *testing.T) { - testCases := []struct { - capabilities DriverCapabilities - expected []string - }{ - { - capabilities: DriverCapabilities(""), - }, - { - capabilities: DriverCapabilities(" "), - }, - { - capabilities: DriverCapabilities(","), - }, - { - capabilities: DriverCapabilities(",cap"), - expected: []string{"cap"}, - }, - { - capabilities: DriverCapabilities("cap,"), - expected: []string{"cap"}, - }, - { - capabilities: DriverCapabilities("cap0,,cap1"), - expected: []string{"cap0", "cap1"}, - }, - { - capabilities: DriverCapabilities("cap1,cap0,cap3"), - expected: []string{"cap1", "cap0", "cap3"}, - }, - } - - for i, tc := range testCases { - t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { - require.EqualValues(t, tc.expected, tc.capabilities.list()) - }) - } -} diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index 2bf371aa..42732bfd 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -271,10 +271,12 @@ func getMigMonitorDevices(env map[string]string) *string { return nil } -func getDriverCapabilities(env map[string]string, supportedDriverCapabilities DriverCapabilities, legacyImage bool) DriverCapabilities { +func (c *HookConfig) getDriverCapabilities(env map[string]string, legacyImage bool) image.DriverCapabilities { // We use the default driver capabilities by default. This is filtered to only include the // supported capabilities - capabilities := supportedDriverCapabilities.Intersection(defaultDriverCapabilities) + supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities) + + capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) capsEnv, capsEnvSpecified := env[envNVDriverCapabilities] @@ -285,9 +287,9 @@ func getDriverCapabilities(env map[string]string, supportedDriverCapabilities Dr if capsEnvSpecified && len(capsEnv) > 0 { // If the envvironment variable is specified and is non-empty, use the capabilities value - envCapabilities := DriverCapabilities(capsEnv) + envCapabilities := image.NewDriverCapabilities(capsEnv) capabilities = supportedDriverCapabilities.Intersection(envCapabilities) - if envCapabilities != all && capabilities != envCapabilities { + if !envCapabilities.IsAll() && len(capabilities) != len(envCapabilities) { log.Panicln(fmt.Errorf("unsupported capabilities found in '%v' (allowed '%v')", envCapabilities, capabilities)) } } @@ -322,7 +324,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") } - driverCapabilities := getDriverCapabilities(image, hookConfig.SupportedDriverCapabilities, legacyImage).String() + driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String() requirements, err := image.GetRequirements() if err != nil { diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index f4b6c217..c918bf06 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -5,7 +5,6 @@ import ( "path/filepath" "testing" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/stretchr/testify/require" ) @@ -39,7 +38,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: allDriverCapabilities.String(), + DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -52,7 +51,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: allDriverCapabilities.String(), + DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -83,7 +82,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "", - DriverCapabilities: allDriverCapabilities.String(), + DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -96,7 +95,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: allDriverCapabilities.String(), + DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -110,7 +109,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -124,7 +123,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: allDriverCapabilities.String(), + DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -138,7 +137,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0"}, }, }, @@ -154,7 +153,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, }, }, @@ -171,7 +170,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", Requirements: []string{}, }, }, @@ -201,7 +200,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -232,7 +231,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -245,7 +244,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -259,7 +258,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -273,7 +272,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: allDriverCapabilities.String(), + DriverCapabilities: image.SupportedDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -287,7 +286,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0"}, }, }, @@ -303,7 +302,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"}, }, }, @@ -320,7 +319,7 @@ func TestGetNvidiaConfig(t *testing.T) { privileged: false, expectedConfig: &nvidiaConfig{ Devices: "gpu0,gpu1", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", Requirements: []string{}, }, }, @@ -333,7 +332,7 @@ func TestGetNvidiaConfig(t *testing.T) { expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{}, }, }, @@ -348,7 +347,7 @@ func TestGetNvidiaConfig(t *testing.T) { expectedConfig: &nvidiaConfig{ Devices: "all", MigConfigDevices: "mig0,mig1", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -373,7 +372,7 @@ func TestGetNvidiaConfig(t *testing.T) { expectedConfig: &nvidiaConfig{ Devices: "all", MigMonitorDevices: "mig0,mig1", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), Requirements: []string{"cuda>=9.0"}, }, }, @@ -399,7 +398,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", }, }, { @@ -414,7 +413,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: "video,display", + DriverCapabilities: "display,video", }, }, { @@ -428,7 +427,7 @@ func TestGetNvidiaConfig(t *testing.T) { }, expectedConfig: &nvidiaConfig{ Devices: "all", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), }, }, { @@ -439,14 +438,12 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: true, hookConfig: &HookConfig{ - Config: config.Config{ - SwarmResource: "DOCKER_SWARM_RESOURCE", - }, + SwarmResource: "DOCKER_SWARM_RESOURCE", SupportedDriverCapabilities: "video,display,utility,compute", }, expectedConfig: &nvidiaConfig{ Devices: "GPU1,GPU2", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), }, }, { @@ -457,14 +454,12 @@ func TestGetNvidiaConfig(t *testing.T) { }, privileged: true, hookConfig: &HookConfig{ - Config: config.Config{ - SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE", - }, + SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE", SupportedDriverCapabilities: "video,display,utility,compute", }, expectedConfig: &nvidiaConfig{ Devices: "GPU1,GPU2", - DriverCapabilities: defaultDriverCapabilities.String(), + DriverCapabilities: image.DefaultDriverCapabilities.String(), }, }, } @@ -924,7 +919,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) { func TestGetDriverCapabilities(t *testing.T) { - supportedCapabilities := "compute,utility,display,video" + supportedCapabilities := "compute,display,utility,video" testCases := []struct { description string @@ -959,7 +954,7 @@ func TestGetDriverCapabilities(t *testing.T) { }, legacyImage: true, supportedCapabilities: supportedCapabilities, - expectedCapabilities: defaultDriverCapabilities.String(), + expectedCapabilities: image.DefaultDriverCapabilities.String(), }, { description: "Env unset for legacy image is 'all'", @@ -982,7 +977,7 @@ func TestGetDriverCapabilities(t *testing.T) { env: map[string]string{}, legacyImage: false, supportedCapabilities: supportedCapabilities, - expectedCapabilities: defaultDriverCapabilities.String(), + expectedCapabilities: image.DefaultDriverCapabilities.String(), }, { description: "Env is all for modern image", @@ -1000,7 +995,7 @@ func TestGetDriverCapabilities(t *testing.T) { }, legacyImage: false, supportedCapabilities: supportedCapabilities, - expectedCapabilities: defaultDriverCapabilities.String(), + expectedCapabilities: image.DefaultDriverCapabilities.String(), }, { description: "Invalid capabilities panic", @@ -1020,11 +1015,14 @@ func TestGetDriverCapabilities(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - var capabilites DriverCapabilities + var capabilites string + + c := HookConfig{ + SupportedDriverCapabilities: tc.supportedCapabilities, + } getDriverCapabilities := func() { - supportedCapabilities := DriverCapabilities(tc.supportedCapabilities) - capabilites = getDriverCapabilities(tc.env, supportedCapabilities, tc.legacyImage) + capabilites = c.getDriverCapabilities(tc.env, tc.legacyImage).String() } if tc.expectedPanic { diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index c3e06d0b..bd3535ec 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -10,6 +10,7 @@ import ( "github.com/BurntSushi/toml" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) const ( @@ -23,11 +24,7 @@ var defaultPaths = [...]string{ } // HookConfig : options for the nvidia-container-runtime-hook. -type HookConfig struct { - config.Config - // TODO: We should also migrate the driver capabilities - SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"` -} +type HookConfig config.Config func getDefaultHookConfig() (HookConfig, error) { defaultCfg, err := config.GetDefault() @@ -35,12 +32,7 @@ func getDefaultHookConfig() (HookConfig, error) { return HookConfig{}, err } - c := HookConfig{ - Config: *defaultCfg, - SupportedDriverCapabilities: allDriverCapabilities, - } - - return c, nil + return *(*HookConfig)(defaultCfg), nil } func getHookConfig() (*HookConfig, error) { @@ -71,13 +63,15 @@ func getHookConfig() (*HookConfig, error) { } } - if config.SupportedDriverCapabilities == all { - config.SupportedDriverCapabilities = allDriverCapabilities + allSupportedDriverCapabilities := image.SupportedDriverCapabilities + if config.SupportedDriverCapabilities == "all" { + config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String() } - // We ensure that the supported-driver-capabilites option is a subset of allDriverCapabilities - if intersection := allDriverCapabilities.Intersection(config.SupportedDriverCapabilities); intersection != config.SupportedDriverCapabilities { + configuredCapabilities := image.NewDriverCapabilities(config.SupportedDriverCapabilities) + // We ensure that the configured value is a subset of all supported capabilities + if !allSupportedDriverCapabilities.IsSuperset(configuredCapabilities) { configName := config.getConfigOption("SupportedDriverCapabilities") - log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allDriverCapabilities) + log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String()) } return &config, nil diff --git a/cmd/nvidia-container-runtime-hook/hook_config_test.go b/cmd/nvidia-container-runtime-hook/hook_config_test.go index dddd1cdc..4e71a4ab 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config_test.go +++ b/cmd/nvidia-container-runtime-hook/hook_config_test.go @@ -21,7 +21,7 @@ import ( "os" "testing" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/stretchr/testify/require" ) @@ -29,16 +29,16 @@ func TestGetHookConfig(t *testing.T) { testCases := []struct { lines []string expectedPanic bool - expectedDriverCapabilities DriverCapabilities + expectedDriverCapabilities string }{ { - expectedDriverCapabilities: allDriverCapabilities, + expectedDriverCapabilities: image.SupportedDriverCapabilities.String(), }, { lines: []string{ "supported-driver-capabilities = \"all\"", }, - expectedDriverCapabilities: allDriverCapabilities, + expectedDriverCapabilities: image.SupportedDriverCapabilities.String(), }, { lines: []string{ @@ -48,19 +48,19 @@ func TestGetHookConfig(t *testing.T) { }, { lines: []string{}, - expectedDriverCapabilities: allDriverCapabilities, + expectedDriverCapabilities: image.SupportedDriverCapabilities.String(), }, { lines: []string{ "supported-driver-capabilities = \"\"", }, - expectedDriverCapabilities: none, + expectedDriverCapabilities: "", }, { lines: []string{ - "supported-driver-capabilities = \"utility,compute\"", + "supported-driver-capabilities = \"compute,utility\"", }, - expectedDriverCapabilities: DriverCapabilities("utility,compute"), + expectedDriverCapabilities: "compute,utility", }, } @@ -144,9 +144,7 @@ func TestGetSwarmResourceEnvvars(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { c := &HookConfig{ - Config: config.Config{ - SwarmResource: tc.value, - }, + SwarmResource: tc.value, } envvars := c.getSwarmResourceEnvvars() diff --git a/internal/config/config.go b/internal/config/config.go index 660c505c..38a65e22 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,6 +25,7 @@ import ( "path/filepath" "strings" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" @@ -61,7 +62,7 @@ type Config struct { SwarmResource string `toml:"swarm-resource"` AcceptEnvvarUnprivileged bool `toml:"accept-nvidia-visible-devices-envvar-when-unprivileged"` AcceptDeviceListAsVolumeMounts bool `toml:"accept-nvidia-visible-devices-as-volume-mounts"` - // SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"` + SupportedDriverCapabilities string `toml:"supported-driver-capabilities"` NVIDIAContainerCLIConfig ContainerCLIConfig `toml:"nvidia-container-cli"` NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"` @@ -135,7 +136,8 @@ func getFromTree(toml *toml.Tree) (*Config, error) { // GetDefault defines the default values for the config func GetDefault() (*Config, error) { d := Config{ - AcceptEnvvarUnprivileged: true, + AcceptEnvvarUnprivileged: true, + SupportedDriverCapabilities: image.SupportedDriverCapabilities.String(), NVIDIAContainerCLIConfig: ContainerCLIConfig{ LoadKmods: true, Ldconfig: getLdConfigPath(), diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 41be9e40..9cb4a946 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -60,7 +60,8 @@ func TestGetConfig(t *testing.T) { description: "empty config is default", inspectLdconfig: true, expectedConfig: &Config{ - AcceptEnvvarUnprivileged: true, + AcceptEnvvarUnprivileged: true, + SupportedDriverCapabilities: "compat32,compute,display,graphics,ngx,utility,video", NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "", LoadKmods: true, @@ -94,6 +95,7 @@ func TestGetConfig(t *testing.T) { description: "config options set inline", contents: []string{ "accept-nvidia-visible-devices-envvar-when-unprivileged = false", + "supported-driver-capabilities = \"compute,utility\"", "nvidia-container-cli.root = \"/bar/baz\"", "nvidia-container-cli.load-kmods = false", "nvidia-container-cli.ldconfig = \"/foo/bar/ldconfig\"", @@ -110,7 +112,8 @@ func TestGetConfig(t *testing.T) { "nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"", }, expectedConfig: &Config{ - AcceptEnvvarUnprivileged: false, + AcceptEnvvarUnprivileged: false, + SupportedDriverCapabilities: "compute,utility", NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "/bar/baz", LoadKmods: false, @@ -150,6 +153,7 @@ func TestGetConfig(t *testing.T) { description: "config options set in section", contents: []string{ "accept-nvidia-visible-devices-envvar-when-unprivileged = false", + "supported-driver-capabilities = \"compute,utility\"", "[nvidia-container-cli]", "root = \"/bar/baz\"", "load-kmods = false", @@ -172,7 +176,8 @@ func TestGetConfig(t *testing.T) { "path = \"/foo/bar/nvidia-ctk\"", }, expectedConfig: &Config{ - AcceptEnvvarUnprivileged: false, + AcceptEnvvarUnprivileged: false, + SupportedDriverCapabilities: "compute,utility", NVIDIAContainerCLIConfig: ContainerCLIConfig{ Root: "/bar/baz", LoadKmods: false, diff --git a/internal/config/image/capabilities.go b/internal/config/image/capabilities.go index 2ff1f40b..9c05acc0 100644 --- a/internal/config/image/capabilities.go +++ b/internal/config/image/capabilities.go @@ -16,12 +16,18 @@ package image +import ( + "sort" + "strings" +) + // DriverCapability represents the possible values of NVIDIA_DRIVER_CAPABILITIES type DriverCapability string // Constants for the supported driver capabilities const ( DriverCapabilityAll DriverCapability = "all" + DriverCapabilityNone DriverCapability = "none" DriverCapabilityCompat32 DriverCapability = "compat32" DriverCapabilityCompute DriverCapability = "compute" DriverCapabilityDisplay DriverCapability = "display" @@ -31,12 +37,37 @@ const ( DriverCapabilityVideo DriverCapability = "video" ) +var ( + driverCapabilitiesNone = NewDriverCapabilities() + driverCapabilitiesAll = NewDriverCapabilities("all") + + // DefaultDriverCapabilities sets the value for driver capabilities if no value is set. + DefaultDriverCapabilities = NewDriverCapabilities("utility,compute") + // SupportedDriverCapabilities defines the set of all supported driver capabilities. + SupportedDriverCapabilities = NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx") +) + +// NewDriverCapabilities creates a set of driver capabilities from the specified capabilities +func NewDriverCapabilities(capabilities ...string) DriverCapabilities { + dc := make(DriverCapabilities) + for _, capability := range capabilities { + for _, c := range strings.Split(capability, ",") { + trimmed := strings.TrimSpace(c) + if trimmed == "" { + continue + } + dc[DriverCapability(trimmed)] = true + } + } + return dc +} + // DriverCapabilities represents the NVIDIA_DRIVER_CAPABILITIES set for the specified image. type DriverCapabilities map[DriverCapability]bool // Has check whether the specified capability is selected. func (c DriverCapabilities) Has(capability DriverCapability) bool { - if c[DriverCapabilityAll] { + if c.IsAll() { return true } return c[capability] @@ -44,11 +75,72 @@ func (c DriverCapabilities) Has(capability DriverCapability) bool { // Any checks whether any of the specified capabilites are set func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool { + if c.IsAll() { + return true + } for _, cap := range capabilities { if c.Has(cap) { return true } } - return false } + +// List returns the list of driver capabilities. +// The list is sorted. +func (c DriverCapabilities) List() []string { + var capabilities []string + for capability := range c { + capabilities = append(capabilities, string(capability)) + } + sort.Strings(capabilities) + return capabilities +} + +// String returns the string repesentation of the driver capabilities. +func (c DriverCapabilities) String() string { + if c.IsAll() { + return "all" + } + return strings.Join(c.List(), ",") +} + +// IsAll indicates whether the set of capabilities is `all` +func (c DriverCapabilities) IsAll() bool { + return c[DriverCapabilityAll] +} + +// Intersection returns a new set which includes the item in BOTH d and s2. +// For example: d = {a1, a2} s2 = {a2, a3} s1.Intersection(s2) = {a2} +func (c DriverCapabilities) Intersection(s2 DriverCapabilities) DriverCapabilities { + if s2.IsAll() { + return c + } + if c.IsAll() { + return s2 + } + + intersection := make(DriverCapabilities) + for capability := range s2 { + if c[capability] { + intersection[capability] = true + } + } + + return intersection +} + +// IsSuperset returns true if and only if d is a superset of s2. +func (c DriverCapabilities) IsSuperset(s2 DriverCapabilities) bool { + if c.IsAll() { + return true + } + + for capability := range s2 { + if !c[capability] { + return false + } + } + + return true +} diff --git a/internal/config/image/capabilities_test.go b/internal/config/image/capabilities_test.go new file mode 100644 index 00000000..f178392f --- /dev/null +++ b/internal/config/image/capabilities_test.go @@ -0,0 +1,134 @@ +/** +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package image + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDriverCapabilitiesIntersection(t *testing.T) { + testCases := []struct { + capabilities DriverCapabilities + supportedCapabilities DriverCapabilities + expectedIntersection DriverCapabilities + }{ + { + capabilities: driverCapabilitiesNone, + supportedCapabilities: driverCapabilitiesNone, + expectedIntersection: driverCapabilitiesNone, + }, + { + capabilities: driverCapabilitiesAll, + supportedCapabilities: driverCapabilitiesNone, + expectedIntersection: driverCapabilitiesNone, + }, + { + capabilities: driverCapabilitiesAll, + supportedCapabilities: SupportedDriverCapabilities, + expectedIntersection: SupportedDriverCapabilities, + }, + { + capabilities: SupportedDriverCapabilities, + supportedCapabilities: driverCapabilitiesAll, + expectedIntersection: SupportedDriverCapabilities, + }, + { + capabilities: driverCapabilitiesNone, + supportedCapabilities: driverCapabilitiesAll, + expectedIntersection: driverCapabilitiesNone, + }, + { + capabilities: driverCapabilitiesNone, + supportedCapabilities: NewDriverCapabilities("cap1"), + expectedIntersection: driverCapabilitiesNone, + }, + { + capabilities: NewDriverCapabilities("cap0,cap1"), + supportedCapabilities: NewDriverCapabilities("cap1,cap0"), + expectedIntersection: NewDriverCapabilities("cap0,cap1"), + }, + { + capabilities: DefaultDriverCapabilities, + supportedCapabilities: SupportedDriverCapabilities, + expectedIntersection: DefaultDriverCapabilities, + }, + { + capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"), + supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"), + expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"), + }, + { + capabilities: NewDriverCapabilities("cap1"), + supportedCapabilities: driverCapabilitiesNone, + expectedIntersection: driverCapabilitiesNone, + }, + { + capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"), + supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"), + expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"), + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + intersection := tc.supportedCapabilities.Intersection(tc.capabilities) + require.EqualValues(t, tc.expectedIntersection, intersection) + }) + } +} + +func TestDriverCapabilitiesList(t *testing.T) { + testCases := []struct { + capabilities DriverCapabilities + expected []string + }{ + { + capabilities: NewDriverCapabilities(""), + }, + { + capabilities: NewDriverCapabilities(" "), + }, + { + capabilities: NewDriverCapabilities(","), + }, + { + capabilities: NewDriverCapabilities(",cap"), + expected: []string{"cap"}, + }, + { + capabilities: NewDriverCapabilities("cap,"), + expected: []string{"cap"}, + }, + { + capabilities: NewDriverCapabilities("cap0,,cap1"), + expected: []string{"cap0", "cap1"}, + }, + { + capabilities: NewDriverCapabilities("cap1,cap0,cap3"), + expected: []string{"cap0", "cap1", "cap3"}, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + require.EqualValues(t, tc.expected, tc.capabilities.List()) + }) + } +}