Refactor handling of DriverCapabilities

This change consolidates the handling of NVIDIA_DRIVER_CAPABILITIES in the
interal/image package.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar
2023-08-10 14:15:24 +02:00
parent 4dcaa61167
commit b18ac09f77
10 changed files with 303 additions and 268 deletions

View File

@@ -2,15 +2,6 @@ package main
import (
"log"
"strings"
)
const (
allDriverCapabilities = DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx")
defaultDriverCapabilities = DriverCapabilities("utility,compute")
none = DriverCapabilities("")
all = DriverCapabilities("all")
)
func capabilityToCLI(cap string) string {
@@ -34,50 +25,3 @@ func capabilityToCLI(cap string) string {
}
return ""
}
// DriverCapabilities is used to process the NVIDIA_DRIVER_CAPABILITIES environment
// variable. Operations include default values, filtering, and handling meta values such as "all"
type DriverCapabilities string
// Intersection returns intersection between two sets of capabilities.
func (d DriverCapabilities) Intersection(capabilities DriverCapabilities) DriverCapabilities {
if capabilities == all {
return d
}
if d == all {
return capabilities
}
lookup := make(map[string]bool)
for _, c := range d.list() {
lookup[c] = true
}
var found []string
for _, c := range capabilities.list() {
if lookup[c] {
found = append(found, c)
}
}
intersection := DriverCapabilities(strings.Join(found, ","))
return intersection
}
// String returns the string representation of the driver capabilities
func (d DriverCapabilities) String() string {
return string(d)
}
// list returns the driver capabilities as a list
func (d DriverCapabilities) list() []string {
var caps []string
for _, c := range strings.Split(string(d), ",") {
trimmed := strings.TrimSpace(c)
if len(trimmed) == 0 {
continue
}
caps = append(caps, trimmed)
}
return caps
}

View File

@@ -1,134 +0,0 @@
/**
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package main
import (
"fmt"
"testing"
"github.com/stretchr/testify/require"
)
func TestDriverCapabilitiesIntersection(t *testing.T) {
testCases := []struct {
capabilities DriverCapabilities
supportedCapabilities DriverCapabilities
expectedIntersection DriverCapabilities
}{
{
capabilities: none,
supportedCapabilities: none,
expectedIntersection: none,
},
{
capabilities: all,
supportedCapabilities: none,
expectedIntersection: none,
},
{
capabilities: all,
supportedCapabilities: allDriverCapabilities,
expectedIntersection: allDriverCapabilities,
},
{
capabilities: allDriverCapabilities,
supportedCapabilities: all,
expectedIntersection: allDriverCapabilities,
},
{
capabilities: none,
supportedCapabilities: all,
expectedIntersection: none,
},
{
capabilities: none,
supportedCapabilities: DriverCapabilities("cap1"),
expectedIntersection: none,
},
{
capabilities: DriverCapabilities("cap0,cap1"),
supportedCapabilities: DriverCapabilities("cap1,cap0"),
expectedIntersection: DriverCapabilities("cap0,cap1"),
},
{
capabilities: defaultDriverCapabilities,
supportedCapabilities: allDriverCapabilities,
expectedIntersection: defaultDriverCapabilities,
},
{
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
},
{
capabilities: DriverCapabilities("cap1"),
supportedCapabilities: none,
expectedIntersection: none,
},
{
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
intersection := tc.supportedCapabilities.Intersection(tc.capabilities)
require.EqualValues(t, tc.expectedIntersection, intersection)
})
}
}
func TestDriverCapabilitiesList(t *testing.T) {
testCases := []struct {
capabilities DriverCapabilities
expected []string
}{
{
capabilities: DriverCapabilities(""),
},
{
capabilities: DriverCapabilities(" "),
},
{
capabilities: DriverCapabilities(","),
},
{
capabilities: DriverCapabilities(",cap"),
expected: []string{"cap"},
},
{
capabilities: DriverCapabilities("cap,"),
expected: []string{"cap"},
},
{
capabilities: DriverCapabilities("cap0,,cap1"),
expected: []string{"cap0", "cap1"},
},
{
capabilities: DriverCapabilities("cap1,cap0,cap3"),
expected: []string{"cap1", "cap0", "cap3"},
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
require.EqualValues(t, tc.expected, tc.capabilities.list())
})
}
}

View File

@@ -271,10 +271,12 @@ func getMigMonitorDevices(env map[string]string) *string {
return nil
}
func getDriverCapabilities(env map[string]string, supportedDriverCapabilities DriverCapabilities, legacyImage bool) DriverCapabilities {
func (c *HookConfig) getDriverCapabilities(env map[string]string, legacyImage bool) image.DriverCapabilities {
// We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities
capabilities := supportedDriverCapabilities.Intersection(defaultDriverCapabilities)
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities)
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
capsEnv, capsEnvSpecified := env[envNVDriverCapabilities]
@@ -285,9 +287,9 @@ func getDriverCapabilities(env map[string]string, supportedDriverCapabilities Dr
if capsEnvSpecified && len(capsEnv) > 0 {
// If the envvironment variable is specified and is non-empty, use the capabilities value
envCapabilities := DriverCapabilities(capsEnv)
envCapabilities := image.NewDriverCapabilities(capsEnv)
capabilities = supportedDriverCapabilities.Intersection(envCapabilities)
if envCapabilities != all && capabilities != envCapabilities {
if !envCapabilities.IsAll() && len(capabilities) != len(envCapabilities) {
log.Panicln(fmt.Errorf("unsupported capabilities found in '%v' (allowed '%v')", envCapabilities, capabilities))
}
}
@@ -322,7 +324,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
}
driverCapabilities := getDriverCapabilities(image, hookConfig.SupportedDriverCapabilities, legacyImage).String()
driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()
requirements, err := image.GetRequirements()
if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"path/filepath"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/stretchr/testify/require"
)
@@ -39,7 +38,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -52,7 +51,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -83,7 +82,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -96,7 +95,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -110,7 +109,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -124,7 +123,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -138,7 +137,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0"},
},
},
@@ -154,7 +153,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
},
},
@@ -171,7 +170,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{},
},
},
@@ -201,7 +200,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -232,7 +231,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -245,7 +244,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -259,7 +258,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -273,7 +272,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: allDriverCapabilities.String(),
DriverCapabilities: image.SupportedDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -287,7 +286,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0"},
},
},
@@ -303,7 +302,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
},
},
@@ -320,7 +319,7 @@ func TestGetNvidiaConfig(t *testing.T) {
privileged: false,
expectedConfig: &nvidiaConfig{
Devices: "gpu0,gpu1",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
Requirements: []string{},
},
},
@@ -333,7 +332,7 @@ func TestGetNvidiaConfig(t *testing.T) {
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{},
},
},
@@ -348,7 +347,7 @@ func TestGetNvidiaConfig(t *testing.T) {
expectedConfig: &nvidiaConfig{
Devices: "all",
MigConfigDevices: "mig0,mig1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -373,7 +372,7 @@ func TestGetNvidiaConfig(t *testing.T) {
expectedConfig: &nvidiaConfig{
Devices: "all",
MigMonitorDevices: "mig0,mig1",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
Requirements: []string{"cuda>=9.0"},
},
},
@@ -399,7 +398,7 @@ func TestGetNvidiaConfig(t *testing.T) {
},
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
},
},
{
@@ -414,7 +413,7 @@ func TestGetNvidiaConfig(t *testing.T) {
},
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: "video,display",
DriverCapabilities: "display,video",
},
},
{
@@ -428,7 +427,7 @@ func TestGetNvidiaConfig(t *testing.T) {
},
expectedConfig: &nvidiaConfig{
Devices: "all",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
},
},
{
@@ -439,14 +438,12 @@ func TestGetNvidiaConfig(t *testing.T) {
},
privileged: true,
hookConfig: &HookConfig{
Config: config.Config{
SwarmResource: "DOCKER_SWARM_RESOURCE",
},
SwarmResource: "DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
},
},
{
@@ -457,14 +454,12 @@ func TestGetNvidiaConfig(t *testing.T) {
},
privileged: true,
hookConfig: &HookConfig{
Config: config.Config{
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
},
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute",
},
expectedConfig: &nvidiaConfig{
Devices: "GPU1,GPU2",
DriverCapabilities: defaultDriverCapabilities.String(),
DriverCapabilities: image.DefaultDriverCapabilities.String(),
},
},
}
@@ -924,7 +919,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
func TestGetDriverCapabilities(t *testing.T) {
supportedCapabilities := "compute,utility,display,video"
supportedCapabilities := "compute,display,utility,video"
testCases := []struct {
description string
@@ -959,7 +954,7 @@ func TestGetDriverCapabilities(t *testing.T) {
},
legacyImage: true,
supportedCapabilities: supportedCapabilities,
expectedCapabilities: defaultDriverCapabilities.String(),
expectedCapabilities: image.DefaultDriverCapabilities.String(),
},
{
description: "Env unset for legacy image is 'all'",
@@ -982,7 +977,7 @@ func TestGetDriverCapabilities(t *testing.T) {
env: map[string]string{},
legacyImage: false,
supportedCapabilities: supportedCapabilities,
expectedCapabilities: defaultDriverCapabilities.String(),
expectedCapabilities: image.DefaultDriverCapabilities.String(),
},
{
description: "Env is all for modern image",
@@ -1000,7 +995,7 @@ func TestGetDriverCapabilities(t *testing.T) {
},
legacyImage: false,
supportedCapabilities: supportedCapabilities,
expectedCapabilities: defaultDriverCapabilities.String(),
expectedCapabilities: image.DefaultDriverCapabilities.String(),
},
{
description: "Invalid capabilities panic",
@@ -1020,11 +1015,14 @@ func TestGetDriverCapabilities(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
var capabilites DriverCapabilities
var capabilites string
c := HookConfig{
SupportedDriverCapabilities: tc.supportedCapabilities,
}
getDriverCapabilities := func() {
supportedCapabilities := DriverCapabilities(tc.supportedCapabilities)
capabilites = getDriverCapabilities(tc.env, supportedCapabilities, tc.legacyImage)
capabilites = c.getDriverCapabilities(tc.env, tc.legacyImage).String()
}
if tc.expectedPanic {

View File

@@ -10,6 +10,7 @@ import (
"github.com/BurntSushi/toml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
)
const (
@@ -23,11 +24,7 @@ var defaultPaths = [...]string{
}
// HookConfig : options for the nvidia-container-runtime-hook.
type HookConfig struct {
config.Config
// TODO: We should also migrate the driver capabilities
SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"`
}
type HookConfig config.Config
func getDefaultHookConfig() (HookConfig, error) {
defaultCfg, err := config.GetDefault()
@@ -35,12 +32,7 @@ func getDefaultHookConfig() (HookConfig, error) {
return HookConfig{}, err
}
c := HookConfig{
Config: *defaultCfg,
SupportedDriverCapabilities: allDriverCapabilities,
}
return c, nil
return *(*HookConfig)(defaultCfg), nil
}
func getHookConfig() (*HookConfig, error) {
@@ -71,13 +63,15 @@ func getHookConfig() (*HookConfig, error) {
}
}
if config.SupportedDriverCapabilities == all {
config.SupportedDriverCapabilities = allDriverCapabilities
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" {
config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String()
}
// We ensure that the supported-driver-capabilites option is a subset of allDriverCapabilities
if intersection := allDriverCapabilities.Intersection(config.SupportedDriverCapabilities); intersection != config.SupportedDriverCapabilities {
configuredCapabilities := image.NewDriverCapabilities(config.SupportedDriverCapabilities)
// We ensure that the configured value is a subset of all supported capabilities
if !allSupportedDriverCapabilities.IsSuperset(configuredCapabilities) {
configName := config.getConfigOption("SupportedDriverCapabilities")
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allDriverCapabilities)
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String())
}
return &config, nil

View File

@@ -21,7 +21,7 @@ import (
"os"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/stretchr/testify/require"
)
@@ -29,16 +29,16 @@ func TestGetHookConfig(t *testing.T) {
testCases := []struct {
lines []string
expectedPanic bool
expectedDriverCapabilities DriverCapabilities
expectedDriverCapabilities string
}{
{
expectedDriverCapabilities: allDriverCapabilities,
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
},
{
lines: []string{
"supported-driver-capabilities = \"all\"",
},
expectedDriverCapabilities: allDriverCapabilities,
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
},
{
lines: []string{
@@ -48,19 +48,19 @@ func TestGetHookConfig(t *testing.T) {
},
{
lines: []string{},
expectedDriverCapabilities: allDriverCapabilities,
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
},
{
lines: []string{
"supported-driver-capabilities = \"\"",
},
expectedDriverCapabilities: none,
expectedDriverCapabilities: "",
},
{
lines: []string{
"supported-driver-capabilities = \"utility,compute\"",
"supported-driver-capabilities = \"compute,utility\"",
},
expectedDriverCapabilities: DriverCapabilities("utility,compute"),
expectedDriverCapabilities: "compute,utility",
},
}
@@ -144,9 +144,7 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
c := &HookConfig{
Config: config.Config{
SwarmResource: tc.value,
},
SwarmResource: tc.value,
}
envvars := c.getSwarmResourceEnvvars()