mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 00:08:11 +00:00
Refactor handling of DriverCapabilities
This change consolidates the handling of NVIDIA_DRIVER_CAPABILITIES in the interal/image package. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
4dcaa61167
commit
b18ac09f77
@ -2,15 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
allDriverCapabilities = DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx")
|
||||
defaultDriverCapabilities = DriverCapabilities("utility,compute")
|
||||
|
||||
none = DriverCapabilities("")
|
||||
all = DriverCapabilities("all")
|
||||
)
|
||||
|
||||
func capabilityToCLI(cap string) string {
|
||||
@ -34,50 +25,3 @@ func capabilityToCLI(cap string) string {
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// DriverCapabilities is used to process the NVIDIA_DRIVER_CAPABILITIES environment
|
||||
// variable. Operations include default values, filtering, and handling meta values such as "all"
|
||||
type DriverCapabilities string
|
||||
|
||||
// Intersection returns intersection between two sets of capabilities.
|
||||
func (d DriverCapabilities) Intersection(capabilities DriverCapabilities) DriverCapabilities {
|
||||
if capabilities == all {
|
||||
return d
|
||||
}
|
||||
if d == all {
|
||||
return capabilities
|
||||
}
|
||||
|
||||
lookup := make(map[string]bool)
|
||||
for _, c := range d.list() {
|
||||
lookup[c] = true
|
||||
}
|
||||
var found []string
|
||||
for _, c := range capabilities.list() {
|
||||
if lookup[c] {
|
||||
found = append(found, c)
|
||||
}
|
||||
}
|
||||
|
||||
intersection := DriverCapabilities(strings.Join(found, ","))
|
||||
return intersection
|
||||
}
|
||||
|
||||
// String returns the string representation of the driver capabilities
|
||||
func (d DriverCapabilities) String() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// list returns the driver capabilities as a list
|
||||
func (d DriverCapabilities) list() []string {
|
||||
var caps []string
|
||||
for _, c := range strings.Split(string(d), ",") {
|
||||
trimmed := strings.TrimSpace(c)
|
||||
if len(trimmed) == 0 {
|
||||
continue
|
||||
}
|
||||
caps = append(caps, trimmed)
|
||||
}
|
||||
|
||||
return caps
|
||||
}
|
||||
|
@ -1,134 +0,0 @@
|
||||
/**
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDriverCapabilitiesIntersection(t *testing.T) {
|
||||
testCases := []struct {
|
||||
capabilities DriverCapabilities
|
||||
supportedCapabilities DriverCapabilities
|
||||
expectedIntersection DriverCapabilities
|
||||
}{
|
||||
{
|
||||
capabilities: none,
|
||||
supportedCapabilities: none,
|
||||
expectedIntersection: none,
|
||||
},
|
||||
{
|
||||
capabilities: all,
|
||||
supportedCapabilities: none,
|
||||
expectedIntersection: none,
|
||||
},
|
||||
{
|
||||
capabilities: all,
|
||||
supportedCapabilities: allDriverCapabilities,
|
||||
expectedIntersection: allDriverCapabilities,
|
||||
},
|
||||
{
|
||||
capabilities: allDriverCapabilities,
|
||||
supportedCapabilities: all,
|
||||
expectedIntersection: allDriverCapabilities,
|
||||
},
|
||||
{
|
||||
capabilities: none,
|
||||
supportedCapabilities: all,
|
||||
expectedIntersection: none,
|
||||
},
|
||||
{
|
||||
capabilities: none,
|
||||
supportedCapabilities: DriverCapabilities("cap1"),
|
||||
expectedIntersection: none,
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities("cap0,cap1"),
|
||||
supportedCapabilities: DriverCapabilities("cap1,cap0"),
|
||||
expectedIntersection: DriverCapabilities("cap0,cap1"),
|
||||
},
|
||||
{
|
||||
capabilities: defaultDriverCapabilities,
|
||||
supportedCapabilities: allDriverCapabilities,
|
||||
expectedIntersection: defaultDriverCapabilities,
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
||||
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities("cap1"),
|
||||
supportedCapabilities: none,
|
||||
expectedIntersection: none,
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
||||
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
intersection := tc.supportedCapabilities.Intersection(tc.capabilities)
|
||||
require.EqualValues(t, tc.expectedIntersection, intersection)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverCapabilitiesList(t *testing.T) {
|
||||
testCases := []struct {
|
||||
capabilities DriverCapabilities
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
capabilities: DriverCapabilities(""),
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities(" "),
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities(","),
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities(",cap"),
|
||||
expected: []string{"cap"},
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities("cap,"),
|
||||
expected: []string{"cap"},
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities("cap0,,cap1"),
|
||||
expected: []string{"cap0", "cap1"},
|
||||
},
|
||||
{
|
||||
capabilities: DriverCapabilities("cap1,cap0,cap3"),
|
||||
expected: []string{"cap1", "cap0", "cap3"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
require.EqualValues(t, tc.expected, tc.capabilities.list())
|
||||
})
|
||||
}
|
||||
}
|
@ -271,10 +271,12 @@ func getMigMonitorDevices(env map[string]string) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getDriverCapabilities(env map[string]string, supportedDriverCapabilities DriverCapabilities, legacyImage bool) DriverCapabilities {
|
||||
func (c *HookConfig) getDriverCapabilities(env map[string]string, legacyImage bool) image.DriverCapabilities {
|
||||
// We use the default driver capabilities by default. This is filtered to only include the
|
||||
// supported capabilities
|
||||
capabilities := supportedDriverCapabilities.Intersection(defaultDriverCapabilities)
|
||||
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities)
|
||||
|
||||
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
|
||||
|
||||
capsEnv, capsEnvSpecified := env[envNVDriverCapabilities]
|
||||
|
||||
@ -285,9 +287,9 @@ func getDriverCapabilities(env map[string]string, supportedDriverCapabilities Dr
|
||||
|
||||
if capsEnvSpecified && len(capsEnv) > 0 {
|
||||
// If the envvironment variable is specified and is non-empty, use the capabilities value
|
||||
envCapabilities := DriverCapabilities(capsEnv)
|
||||
envCapabilities := image.NewDriverCapabilities(capsEnv)
|
||||
capabilities = supportedDriverCapabilities.Intersection(envCapabilities)
|
||||
if envCapabilities != all && capabilities != envCapabilities {
|
||||
if !envCapabilities.IsAll() && len(capabilities) != len(envCapabilities) {
|
||||
log.Panicln(fmt.Errorf("unsupported capabilities found in '%v' (allowed '%v')", envCapabilities, capabilities))
|
||||
}
|
||||
}
|
||||
@ -322,7 +324,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p
|
||||
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
|
||||
}
|
||||
|
||||
driverCapabilities := getDriverCapabilities(image, hookConfig.SupportedDriverCapabilities, legacyImage).String()
|
||||
driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()
|
||||
|
||||
requirements, err := image.GetRequirements()
|
||||
if err != nil {
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -39,7 +38,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
DriverCapabilities: allDriverCapabilities.String(),
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -52,7 +51,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
DriverCapabilities: allDriverCapabilities.String(),
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -83,7 +82,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "",
|
||||
DriverCapabilities: allDriverCapabilities.String(),
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -96,7 +95,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: allDriverCapabilities.String(),
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -110,7 +109,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -124,7 +123,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: allDriverCapabilities.String(),
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -138,7 +137,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -154,7 +153,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
||||
},
|
||||
},
|
||||
@ -171,7 +170,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{},
|
||||
},
|
||||
},
|
||||
@ -201,7 +200,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -232,7 +231,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -245,7 +244,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -259,7 +258,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -273,7 +272,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: allDriverCapabilities.String(),
|
||||
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -287,7 +286,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -303,7 +302,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
||||
},
|
||||
},
|
||||
@ -320,7 +319,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
privileged: false,
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "gpu0,gpu1",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
Requirements: []string{},
|
||||
},
|
||||
},
|
||||
@ -333,7 +332,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{},
|
||||
},
|
||||
},
|
||||
@ -348,7 +347,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
MigConfigDevices: "mig0,mig1",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -373,7 +372,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
MigMonitorDevices: "mig0,mig1",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
Requirements: []string{"cuda>=9.0"},
|
||||
},
|
||||
},
|
||||
@ -399,7 +398,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -414,7 +413,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
DriverCapabilities: "video,display",
|
||||
DriverCapabilities: "display,video",
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -428,7 +427,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "all",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -439,14 +438,12 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: true,
|
||||
hookConfig: &HookConfig{
|
||||
Config: config.Config{
|
||||
SwarmResource: "DOCKER_SWARM_RESOURCE",
|
||||
},
|
||||
SwarmResource: "DOCKER_SWARM_RESOURCE",
|
||||
SupportedDriverCapabilities: "video,display,utility,compute",
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "GPU1,GPU2",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -457,14 +454,12 @@ func TestGetNvidiaConfig(t *testing.T) {
|
||||
},
|
||||
privileged: true,
|
||||
hookConfig: &HookConfig{
|
||||
Config: config.Config{
|
||||
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
|
||||
},
|
||||
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
|
||||
SupportedDriverCapabilities: "video,display,utility,compute",
|
||||
},
|
||||
expectedConfig: &nvidiaConfig{
|
||||
Devices: "GPU1,GPU2",
|
||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
||||
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -924,7 +919,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
||||
|
||||
func TestGetDriverCapabilities(t *testing.T) {
|
||||
|
||||
supportedCapabilities := "compute,utility,display,video"
|
||||
supportedCapabilities := "compute,display,utility,video"
|
||||
|
||||
testCases := []struct {
|
||||
description string
|
||||
@ -959,7 +954,7 @@ func TestGetDriverCapabilities(t *testing.T) {
|
||||
},
|
||||
legacyImage: true,
|
||||
supportedCapabilities: supportedCapabilities,
|
||||
expectedCapabilities: defaultDriverCapabilities.String(),
|
||||
expectedCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
{
|
||||
description: "Env unset for legacy image is 'all'",
|
||||
@ -982,7 +977,7 @@ func TestGetDriverCapabilities(t *testing.T) {
|
||||
env: map[string]string{},
|
||||
legacyImage: false,
|
||||
supportedCapabilities: supportedCapabilities,
|
||||
expectedCapabilities: defaultDriverCapabilities.String(),
|
||||
expectedCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
{
|
||||
description: "Env is all for modern image",
|
||||
@ -1000,7 +995,7 @@ func TestGetDriverCapabilities(t *testing.T) {
|
||||
},
|
||||
legacyImage: false,
|
||||
supportedCapabilities: supportedCapabilities,
|
||||
expectedCapabilities: defaultDriverCapabilities.String(),
|
||||
expectedCapabilities: image.DefaultDriverCapabilities.String(),
|
||||
},
|
||||
{
|
||||
description: "Invalid capabilities panic",
|
||||
@ -1020,11 +1015,14 @@ func TestGetDriverCapabilities(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
var capabilites DriverCapabilities
|
||||
var capabilites string
|
||||
|
||||
c := HookConfig{
|
||||
SupportedDriverCapabilities: tc.supportedCapabilities,
|
||||
}
|
||||
|
||||
getDriverCapabilities := func() {
|
||||
supportedCapabilities := DriverCapabilities(tc.supportedCapabilities)
|
||||
capabilites = getDriverCapabilities(tc.env, supportedCapabilities, tc.legacyImage)
|
||||
capabilites = c.getDriverCapabilities(tc.env, tc.legacyImage).String()
|
||||
}
|
||||
|
||||
if tc.expectedPanic {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -23,11 +24,7 @@ var defaultPaths = [...]string{
|
||||
}
|
||||
|
||||
// HookConfig : options for the nvidia-container-runtime-hook.
|
||||
type HookConfig struct {
|
||||
config.Config
|
||||
// TODO: We should also migrate the driver capabilities
|
||||
SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"`
|
||||
}
|
||||
type HookConfig config.Config
|
||||
|
||||
func getDefaultHookConfig() (HookConfig, error) {
|
||||
defaultCfg, err := config.GetDefault()
|
||||
@ -35,12 +32,7 @@ func getDefaultHookConfig() (HookConfig, error) {
|
||||
return HookConfig{}, err
|
||||
}
|
||||
|
||||
c := HookConfig{
|
||||
Config: *defaultCfg,
|
||||
SupportedDriverCapabilities: allDriverCapabilities,
|
||||
}
|
||||
|
||||
return c, nil
|
||||
return *(*HookConfig)(defaultCfg), nil
|
||||
}
|
||||
|
||||
func getHookConfig() (*HookConfig, error) {
|
||||
@ -71,13 +63,15 @@ func getHookConfig() (*HookConfig, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if config.SupportedDriverCapabilities == all {
|
||||
config.SupportedDriverCapabilities = allDriverCapabilities
|
||||
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
|
||||
if config.SupportedDriverCapabilities == "all" {
|
||||
config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String()
|
||||
}
|
||||
// We ensure that the supported-driver-capabilites option is a subset of allDriverCapabilities
|
||||
if intersection := allDriverCapabilities.Intersection(config.SupportedDriverCapabilities); intersection != config.SupportedDriverCapabilities {
|
||||
configuredCapabilities := image.NewDriverCapabilities(config.SupportedDriverCapabilities)
|
||||
// We ensure that the configured value is a subset of all supported capabilities
|
||||
if !allSupportedDriverCapabilities.IsSuperset(configuredCapabilities) {
|
||||
configName := config.getConfigOption("SupportedDriverCapabilities")
|
||||
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allDriverCapabilities)
|
||||
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String())
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
|
@ -21,7 +21,7 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -29,16 +29,16 @@ func TestGetHookConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
lines []string
|
||||
expectedPanic bool
|
||||
expectedDriverCapabilities DriverCapabilities
|
||||
expectedDriverCapabilities string
|
||||
}{
|
||||
{
|
||||
expectedDriverCapabilities: allDriverCapabilities,
|
||||
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
},
|
||||
{
|
||||
lines: []string{
|
||||
"supported-driver-capabilities = \"all\"",
|
||||
},
|
||||
expectedDriverCapabilities: allDriverCapabilities,
|
||||
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
},
|
||||
{
|
||||
lines: []string{
|
||||
@ -48,19 +48,19 @@ func TestGetHookConfig(t *testing.T) {
|
||||
},
|
||||
{
|
||||
lines: []string{},
|
||||
expectedDriverCapabilities: allDriverCapabilities,
|
||||
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
},
|
||||
{
|
||||
lines: []string{
|
||||
"supported-driver-capabilities = \"\"",
|
||||
},
|
||||
expectedDriverCapabilities: none,
|
||||
expectedDriverCapabilities: "",
|
||||
},
|
||||
{
|
||||
lines: []string{
|
||||
"supported-driver-capabilities = \"utility,compute\"",
|
||||
"supported-driver-capabilities = \"compute,utility\"",
|
||||
},
|
||||
expectedDriverCapabilities: DriverCapabilities("utility,compute"),
|
||||
expectedDriverCapabilities: "compute,utility",
|
||||
},
|
||||
}
|
||||
|
||||
@ -144,9 +144,7 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||
c := &HookConfig{
|
||||
Config: config.Config{
|
||||
SwarmResource: tc.value,
|
||||
},
|
||||
SwarmResource: tc.value,
|
||||
}
|
||||
|
||||
envvars := c.getSwarmResourceEnvvars()
|
||||
|
@ -25,6 +25,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
||||
@ -61,7 +62,7 @@ type Config struct {
|
||||
SwarmResource string `toml:"swarm-resource"`
|
||||
AcceptEnvvarUnprivileged bool `toml:"accept-nvidia-visible-devices-envvar-when-unprivileged"`
|
||||
AcceptDeviceListAsVolumeMounts bool `toml:"accept-nvidia-visible-devices-as-volume-mounts"`
|
||||
// SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"`
|
||||
SupportedDriverCapabilities string `toml:"supported-driver-capabilities"`
|
||||
|
||||
NVIDIAContainerCLIConfig ContainerCLIConfig `toml:"nvidia-container-cli"`
|
||||
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
|
||||
@ -135,7 +136,8 @@ func getFromTree(toml *toml.Tree) (*Config, error) {
|
||||
// GetDefault defines the default values for the config
|
||||
func GetDefault() (*Config, error) {
|
||||
d := Config{
|
||||
AcceptEnvvarUnprivileged: true,
|
||||
AcceptEnvvarUnprivileged: true,
|
||||
SupportedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||
LoadKmods: true,
|
||||
Ldconfig: getLdConfigPath(),
|
||||
|
@ -60,7 +60,8 @@ func TestGetConfig(t *testing.T) {
|
||||
description: "empty config is default",
|
||||
inspectLdconfig: true,
|
||||
expectedConfig: &Config{
|
||||
AcceptEnvvarUnprivileged: true,
|
||||
AcceptEnvvarUnprivileged: true,
|
||||
SupportedDriverCapabilities: "compat32,compute,display,graphics,ngx,utility,video",
|
||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||
Root: "",
|
||||
LoadKmods: true,
|
||||
@ -94,6 +95,7 @@ func TestGetConfig(t *testing.T) {
|
||||
description: "config options set inline",
|
||||
contents: []string{
|
||||
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
|
||||
"supported-driver-capabilities = \"compute,utility\"",
|
||||
"nvidia-container-cli.root = \"/bar/baz\"",
|
||||
"nvidia-container-cli.load-kmods = false",
|
||||
"nvidia-container-cli.ldconfig = \"/foo/bar/ldconfig\"",
|
||||
@ -110,7 +112,8 @@ func TestGetConfig(t *testing.T) {
|
||||
"nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"",
|
||||
},
|
||||
expectedConfig: &Config{
|
||||
AcceptEnvvarUnprivileged: false,
|
||||
AcceptEnvvarUnprivileged: false,
|
||||
SupportedDriverCapabilities: "compute,utility",
|
||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||
Root: "/bar/baz",
|
||||
LoadKmods: false,
|
||||
@ -150,6 +153,7 @@ func TestGetConfig(t *testing.T) {
|
||||
description: "config options set in section",
|
||||
contents: []string{
|
||||
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
|
||||
"supported-driver-capabilities = \"compute,utility\"",
|
||||
"[nvidia-container-cli]",
|
||||
"root = \"/bar/baz\"",
|
||||
"load-kmods = false",
|
||||
@ -172,7 +176,8 @@ func TestGetConfig(t *testing.T) {
|
||||
"path = \"/foo/bar/nvidia-ctk\"",
|
||||
},
|
||||
expectedConfig: &Config{
|
||||
AcceptEnvvarUnprivileged: false,
|
||||
AcceptEnvvarUnprivileged: false,
|
||||
SupportedDriverCapabilities: "compute,utility",
|
||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||
Root: "/bar/baz",
|
||||
LoadKmods: false,
|
||||
|
@ -16,12 +16,18 @@
|
||||
|
||||
package image
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DriverCapability represents the possible values of NVIDIA_DRIVER_CAPABILITIES
|
||||
type DriverCapability string
|
||||
|
||||
// Constants for the supported driver capabilities
|
||||
const (
|
||||
DriverCapabilityAll DriverCapability = "all"
|
||||
DriverCapabilityNone DriverCapability = "none"
|
||||
DriverCapabilityCompat32 DriverCapability = "compat32"
|
||||
DriverCapabilityCompute DriverCapability = "compute"
|
||||
DriverCapabilityDisplay DriverCapability = "display"
|
||||
@ -31,12 +37,37 @@ const (
|
||||
DriverCapabilityVideo DriverCapability = "video"
|
||||
)
|
||||
|
||||
var (
|
||||
driverCapabilitiesNone = NewDriverCapabilities()
|
||||
driverCapabilitiesAll = NewDriverCapabilities("all")
|
||||
|
||||
// DefaultDriverCapabilities sets the value for driver capabilities if no value is set.
|
||||
DefaultDriverCapabilities = NewDriverCapabilities("utility,compute")
|
||||
// SupportedDriverCapabilities defines the set of all supported driver capabilities.
|
||||
SupportedDriverCapabilities = NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx")
|
||||
)
|
||||
|
||||
// NewDriverCapabilities creates a set of driver capabilities from the specified capabilities
|
||||
func NewDriverCapabilities(capabilities ...string) DriverCapabilities {
|
||||
dc := make(DriverCapabilities)
|
||||
for _, capability := range capabilities {
|
||||
for _, c := range strings.Split(capability, ",") {
|
||||
trimmed := strings.TrimSpace(c)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
dc[DriverCapability(trimmed)] = true
|
||||
}
|
||||
}
|
||||
return dc
|
||||
}
|
||||
|
||||
// DriverCapabilities represents the NVIDIA_DRIVER_CAPABILITIES set for the specified image.
|
||||
type DriverCapabilities map[DriverCapability]bool
|
||||
|
||||
// Has check whether the specified capability is selected.
|
||||
func (c DriverCapabilities) Has(capability DriverCapability) bool {
|
||||
if c[DriverCapabilityAll] {
|
||||
if c.IsAll() {
|
||||
return true
|
||||
}
|
||||
return c[capability]
|
||||
@ -44,11 +75,72 @@ func (c DriverCapabilities) Has(capability DriverCapability) bool {
|
||||
|
||||
// Any checks whether any of the specified capabilites are set
|
||||
func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool {
|
||||
if c.IsAll() {
|
||||
return true
|
||||
}
|
||||
for _, cap := range capabilities {
|
||||
if c.Has(cap) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// List returns the list of driver capabilities.
|
||||
// The list is sorted.
|
||||
func (c DriverCapabilities) List() []string {
|
||||
var capabilities []string
|
||||
for capability := range c {
|
||||
capabilities = append(capabilities, string(capability))
|
||||
}
|
||||
sort.Strings(capabilities)
|
||||
return capabilities
|
||||
}
|
||||
|
||||
// String returns the string repesentation of the driver capabilities.
|
||||
func (c DriverCapabilities) String() string {
|
||||
if c.IsAll() {
|
||||
return "all"
|
||||
}
|
||||
return strings.Join(c.List(), ",")
|
||||
}
|
||||
|
||||
// IsAll indicates whether the set of capabilities is `all`
|
||||
func (c DriverCapabilities) IsAll() bool {
|
||||
return c[DriverCapabilityAll]
|
||||
}
|
||||
|
||||
// Intersection returns a new set which includes the item in BOTH d and s2.
|
||||
// For example: d = {a1, a2} s2 = {a2, a3} s1.Intersection(s2) = {a2}
|
||||
func (c DriverCapabilities) Intersection(s2 DriverCapabilities) DriverCapabilities {
|
||||
if s2.IsAll() {
|
||||
return c
|
||||
}
|
||||
if c.IsAll() {
|
||||
return s2
|
||||
}
|
||||
|
||||
intersection := make(DriverCapabilities)
|
||||
for capability := range s2 {
|
||||
if c[capability] {
|
||||
intersection[capability] = true
|
||||
}
|
||||
}
|
||||
|
||||
return intersection
|
||||
}
|
||||
|
||||
// IsSuperset returns true if and only if d is a superset of s2.
|
||||
func (c DriverCapabilities) IsSuperset(s2 DriverCapabilities) bool {
|
||||
if c.IsAll() {
|
||||
return true
|
||||
}
|
||||
|
||||
for capability := range s2 {
|
||||
if !c[capability] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
134
internal/config/image/capabilities_test.go
Normal file
134
internal/config/image/capabilities_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
/**
|
||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package image
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDriverCapabilitiesIntersection(t *testing.T) {
|
||||
testCases := []struct {
|
||||
capabilities DriverCapabilities
|
||||
supportedCapabilities DriverCapabilities
|
||||
expectedIntersection DriverCapabilities
|
||||
}{
|
||||
{
|
||||
capabilities: driverCapabilitiesNone,
|
||||
supportedCapabilities: driverCapabilitiesNone,
|
||||
expectedIntersection: driverCapabilitiesNone,
|
||||
},
|
||||
{
|
||||
capabilities: driverCapabilitiesAll,
|
||||
supportedCapabilities: driverCapabilitiesNone,
|
||||
expectedIntersection: driverCapabilitiesNone,
|
||||
},
|
||||
{
|
||||
capabilities: driverCapabilitiesAll,
|
||||
supportedCapabilities: SupportedDriverCapabilities,
|
||||
expectedIntersection: SupportedDriverCapabilities,
|
||||
},
|
||||
{
|
||||
capabilities: SupportedDriverCapabilities,
|
||||
supportedCapabilities: driverCapabilitiesAll,
|
||||
expectedIntersection: SupportedDriverCapabilities,
|
||||
},
|
||||
{
|
||||
capabilities: driverCapabilitiesNone,
|
||||
supportedCapabilities: driverCapabilitiesAll,
|
||||
expectedIntersection: driverCapabilitiesNone,
|
||||
},
|
||||
{
|
||||
capabilities: driverCapabilitiesNone,
|
||||
supportedCapabilities: NewDriverCapabilities("cap1"),
|
||||
expectedIntersection: driverCapabilitiesNone,
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities("cap0,cap1"),
|
||||
supportedCapabilities: NewDriverCapabilities("cap1,cap0"),
|
||||
expectedIntersection: NewDriverCapabilities("cap0,cap1"),
|
||||
},
|
||||
{
|
||||
capabilities: DefaultDriverCapabilities,
|
||||
supportedCapabilities: SupportedDriverCapabilities,
|
||||
expectedIntersection: DefaultDriverCapabilities,
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
||||
expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities("cap1"),
|
||||
supportedCapabilities: driverCapabilitiesNone,
|
||||
expectedIntersection: driverCapabilitiesNone,
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
||||
supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
intersection := tc.supportedCapabilities.Intersection(tc.capabilities)
|
||||
require.EqualValues(t, tc.expectedIntersection, intersection)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriverCapabilitiesList(t *testing.T) {
|
||||
testCases := []struct {
|
||||
capabilities DriverCapabilities
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
capabilities: NewDriverCapabilities(""),
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities(" "),
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities(","),
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities(",cap"),
|
||||
expected: []string{"cap"},
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities("cap,"),
|
||||
expected: []string{"cap"},
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities("cap0,,cap1"),
|
||||
expected: []string{"cap0", "cap1"},
|
||||
},
|
||||
{
|
||||
capabilities: NewDriverCapabilities("cap1,cap0,cap3"),
|
||||
expected: []string{"cap0", "cap1", "cap3"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||
require.EqualValues(t, tc.expected, tc.capabilities.List())
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user