Refactor handling of DriverCapabilities

This change consolidates the handling of NVIDIA_DRIVER_CAPABILITIES in the
interal/image package.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-08-10 14:15:24 +02:00
parent 4dcaa61167
commit b18ac09f77
10 changed files with 303 additions and 268 deletions

View File

@ -2,15 +2,6 @@ package main
import (
"log"
"strings"
)
const (
allDriverCapabilities = DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx")
defaultDriverCapabilities = DriverCapabilities("utility,compute")
none = DriverCapabilities("")
all = DriverCapabilities("all")
)
func capabilityToCLI(cap string) string {
@ -34,50 +25,3 @@ func capabilityToCLI(cap string) string {
}
return ""
}
// DriverCapabilities is used to process the NVIDIA_DRIVER_CAPABILITIES environment
// variable. Operations include default values, filtering, and handling meta values such as "all"
type DriverCapabilities string
// Intersection returns intersection between two sets of capabilities.
func (d DriverCapabilities) Intersection(capabilities DriverCapabilities) DriverCapabilities {
if capabilities == all {
return d
}
if d == all {
return capabilities
}
lookup := make(map[string]bool)
for _, c := range d.list() {
lookup[c] = true
}
var found []string
for _, c := range capabilities.list() {
if lookup[c] {
found = append(found, c)
}
}
intersection := DriverCapabilities(strings.Join(found, ","))
return intersection
}
// String returns the string representation of the driver capabilities
func (d DriverCapabilities) String() string {
return string(d)
}
// list returns the driver capabilities as a list
func (d DriverCapabilities) list() []string {
var caps []string
for _, c := range strings.Split(string(d), ",") {
trimmed := strings.TrimSpace(c)
if len(trimmed) == 0 {
continue
}
caps = append(caps, trimmed)
}
return caps
}

View File

@ -1,134 +0,0 @@
/**
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package main
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestDriverCapabilitiesIntersection(t *testing.T) {
testCases := []struct {
capabilities DriverCapabilities
supportedCapabilities DriverCapabilities
expectedIntersection DriverCapabilities
}{
{
capabilities: none,
supportedCapabilities: none,
expectedIntersection: none,
},
{
capabilities: all,
supportedCapabilities: none,
expectedIntersection: none,
},
{
capabilities: all,
supportedCapabilities: allDriverCapabilities,
expectedIntersection: allDriverCapabilities,
},
{
capabilities: allDriverCapabilities,
supportedCapabilities: all,
expectedIntersection: allDriverCapabilities,
},
{
capabilities: none,
supportedCapabilities: all,
expectedIntersection: none,
},
{
capabilities: none,
supportedCapabilities: DriverCapabilities("cap1"),
expectedIntersection: none,
},
{
capabilities: DriverCapabilities("cap0,cap1"),
supportedCapabilities: DriverCapabilities("cap1,cap0"),
expectedIntersection: DriverCapabilities("cap0,cap1"),
},
{
capabilities: defaultDriverCapabilities,
supportedCapabilities: allDriverCapabilities,
expectedIntersection: defaultDriverCapabilities,
},
{
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
},
{
capabilities: DriverCapabilities("cap1"),
supportedCapabilities: none,
expectedIntersection: none,
},
{
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
intersection := tc.supportedCapabilities.Intersection(tc.capabilities)
require.EqualValues(t, tc.expectedIntersection, intersection)
})
}
}
func TestDriverCapabilitiesList(t *testing.T) {
testCases := []struct {
capabilities DriverCapabilities
expected []string
}{
{
capabilities: DriverCapabilities(""),
},
{
capabilities: DriverCapabilities(" "),
},
{
capabilities: DriverCapabilities(","),
},
{
capabilities: DriverCapabilities(",cap"),
expected: []string{"cap"},
},
{
capabilities: DriverCapabilities("cap,"),
expected: []string{"cap"},
},
{
capabilities: DriverCapabilities("cap0,,cap1"),
expected: []string{"cap0", "cap1"},
},
{
capabilities: DriverCapabilities("cap1,cap0,cap3"),
expected: []string{"cap1", "cap0", "cap3"},
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
require.EqualValues(t, tc.expected, tc.capabilities.list())
})
}
}

View File

@ -271,10 +271,12 @@ func getMigMonitorDevices(env map[string]string) *string {
return nil
}
func getDriverCapabilities(env map[string]string, supportedDriverCapabilities DriverCapabilities, legacyImage bool) DriverCapabilities {
func (c *HookConfig) getDriverCapabilities(env map[string]string, legacyImage bool) image.DriverCapabilities {
// We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities
capabilities := supportedDriverCapabilities.Intersection(defaultDriverCapabilities)
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities)
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
capsEnv, capsEnvSpecified := env[envNVDriverCapabilities]
@ -285,9 +287,9 @@ func getDriverCapabilities(env map[string]string, supportedDriverCapabilities Dr
if capsEnvSpecified && len(capsEnv) > 0 {
// If the envvironment variable is specified and is non-empty, use the capabilities value
envCapabilities := DriverCapabilities(capsEnv)
envCapabilities := image.NewDriverCapabilities(capsEnv)
capabilities = supportedDriverCapabilities.Intersection(envCapabilities)
if envCapabilities != all && capabilities != envCapabilities {
if !envCapabilities.IsAll() && len(capabilities) != len(envCapabilities) {
log.Panicln(fmt.Errorf("unsupported capabilities found in '%v' (allowed '%v')", envCapabilities, capabilities))
}
}
@ -322,7 +324,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
}
driverCapabilities := getDriverCapabilities(image, hookConfig.SupportedDriverCapabilities, legacyImage).String()
driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()
requirements, err := image.GetRequirements()
if err != nil {

View File

@ -5,7 +5,6 @@ import (
"path/filepath"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/stretchr/testify/require"
)
@ -39,7 +38,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -52,7 +51,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -83,7 +82,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -96,7 +95,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -110,7 +109,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -124,7 +123,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -138,7 +137,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0"},
},
},
@ -154,7 +153,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
},
},
@ -171,7 +170,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{},
},
},
@ -201,7 +200,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -232,7 +231,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -245,7 +244,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -259,7 +258,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -273,7 +272,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -287,7 +286,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0"},
},
},
@ -303,7 +302,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
},
},
@ -320,7 +319,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{},
},
},
@ -333,7 +332,7 @@ func TestGetNvidiaConfig(t *testing.T) {
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{},
},
},
@ -348,7 +347,7 @@ func TestGetNvidiaConfig(t *testing.T) {
expectedConfig: &nvidiaConfig{
Devices: "all",
MigConfigDevices: "mig0,mig1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -373,7 +372,7 @@ func TestGetNvidiaConfig(t *testing.T) {
expectedConfig: &nvidiaConfig{
Devices: "all",
MigMonitorDevices: "mig0,mig1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@ -399,7 +398,7 @@ func TestGetNvidiaConfig(t *testing.T) {
},
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
},
},
{
@ -414,7 +413,7 @@ func TestGetNvidiaConfig(t *testing.T) {
},
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
},
},
{
@ -428,7 +427,7 @@ func TestGetNvidiaConfig(t *testing.T) {
},
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
},
},
{
@ -439,14 +438,12 @@ func TestGetNvidiaConfig(t *testing.T) {
},
privileged: true,
hookConfig: &HookConfig{
Config: config.Config{
SwarmResource: "DOCKER_SWARM_RESOURCE",
},
SwarmResource: "DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
},
},
{
@ -457,14 +454,12 @@ func TestGetNvidiaConfig(t *testing.T) {
},
privileged: true,
hookConfig: &HookConfig{
Config: config.Config{
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
},
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
},
},
}
@ -924,7 +919,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
func TestGetDriverCapabilities(t *testing.T) {
supportedCapabilities := "compute,utility,display,video"
supportedCapabilities := "compute,display,utility,video"
testCases := []struct {
description string
@ -959,7 +954,7 @@ func TestGetDriverCapabilities(t *testing.T) {
},
legacyImage: true,
supportedCapabilities: supportedCapabilities,
expectedCapabilities: defaultDriverCapabilities.String(),
expectedCapabilities: image.DefaultDriverCapabilities.String(),
},
{
description: "Env unset for legacy image is 'all'",
@ -982,7 +977,7 @@ func TestGetDriverCapabilities(t *testing.T) {
env: map[string]string{},
legacyImage: false,
supportedCapabilities: supportedCapabilities,
expectedCapabilities: defaultDriverCapabilities.String(),
expectedCapabilities: image.DefaultDriverCapabilities.String(),
},
{
description: "Env is all for modern image",
@ -1000,7 +995,7 @@ func TestGetDriverCapabilities(t *testing.T) {
},
legacyImage: false,
supportedCapabilities: supportedCapabilities,
expectedCapabilities: defaultDriverCapabilities.String(),
expectedCapabilities: image.DefaultDriverCapabilities.String(),
},
{
description: "Invalid capabilities panic",
@ -1020,11 +1015,14 @@ func TestGetDriverCapabilities(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
var capabilites DriverCapabilities
var capabilites string
c := HookConfig{
SupportedDriverCapabilities: tc.supportedCapabilities,
}
getDriverCapabilities := func() {
supportedCapabilities := DriverCapabilities(tc.supportedCapabilities)
capabilites = getDriverCapabilities(tc.env, supportedCapabilities, tc.legacyImage)
capabilites = c.getDriverCapabilities(tc.env, tc.legacyImage).String()
}
if tc.expectedPanic {

View File

@ -10,6 +10,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
)
const (
@ -23,11 +24,7 @@ var defaultPaths = [...]string{
}
// HookConfig : options for the nvidia-container-runtime-hook.
type HookConfig struct {
config.Config
// TODO: We should also migrate the driver capabilities
SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"`
}
type HookConfig config.Config
func getDefaultHookConfig() (HookConfig, error) {
defaultCfg, err := config.GetDefault()
@ -35,12 +32,7 @@ func getDefaultHookConfig() (HookConfig, error) {
return HookConfig{}, err
}
c := HookConfig{
Config: *defaultCfg,
SupportedDriverCapabilities: allDriverCapabilities,
}
return c, nil
return *(*HookConfig)(defaultCfg), nil
}
func getHookConfig() (*HookConfig, error) {
@ -71,13 +63,15 @@ func getHookConfig() (*HookConfig, error) {
}
}
if config.SupportedDriverCapabilities == all {
config.SupportedDriverCapabilities = allDriverCapabilities
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" {
config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String()
}
// We ensure that the supported-driver-capabilites option is a subset of allDriverCapabilities
if intersection := allDriverCapabilities.Intersection(config.SupportedDriverCapabilities); intersection != config.SupportedDriverCapabilities {
configuredCapabilities := image.NewDriverCapabilities(config.SupportedDriverCapabilities)
// We ensure that the configured value is a subset of all supported capabilities
if !allSupportedDriverCapabilities.IsSuperset(configuredCapabilities) {
configName := config.getConfigOption("SupportedDriverCapabilities")
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allDriverCapabilities)
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String())
}
return &config, nil

View File

@ -21,7 +21,7 @@ import (
"os"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/stretchr/testify/require"
)
@ -29,16 +29,16 @@ func TestGetHookConfig(t *testing.T) {
testCases := []struct {
lines []string
expectedPanic bool
expectedDriverCapabilities DriverCapabilities
expectedDriverCapabilities string
}{
{
expectedDriverCapabilities: allDriverCapabilities,
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
},
{
lines: []string{
"supported-driver-capabilities = \"all\"",
},
expectedDriverCapabilities: allDriverCapabilities,
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
},
{
lines: []string{
@ -48,19 +48,19 @@ func TestGetHookConfig(t *testing.T) {
},
{
lines: []string{},
expectedDriverCapabilities: allDriverCapabilities,
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
},
{
lines: []string{
"supported-driver-capabilities = \"\"",
},
expectedDriverCapabilities: none,
expectedDriverCapabilities: "",
},
{
lines: []string{
"supported-driver-capabilities = \"utility,compute\"",
"supported-driver-capabilities = \"compute,utility\"",
},
expectedDriverCapabilities: DriverCapabilities("utility,compute"),
expectedDriverCapabilities: "compute,utility",
},
}
@ -144,9 +144,7 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
c := &HookConfig{
Config: config.Config{
SwarmResource: tc.value,
},
SwarmResource: tc.value,
}
envvars := c.getSwarmResourceEnvvars()

View File

@ -25,6 +25,7 @@ import (
"path/filepath"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
@ -61,7 +62,7 @@ type Config struct {
SwarmResource string `toml:"swarm-resource"`
AcceptEnvvarUnprivileged bool `toml:"accept-nvidia-visible-devices-envvar-when-unprivileged"`
AcceptDeviceListAsVolumeMounts bool `toml:"accept-nvidia-visible-devices-as-volume-mounts"`
// SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"`
SupportedDriverCapabilities string `toml:"supported-driver-capabilities"`
NVIDIAContainerCLIConfig ContainerCLIConfig `toml:"nvidia-container-cli"`
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
@ -135,7 +136,8 @@ func getFromTree(toml *toml.Tree) (*Config, error) {
// GetDefault defines the default values for the config
func GetDefault() (*Config, error) {
d := Config{
AcceptEnvvarUnprivileged: true,
AcceptEnvvarUnprivileged: true,
SupportedDriverCapabilities: image.SupportedDriverCapabilities.String(),
NVIDIAContainerCLIConfig: ContainerCLIConfig{
LoadKmods: true,
Ldconfig: getLdConfigPath(),

View File

@ -60,7 +60,8 @@ func TestGetConfig(t *testing.T) {
description: "empty config is default",
inspectLdconfig: true,
expectedConfig: &Config{
AcceptEnvvarUnprivileged: true,
AcceptEnvvarUnprivileged: true,
SupportedDriverCapabilities: "compat32,compute,display,graphics,ngx,utility,video",
NVIDIAContainerCLIConfig: ContainerCLIConfig{
Root: "",
LoadKmods: true,
@ -94,6 +95,7 @@ func TestGetConfig(t *testing.T) {
description: "config options set inline",
contents: []string{
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
"supported-driver-capabilities = \"compute,utility\"",
"nvidia-container-cli.root = \"/bar/baz\"",
"nvidia-container-cli.load-kmods = false",
"nvidia-container-cli.ldconfig = \"/foo/bar/ldconfig\"",
@ -110,7 +112,8 @@ func TestGetConfig(t *testing.T) {
"nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"",
},
expectedConfig: &Config{
AcceptEnvvarUnprivileged: false,
AcceptEnvvarUnprivileged: false,
SupportedDriverCapabilities: "compute,utility",
NVIDIAContainerCLIConfig: ContainerCLIConfig{
Root: "/bar/baz",
LoadKmods: false,
@ -150,6 +153,7 @@ func TestGetConfig(t *testing.T) {
description: "config options set in section",
contents: []string{
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
"supported-driver-capabilities = \"compute,utility\"",
"[nvidia-container-cli]",
"root = \"/bar/baz\"",
"load-kmods = false",
@ -172,7 +176,8 @@ func TestGetConfig(t *testing.T) {
"path = \"/foo/bar/nvidia-ctk\"",
},
expectedConfig: &Config{
AcceptEnvvarUnprivileged: false,
AcceptEnvvarUnprivileged: false,
SupportedDriverCapabilities: "compute,utility",
NVIDIAContainerCLIConfig: ContainerCLIConfig{
Root: "/bar/baz",
LoadKmods: false,

View File

@ -16,12 +16,18 @@
package image
import (
"sort"
"strings"
)
// DriverCapability represents the possible values of NVIDIA_DRIVER_CAPABILITIES
type DriverCapability string
// Constants for the supported driver capabilities
const (
DriverCapabilityAll DriverCapability = "all"
DriverCapabilityNone DriverCapability = "none"
DriverCapabilityCompat32 DriverCapability = "compat32"
DriverCapabilityCompute DriverCapability = "compute"
DriverCapabilityDisplay DriverCapability = "display"
@ -31,12 +37,37 @@ const (
DriverCapabilityVideo DriverCapability = "video"
)
var (
driverCapabilitiesNone = NewDriverCapabilities()
driverCapabilitiesAll = NewDriverCapabilities("all")
// DefaultDriverCapabilities sets the value for driver capabilities if no value is set.
DefaultDriverCapabilities = NewDriverCapabilities("utility,compute")
// SupportedDriverCapabilities defines the set of all supported driver capabilities.
SupportedDriverCapabilities = NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx")
)
// NewDriverCapabilities creates a set of driver capabilities from the specified capabilities
func NewDriverCapabilities(capabilities ...string) DriverCapabilities {
dc := make(DriverCapabilities)
for _, capability := range capabilities {
for _, c := range strings.Split(capability, ",") {
trimmed := strings.TrimSpace(c)
if trimmed == "" {
continue
}
dc[DriverCapability(trimmed)] = true
}
}
return dc
}
// DriverCapabilities represents the NVIDIA_DRIVER_CAPABILITIES set for the specified image.
type DriverCapabilities map[DriverCapability]bool
// Has check whether the specified capability is selected.
func (c DriverCapabilities) Has(capability DriverCapability) bool {
if c[DriverCapabilityAll] {
if c.IsAll() {
return true
}
return c[capability]
@ -44,11 +75,72 @@ func (c DriverCapabilities) Has(capability DriverCapability) bool {
// Any checks whether any of the specified capabilites are set
func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool {
if c.IsAll() {
return true
}
for _, cap := range capabilities {
if c.Has(cap) {
return true
}
}
return false
}
// List returns the list of driver capabilities.
// The list is sorted.
func (c DriverCapabilities) List() []string {
var capabilities []string
for capability := range c {
capabilities = append(capabilities, string(capability))
}
sort.Strings(capabilities)
return capabilities
}
// String returns the string repesentation of the driver capabilities.
func (c DriverCapabilities) String() string {
if c.IsAll() {
return "all"
}
return strings.Join(c.List(), ",")
}
// IsAll indicates whether the set of capabilities is `all`
func (c DriverCapabilities) IsAll() bool {
return c[DriverCapabilityAll]
}
// Intersection returns a new set which includes the item in BOTH d and s2.
// For example: d = {a1, a2} s2 = {a2, a3} s1.Intersection(s2) = {a2}
func (c DriverCapabilities) Intersection(s2 DriverCapabilities) DriverCapabilities {
if s2.IsAll() {
return c
}
if c.IsAll() {
return s2
}
intersection := make(DriverCapabilities)
for capability := range s2 {
if c[capability] {
intersection[capability] = true
}
}
return intersection
}
// IsSuperset returns true if and only if d is a superset of s2.
func (c DriverCapabilities) IsSuperset(s2 DriverCapabilities) bool {
if c.IsAll() {
return true
}
for capability := range s2 {
if !c[capability] {
return false
}
}
return true
}

View File

@ -0,0 +1,134 @@
/**
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package image
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestDriverCapabilitiesIntersection(t *testing.T) {
testCases := []struct {
capabilities DriverCapabilities
supportedCapabilities DriverCapabilities
expectedIntersection DriverCapabilities
}{
{
capabilities: driverCapabilitiesNone,
supportedCapabilities: driverCapabilitiesNone,
expectedIntersection: driverCapabilitiesNone,
},
{
capabilities: driverCapabilitiesAll,
supportedCapabilities: driverCapabilitiesNone,
expectedIntersection: driverCapabilitiesNone,
},
{
capabilities: driverCapabilitiesAll,
supportedCapabilities: SupportedDriverCapabilities,
expectedIntersection: SupportedDriverCapabilities,
},
{
capabilities: SupportedDriverCapabilities,
supportedCapabilities: driverCapabilitiesAll,
expectedIntersection: SupportedDriverCapabilities,
},
{
capabilities: driverCapabilitiesNone,
supportedCapabilities: driverCapabilitiesAll,
expectedIntersection: driverCapabilitiesNone,
},
{
capabilities: driverCapabilitiesNone,
supportedCapabilities: NewDriverCapabilities("cap1"),
expectedIntersection: driverCapabilitiesNone,
},
{
capabilities: NewDriverCapabilities("cap0,cap1"),
supportedCapabilities: NewDriverCapabilities("cap1,cap0"),
expectedIntersection: NewDriverCapabilities("cap0,cap1"),
},
{
capabilities: DefaultDriverCapabilities,
supportedCapabilities: SupportedDriverCapabilities,
expectedIntersection: DefaultDriverCapabilities,
},
{
capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
},
{
capabilities: NewDriverCapabilities("cap1"),
supportedCapabilities: driverCapabilitiesNone,
expectedIntersection: driverCapabilitiesNone,
},
{
capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
intersection := tc.supportedCapabilities.Intersection(tc.capabilities)
require.EqualValues(t, tc.expectedIntersection, intersection)
})
}
}
func TestDriverCapabilitiesList(t *testing.T) {
testCases := []struct {
capabilities DriverCapabilities
expected []string
}{
{
capabilities: NewDriverCapabilities(""),
},
{
capabilities: NewDriverCapabilities(" "),
},
{
capabilities: NewDriverCapabilities(","),
},
{
capabilities: NewDriverCapabilities(",cap"),
expected: []string{"cap"},
},
{
capabilities: NewDriverCapabilities("cap,"),
expected: []string{"cap"},
},
{
capabilities: NewDriverCapabilities("cap0,,cap1"),
expected: []string{"cap0", "cap1"},
},
{
capabilities: NewDriverCapabilities("cap1,cap0,cap3"),
expected: []string{"cap0", "cap1", "cap3"},
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
require.EqualValues(t, tc.expected, tc.capabilities.List())
})
}
}