mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-05-04 04:05:14 +00:00
Refactor handling of DriverCapabilities
This change consolidates the handling of NVIDIA_DRIVER_CAPABILITIES in the interal/image package. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
4dcaa61167
commit
b18ac09f77
@ -2,15 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
allDriverCapabilities = DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx")
|
|
||||||
defaultDriverCapabilities = DriverCapabilities("utility,compute")
|
|
||||||
|
|
||||||
none = DriverCapabilities("")
|
|
||||||
all = DriverCapabilities("all")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func capabilityToCLI(cap string) string {
|
func capabilityToCLI(cap string) string {
|
||||||
@ -34,50 +25,3 @@ func capabilityToCLI(cap string) string {
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// DriverCapabilities is used to process the NVIDIA_DRIVER_CAPABILITIES environment
|
|
||||||
// variable. Operations include default values, filtering, and handling meta values such as "all"
|
|
||||||
type DriverCapabilities string
|
|
||||||
|
|
||||||
// Intersection returns intersection between two sets of capabilities.
|
|
||||||
func (d DriverCapabilities) Intersection(capabilities DriverCapabilities) DriverCapabilities {
|
|
||||||
if capabilities == all {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
if d == all {
|
|
||||||
return capabilities
|
|
||||||
}
|
|
||||||
|
|
||||||
lookup := make(map[string]bool)
|
|
||||||
for _, c := range d.list() {
|
|
||||||
lookup[c] = true
|
|
||||||
}
|
|
||||||
var found []string
|
|
||||||
for _, c := range capabilities.list() {
|
|
||||||
if lookup[c] {
|
|
||||||
found = append(found, c)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
intersection := DriverCapabilities(strings.Join(found, ","))
|
|
||||||
return intersection
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns the string representation of the driver capabilities
|
|
||||||
func (d DriverCapabilities) String() string {
|
|
||||||
return string(d)
|
|
||||||
}
|
|
||||||
|
|
||||||
// list returns the driver capabilities as a list
|
|
||||||
func (d DriverCapabilities) list() []string {
|
|
||||||
var caps []string
|
|
||||||
for _, c := range strings.Split(string(d), ",") {
|
|
||||||
trimmed := strings.TrimSpace(c)
|
|
||||||
if len(trimmed) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
caps = append(caps, trimmed)
|
|
||||||
}
|
|
||||||
|
|
||||||
return caps
|
|
||||||
}
|
|
||||||
|
@ -1,134 +0,0 @@
|
|||||||
/**
|
|
||||||
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
**/
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDriverCapabilitiesIntersection(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
capabilities DriverCapabilities
|
|
||||||
supportedCapabilities DriverCapabilities
|
|
||||||
expectedIntersection DriverCapabilities
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
capabilities: none,
|
|
||||||
supportedCapabilities: none,
|
|
||||||
expectedIntersection: none,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: all,
|
|
||||||
supportedCapabilities: none,
|
|
||||||
expectedIntersection: none,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: all,
|
|
||||||
supportedCapabilities: allDriverCapabilities,
|
|
||||||
expectedIntersection: allDriverCapabilities,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: allDriverCapabilities,
|
|
||||||
supportedCapabilities: all,
|
|
||||||
expectedIntersection: allDriverCapabilities,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: none,
|
|
||||||
supportedCapabilities: all,
|
|
||||||
expectedIntersection: none,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: none,
|
|
||||||
supportedCapabilities: DriverCapabilities("cap1"),
|
|
||||||
expectedIntersection: none,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities("cap0,cap1"),
|
|
||||||
supportedCapabilities: DriverCapabilities("cap1,cap0"),
|
|
||||||
expectedIntersection: DriverCapabilities("cap0,cap1"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: defaultDriverCapabilities,
|
|
||||||
supportedCapabilities: allDriverCapabilities,
|
|
||||||
expectedIntersection: defaultDriverCapabilities,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
|
||||||
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
|
||||||
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities("cap1"),
|
|
||||||
supportedCapabilities: none,
|
|
||||||
expectedIntersection: none,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
|
||||||
supportedCapabilities: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
|
||||||
expectedIntersection: DriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range testCases {
|
|
||||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
|
||||||
intersection := tc.supportedCapabilities.Intersection(tc.capabilities)
|
|
||||||
require.EqualValues(t, tc.expectedIntersection, intersection)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDriverCapabilitiesList(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
capabilities DriverCapabilities
|
|
||||||
expected []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities(""),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities(" "),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities(","),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities(",cap"),
|
|
||||||
expected: []string{"cap"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities("cap,"),
|
|
||||||
expected: []string{"cap"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities("cap0,,cap1"),
|
|
||||||
expected: []string{"cap0", "cap1"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
capabilities: DriverCapabilities("cap1,cap0,cap3"),
|
|
||||||
expected: []string{"cap1", "cap0", "cap3"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range testCases {
|
|
||||||
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
|
||||||
require.EqualValues(t, tc.expected, tc.capabilities.list())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -271,10 +271,12 @@ func getMigMonitorDevices(env map[string]string) *string {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDriverCapabilities(env map[string]string, supportedDriverCapabilities DriverCapabilities, legacyImage bool) DriverCapabilities {
|
func (c *HookConfig) getDriverCapabilities(env map[string]string, legacyImage bool) image.DriverCapabilities {
|
||||||
// We use the default driver capabilities by default. This is filtered to only include the
|
// We use the default driver capabilities by default. This is filtered to only include the
|
||||||
// supported capabilities
|
// supported capabilities
|
||||||
capabilities := supportedDriverCapabilities.Intersection(defaultDriverCapabilities)
|
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities)
|
||||||
|
|
||||||
|
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
|
||||||
|
|
||||||
capsEnv, capsEnvSpecified := env[envNVDriverCapabilities]
|
capsEnv, capsEnvSpecified := env[envNVDriverCapabilities]
|
||||||
|
|
||||||
@ -285,9 +287,9 @@ func getDriverCapabilities(env map[string]string, supportedDriverCapabilities Dr
|
|||||||
|
|
||||||
if capsEnvSpecified && len(capsEnv) > 0 {
|
if capsEnvSpecified && len(capsEnv) > 0 {
|
||||||
// If the envvironment variable is specified and is non-empty, use the capabilities value
|
// If the envvironment variable is specified and is non-empty, use the capabilities value
|
||||||
envCapabilities := DriverCapabilities(capsEnv)
|
envCapabilities := image.NewDriverCapabilities(capsEnv)
|
||||||
capabilities = supportedDriverCapabilities.Intersection(envCapabilities)
|
capabilities = supportedDriverCapabilities.Intersection(envCapabilities)
|
||||||
if envCapabilities != all && capabilities != envCapabilities {
|
if !envCapabilities.IsAll() && len(capabilities) != len(envCapabilities) {
|
||||||
log.Panicln(fmt.Errorf("unsupported capabilities found in '%v' (allowed '%v')", envCapabilities, capabilities))
|
log.Panicln(fmt.Errorf("unsupported capabilities found in '%v' (allowed '%v')", envCapabilities, capabilities))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -322,7 +324,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, p
|
|||||||
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
|
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
|
||||||
}
|
}
|
||||||
|
|
||||||
driverCapabilities := getDriverCapabilities(image, hookConfig.SupportedDriverCapabilities, legacyImage).String()
|
driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()
|
||||||
|
|
||||||
requirements, err := image.GetRequirements()
|
requirements, err := image.GetRequirements()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -5,7 +5,6 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@ -39,7 +38,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
DriverCapabilities: allDriverCapabilities.String(),
|
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -52,7 +51,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
DriverCapabilities: allDriverCapabilities.String(),
|
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -83,7 +82,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "",
|
Devices: "",
|
||||||
DriverCapabilities: allDriverCapabilities.String(),
|
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -96,7 +95,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: allDriverCapabilities.String(),
|
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -110,7 +109,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -124,7 +123,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: allDriverCapabilities.String(),
|
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -138,7 +137,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -154,7 +153,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -171,7 +170,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
Requirements: []string{},
|
Requirements: []string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -201,7 +200,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -232,7 +231,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "",
|
Devices: "",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -245,7 +244,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -259,7 +258,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -273,7 +272,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: allDriverCapabilities.String(),
|
DriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -287,7 +286,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -303,7 +302,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
Requirements: []string{"cuda>=9.0", "req0=true", "req1=false"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -320,7 +319,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
privileged: false,
|
privileged: false,
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "gpu0,gpu1",
|
Devices: "gpu0,gpu1",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
Requirements: []string{},
|
Requirements: []string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -333,7 +332,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
|
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{},
|
Requirements: []string{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -348,7 +347,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
MigConfigDevices: "mig0,mig1",
|
MigConfigDevices: "mig0,mig1",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -373,7 +372,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
MigMonitorDevices: "mig0,mig1",
|
MigMonitorDevices: "mig0,mig1",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
Requirements: []string{"cuda>=9.0"},
|
Requirements: []string{"cuda>=9.0"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -399,7 +398,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -414,7 +413,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
DriverCapabilities: "video,display",
|
DriverCapabilities: "display,video",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -428,7 +427,7 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "all",
|
Devices: "all",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -439,14 +438,12 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
privileged: true,
|
privileged: true,
|
||||||
hookConfig: &HookConfig{
|
hookConfig: &HookConfig{
|
||||||
Config: config.Config{
|
SwarmResource: "DOCKER_SWARM_RESOURCE",
|
||||||
SwarmResource: "DOCKER_SWARM_RESOURCE",
|
|
||||||
},
|
|
||||||
SupportedDriverCapabilities: "video,display,utility,compute",
|
SupportedDriverCapabilities: "video,display,utility,compute",
|
||||||
},
|
},
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "GPU1,GPU2",
|
Devices: "GPU1,GPU2",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -457,14 +454,12 @@ func TestGetNvidiaConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
privileged: true,
|
privileged: true,
|
||||||
hookConfig: &HookConfig{
|
hookConfig: &HookConfig{
|
||||||
Config: config.Config{
|
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
|
||||||
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
|
|
||||||
},
|
|
||||||
SupportedDriverCapabilities: "video,display,utility,compute",
|
SupportedDriverCapabilities: "video,display,utility,compute",
|
||||||
},
|
},
|
||||||
expectedConfig: &nvidiaConfig{
|
expectedConfig: &nvidiaConfig{
|
||||||
Devices: "GPU1,GPU2",
|
Devices: "GPU1,GPU2",
|
||||||
DriverCapabilities: defaultDriverCapabilities.String(),
|
DriverCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -924,7 +919,7 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
|
|||||||
|
|
||||||
func TestGetDriverCapabilities(t *testing.T) {
|
func TestGetDriverCapabilities(t *testing.T) {
|
||||||
|
|
||||||
supportedCapabilities := "compute,utility,display,video"
|
supportedCapabilities := "compute,display,utility,video"
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
description string
|
description string
|
||||||
@ -959,7 +954,7 @@ func TestGetDriverCapabilities(t *testing.T) {
|
|||||||
},
|
},
|
||||||
legacyImage: true,
|
legacyImage: true,
|
||||||
supportedCapabilities: supportedCapabilities,
|
supportedCapabilities: supportedCapabilities,
|
||||||
expectedCapabilities: defaultDriverCapabilities.String(),
|
expectedCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Env unset for legacy image is 'all'",
|
description: "Env unset for legacy image is 'all'",
|
||||||
@ -982,7 +977,7 @@ func TestGetDriverCapabilities(t *testing.T) {
|
|||||||
env: map[string]string{},
|
env: map[string]string{},
|
||||||
legacyImage: false,
|
legacyImage: false,
|
||||||
supportedCapabilities: supportedCapabilities,
|
supportedCapabilities: supportedCapabilities,
|
||||||
expectedCapabilities: defaultDriverCapabilities.String(),
|
expectedCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Env is all for modern image",
|
description: "Env is all for modern image",
|
||||||
@ -1000,7 +995,7 @@ func TestGetDriverCapabilities(t *testing.T) {
|
|||||||
},
|
},
|
||||||
legacyImage: false,
|
legacyImage: false,
|
||||||
supportedCapabilities: supportedCapabilities,
|
supportedCapabilities: supportedCapabilities,
|
||||||
expectedCapabilities: defaultDriverCapabilities.String(),
|
expectedCapabilities: image.DefaultDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "Invalid capabilities panic",
|
description: "Invalid capabilities panic",
|
||||||
@ -1020,11 +1015,14 @@ func TestGetDriverCapabilities(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.description, func(t *testing.T) {
|
t.Run(tc.description, func(t *testing.T) {
|
||||||
var capabilites DriverCapabilities
|
var capabilites string
|
||||||
|
|
||||||
|
c := HookConfig{
|
||||||
|
SupportedDriverCapabilities: tc.supportedCapabilities,
|
||||||
|
}
|
||||||
|
|
||||||
getDriverCapabilities := func() {
|
getDriverCapabilities := func() {
|
||||||
supportedCapabilities := DriverCapabilities(tc.supportedCapabilities)
|
capabilites = c.getDriverCapabilities(tc.env, tc.legacyImage).String()
|
||||||
capabilites = getDriverCapabilities(tc.env, supportedCapabilities, tc.legacyImage)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if tc.expectedPanic {
|
if tc.expectedPanic {
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/BurntSushi/toml"
|
"github.com/BurntSushi/toml"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -23,11 +24,7 @@ var defaultPaths = [...]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HookConfig : options for the nvidia-container-runtime-hook.
|
// HookConfig : options for the nvidia-container-runtime-hook.
|
||||||
type HookConfig struct {
|
type HookConfig config.Config
|
||||||
config.Config
|
|
||||||
// TODO: We should also migrate the driver capabilities
|
|
||||||
SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDefaultHookConfig() (HookConfig, error) {
|
func getDefaultHookConfig() (HookConfig, error) {
|
||||||
defaultCfg, err := config.GetDefault()
|
defaultCfg, err := config.GetDefault()
|
||||||
@ -35,12 +32,7 @@ func getDefaultHookConfig() (HookConfig, error) {
|
|||||||
return HookConfig{}, err
|
return HookConfig{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c := HookConfig{
|
return *(*HookConfig)(defaultCfg), nil
|
||||||
Config: *defaultCfg,
|
|
||||||
SupportedDriverCapabilities: allDriverCapabilities,
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getHookConfig() (*HookConfig, error) {
|
func getHookConfig() (*HookConfig, error) {
|
||||||
@ -71,13 +63,15 @@ func getHookConfig() (*HookConfig, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.SupportedDriverCapabilities == all {
|
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
|
||||||
config.SupportedDriverCapabilities = allDriverCapabilities
|
if config.SupportedDriverCapabilities == "all" {
|
||||||
|
config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String()
|
||||||
}
|
}
|
||||||
// We ensure that the supported-driver-capabilites option is a subset of allDriverCapabilities
|
configuredCapabilities := image.NewDriverCapabilities(config.SupportedDriverCapabilities)
|
||||||
if intersection := allDriverCapabilities.Intersection(config.SupportedDriverCapabilities); intersection != config.SupportedDriverCapabilities {
|
// We ensure that the configured value is a subset of all supported capabilities
|
||||||
|
if !allSupportedDriverCapabilities.IsSuperset(configuredCapabilities) {
|
||||||
configName := config.getConfigOption("SupportedDriverCapabilities")
|
configName := config.getConfigOption("SupportedDriverCapabilities")
|
||||||
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allDriverCapabilities)
|
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return &config, nil
|
return &config, nil
|
||||||
|
@ -21,7 +21,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,16 +29,16 @@ func TestGetHookConfig(t *testing.T) {
|
|||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
lines []string
|
lines []string
|
||||||
expectedPanic bool
|
expectedPanic bool
|
||||||
expectedDriverCapabilities DriverCapabilities
|
expectedDriverCapabilities string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
expectedDriverCapabilities: allDriverCapabilities,
|
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
lines: []string{
|
lines: []string{
|
||||||
"supported-driver-capabilities = \"all\"",
|
"supported-driver-capabilities = \"all\"",
|
||||||
},
|
},
|
||||||
expectedDriverCapabilities: allDriverCapabilities,
|
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
lines: []string{
|
lines: []string{
|
||||||
@ -48,19 +48,19 @@ func TestGetHookConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
lines: []string{},
|
lines: []string{},
|
||||||
expectedDriverCapabilities: allDriverCapabilities,
|
expectedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
lines: []string{
|
lines: []string{
|
||||||
"supported-driver-capabilities = \"\"",
|
"supported-driver-capabilities = \"\"",
|
||||||
},
|
},
|
||||||
expectedDriverCapabilities: none,
|
expectedDriverCapabilities: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
lines: []string{
|
lines: []string{
|
||||||
"supported-driver-capabilities = \"utility,compute\"",
|
"supported-driver-capabilities = \"compute,utility\"",
|
||||||
},
|
},
|
||||||
expectedDriverCapabilities: DriverCapabilities("utility,compute"),
|
expectedDriverCapabilities: "compute,utility",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,9 +144,7 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {
|
|||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
|
||||||
c := &HookConfig{
|
c := &HookConfig{
|
||||||
Config: config.Config{
|
SwarmResource: tc.value,
|
||||||
SwarmResource: tc.value,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
envvars := c.getSwarmResourceEnvvars()
|
envvars := c.getSwarmResourceEnvvars()
|
||||||
|
@ -25,6 +25,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||||
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
||||||
@ -61,7 +62,7 @@ type Config struct {
|
|||||||
SwarmResource string `toml:"swarm-resource"`
|
SwarmResource string `toml:"swarm-resource"`
|
||||||
AcceptEnvvarUnprivileged bool `toml:"accept-nvidia-visible-devices-envvar-when-unprivileged"`
|
AcceptEnvvarUnprivileged bool `toml:"accept-nvidia-visible-devices-envvar-when-unprivileged"`
|
||||||
AcceptDeviceListAsVolumeMounts bool `toml:"accept-nvidia-visible-devices-as-volume-mounts"`
|
AcceptDeviceListAsVolumeMounts bool `toml:"accept-nvidia-visible-devices-as-volume-mounts"`
|
||||||
// SupportedDriverCapabilities DriverCapabilities `toml:"supported-driver-capabilities"`
|
SupportedDriverCapabilities string `toml:"supported-driver-capabilities"`
|
||||||
|
|
||||||
NVIDIAContainerCLIConfig ContainerCLIConfig `toml:"nvidia-container-cli"`
|
NVIDIAContainerCLIConfig ContainerCLIConfig `toml:"nvidia-container-cli"`
|
||||||
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
|
NVIDIACTKConfig CTKConfig `toml:"nvidia-ctk"`
|
||||||
@ -135,7 +136,8 @@ func getFromTree(toml *toml.Tree) (*Config, error) {
|
|||||||
// GetDefault defines the default values for the config
|
// GetDefault defines the default values for the config
|
||||||
func GetDefault() (*Config, error) {
|
func GetDefault() (*Config, error) {
|
||||||
d := Config{
|
d := Config{
|
||||||
AcceptEnvvarUnprivileged: true,
|
AcceptEnvvarUnprivileged: true,
|
||||||
|
SupportedDriverCapabilities: image.SupportedDriverCapabilities.String(),
|
||||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||||
LoadKmods: true,
|
LoadKmods: true,
|
||||||
Ldconfig: getLdConfigPath(),
|
Ldconfig: getLdConfigPath(),
|
||||||
|
@ -60,7 +60,8 @@ func TestGetConfig(t *testing.T) {
|
|||||||
description: "empty config is default",
|
description: "empty config is default",
|
||||||
inspectLdconfig: true,
|
inspectLdconfig: true,
|
||||||
expectedConfig: &Config{
|
expectedConfig: &Config{
|
||||||
AcceptEnvvarUnprivileged: true,
|
AcceptEnvvarUnprivileged: true,
|
||||||
|
SupportedDriverCapabilities: "compat32,compute,display,graphics,ngx,utility,video",
|
||||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||||
Root: "",
|
Root: "",
|
||||||
LoadKmods: true,
|
LoadKmods: true,
|
||||||
@ -94,6 +95,7 @@ func TestGetConfig(t *testing.T) {
|
|||||||
description: "config options set inline",
|
description: "config options set inline",
|
||||||
contents: []string{
|
contents: []string{
|
||||||
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
|
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
|
||||||
|
"supported-driver-capabilities = \"compute,utility\"",
|
||||||
"nvidia-container-cli.root = \"/bar/baz\"",
|
"nvidia-container-cli.root = \"/bar/baz\"",
|
||||||
"nvidia-container-cli.load-kmods = false",
|
"nvidia-container-cli.load-kmods = false",
|
||||||
"nvidia-container-cli.ldconfig = \"/foo/bar/ldconfig\"",
|
"nvidia-container-cli.ldconfig = \"/foo/bar/ldconfig\"",
|
||||||
@ -110,7 +112,8 @@ func TestGetConfig(t *testing.T) {
|
|||||||
"nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"",
|
"nvidia-ctk.path = \"/foo/bar/nvidia-ctk\"",
|
||||||
},
|
},
|
||||||
expectedConfig: &Config{
|
expectedConfig: &Config{
|
||||||
AcceptEnvvarUnprivileged: false,
|
AcceptEnvvarUnprivileged: false,
|
||||||
|
SupportedDriverCapabilities: "compute,utility",
|
||||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||||
Root: "/bar/baz",
|
Root: "/bar/baz",
|
||||||
LoadKmods: false,
|
LoadKmods: false,
|
||||||
@ -150,6 +153,7 @@ func TestGetConfig(t *testing.T) {
|
|||||||
description: "config options set in section",
|
description: "config options set in section",
|
||||||
contents: []string{
|
contents: []string{
|
||||||
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
|
"accept-nvidia-visible-devices-envvar-when-unprivileged = false",
|
||||||
|
"supported-driver-capabilities = \"compute,utility\"",
|
||||||
"[nvidia-container-cli]",
|
"[nvidia-container-cli]",
|
||||||
"root = \"/bar/baz\"",
|
"root = \"/bar/baz\"",
|
||||||
"load-kmods = false",
|
"load-kmods = false",
|
||||||
@ -172,7 +176,8 @@ func TestGetConfig(t *testing.T) {
|
|||||||
"path = \"/foo/bar/nvidia-ctk\"",
|
"path = \"/foo/bar/nvidia-ctk\"",
|
||||||
},
|
},
|
||||||
expectedConfig: &Config{
|
expectedConfig: &Config{
|
||||||
AcceptEnvvarUnprivileged: false,
|
AcceptEnvvarUnprivileged: false,
|
||||||
|
SupportedDriverCapabilities: "compute,utility",
|
||||||
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
NVIDIAContainerCLIConfig: ContainerCLIConfig{
|
||||||
Root: "/bar/baz",
|
Root: "/bar/baz",
|
||||||
LoadKmods: false,
|
LoadKmods: false,
|
||||||
|
@ -16,12 +16,18 @@
|
|||||||
|
|
||||||
package image
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
// DriverCapability represents the possible values of NVIDIA_DRIVER_CAPABILITIES
|
// DriverCapability represents the possible values of NVIDIA_DRIVER_CAPABILITIES
|
||||||
type DriverCapability string
|
type DriverCapability string
|
||||||
|
|
||||||
// Constants for the supported driver capabilities
|
// Constants for the supported driver capabilities
|
||||||
const (
|
const (
|
||||||
DriverCapabilityAll DriverCapability = "all"
|
DriverCapabilityAll DriverCapability = "all"
|
||||||
|
DriverCapabilityNone DriverCapability = "none"
|
||||||
DriverCapabilityCompat32 DriverCapability = "compat32"
|
DriverCapabilityCompat32 DriverCapability = "compat32"
|
||||||
DriverCapabilityCompute DriverCapability = "compute"
|
DriverCapabilityCompute DriverCapability = "compute"
|
||||||
DriverCapabilityDisplay DriverCapability = "display"
|
DriverCapabilityDisplay DriverCapability = "display"
|
||||||
@ -31,12 +37,37 @@ const (
|
|||||||
DriverCapabilityVideo DriverCapability = "video"
|
DriverCapabilityVideo DriverCapability = "video"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
driverCapabilitiesNone = NewDriverCapabilities()
|
||||||
|
driverCapabilitiesAll = NewDriverCapabilities("all")
|
||||||
|
|
||||||
|
// DefaultDriverCapabilities sets the value for driver capabilities if no value is set.
|
||||||
|
DefaultDriverCapabilities = NewDriverCapabilities("utility,compute")
|
||||||
|
// SupportedDriverCapabilities defines the set of all supported driver capabilities.
|
||||||
|
SupportedDriverCapabilities = NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx")
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewDriverCapabilities creates a set of driver capabilities from the specified capabilities
|
||||||
|
func NewDriverCapabilities(capabilities ...string) DriverCapabilities {
|
||||||
|
dc := make(DriverCapabilities)
|
||||||
|
for _, capability := range capabilities {
|
||||||
|
for _, c := range strings.Split(capability, ",") {
|
||||||
|
trimmed := strings.TrimSpace(c)
|
||||||
|
if trimmed == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dc[DriverCapability(trimmed)] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dc
|
||||||
|
}
|
||||||
|
|
||||||
// DriverCapabilities represents the NVIDIA_DRIVER_CAPABILITIES set for the specified image.
|
// DriverCapabilities represents the NVIDIA_DRIVER_CAPABILITIES set for the specified image.
|
||||||
type DriverCapabilities map[DriverCapability]bool
|
type DriverCapabilities map[DriverCapability]bool
|
||||||
|
|
||||||
// Has check whether the specified capability is selected.
|
// Has check whether the specified capability is selected.
|
||||||
func (c DriverCapabilities) Has(capability DriverCapability) bool {
|
func (c DriverCapabilities) Has(capability DriverCapability) bool {
|
||||||
if c[DriverCapabilityAll] {
|
if c.IsAll() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return c[capability]
|
return c[capability]
|
||||||
@ -44,11 +75,72 @@ func (c DriverCapabilities) Has(capability DriverCapability) bool {
|
|||||||
|
|
||||||
// Any checks whether any of the specified capabilites are set
|
// Any checks whether any of the specified capabilites are set
|
||||||
func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool {
|
func (c DriverCapabilities) Any(capabilities ...DriverCapability) bool {
|
||||||
|
if c.IsAll() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
for _, cap := range capabilities {
|
for _, cap := range capabilities {
|
||||||
if c.Has(cap) {
|
if c.Has(cap) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// List returns the list of driver capabilities.
|
||||||
|
// The list is sorted.
|
||||||
|
func (c DriverCapabilities) List() []string {
|
||||||
|
var capabilities []string
|
||||||
|
for capability := range c {
|
||||||
|
capabilities = append(capabilities, string(capability))
|
||||||
|
}
|
||||||
|
sort.Strings(capabilities)
|
||||||
|
return capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string repesentation of the driver capabilities.
|
||||||
|
func (c DriverCapabilities) String() string {
|
||||||
|
if c.IsAll() {
|
||||||
|
return "all"
|
||||||
|
}
|
||||||
|
return strings.Join(c.List(), ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAll indicates whether the set of capabilities is `all`
|
||||||
|
func (c DriverCapabilities) IsAll() bool {
|
||||||
|
return c[DriverCapabilityAll]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intersection returns a new set which includes the item in BOTH d and s2.
|
||||||
|
// For example: d = {a1, a2} s2 = {a2, a3} s1.Intersection(s2) = {a2}
|
||||||
|
func (c DriverCapabilities) Intersection(s2 DriverCapabilities) DriverCapabilities {
|
||||||
|
if s2.IsAll() {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
if c.IsAll() {
|
||||||
|
return s2
|
||||||
|
}
|
||||||
|
|
||||||
|
intersection := make(DriverCapabilities)
|
||||||
|
for capability := range s2 {
|
||||||
|
if c[capability] {
|
||||||
|
intersection[capability] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return intersection
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSuperset returns true if and only if d is a superset of s2.
|
||||||
|
func (c DriverCapabilities) IsSuperset(s2 DriverCapabilities) bool {
|
||||||
|
if c.IsAll() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for capability := range s2 {
|
||||||
|
if !c[capability] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
134
internal/config/image/capabilities_test.go
Normal file
134
internal/config/image/capabilities_test.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/**
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
**/
|
||||||
|
|
||||||
|
package image
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDriverCapabilitiesIntersection(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
capabilities DriverCapabilities
|
||||||
|
supportedCapabilities DriverCapabilities
|
||||||
|
expectedIntersection DriverCapabilities
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
capabilities: driverCapabilitiesNone,
|
||||||
|
supportedCapabilities: driverCapabilitiesNone,
|
||||||
|
expectedIntersection: driverCapabilitiesNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: driverCapabilitiesAll,
|
||||||
|
supportedCapabilities: driverCapabilitiesNone,
|
||||||
|
expectedIntersection: driverCapabilitiesNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: driverCapabilitiesAll,
|
||||||
|
supportedCapabilities: SupportedDriverCapabilities,
|
||||||
|
expectedIntersection: SupportedDriverCapabilities,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: SupportedDriverCapabilities,
|
||||||
|
supportedCapabilities: driverCapabilitiesAll,
|
||||||
|
expectedIntersection: SupportedDriverCapabilities,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: driverCapabilitiesNone,
|
||||||
|
supportedCapabilities: driverCapabilitiesAll,
|
||||||
|
expectedIntersection: driverCapabilitiesNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: driverCapabilitiesNone,
|
||||||
|
supportedCapabilities: NewDriverCapabilities("cap1"),
|
||||||
|
expectedIntersection: driverCapabilitiesNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities("cap0,cap1"),
|
||||||
|
supportedCapabilities: NewDriverCapabilities("cap1,cap0"),
|
||||||
|
expectedIntersection: NewDriverCapabilities("cap0,cap1"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: DefaultDriverCapabilities,
|
||||||
|
supportedCapabilities: SupportedDriverCapabilities,
|
||||||
|
expectedIntersection: DefaultDriverCapabilities,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||||
|
supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
||||||
|
expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities("cap1"),
|
||||||
|
supportedCapabilities: driverCapabilitiesNone,
|
||||||
|
expectedIntersection: driverCapabilitiesNone,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display,ngx"),
|
||||||
|
supportedCapabilities: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||||
|
expectedIntersection: NewDriverCapabilities("compute,compat32,graphics,utility,video,display"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||||
|
intersection := tc.supportedCapabilities.Intersection(tc.capabilities)
|
||||||
|
require.EqualValues(t, tc.expectedIntersection, intersection)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDriverCapabilitiesList(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
capabilities DriverCapabilities
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities(""),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities(" "),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities(","),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities(",cap"),
|
||||||
|
expected: []string{"cap"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities("cap,"),
|
||||||
|
expected: []string{"cap"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities("cap0,,cap1"),
|
||||||
|
expected: []string{"cap0", "cap1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: NewDriverCapabilities("cap1,cap0,cap3"),
|
||||||
|
expected: []string{"cap0", "cap1", "cap3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range testCases {
|
||||||
|
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
|
||||||
|
require.EqualValues(t, tc.expected, tc.capabilities.List())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user