[no-relnote] Refactor config handling for hook

This change removes indirect calls to get the default config
from the nvidia-container-runtime-hook.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2024-09-11 11:56:55 +02:00
parent 1afada7de5
commit c0764366d9
No known key found for this signature in database
5 changed files with 74 additions and 54 deletions

View File

@ -157,7 +157,7 @@ func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []stri
return containerImage.VisibleDevicesFromEnvVar() return containerImage.VisibleDevicesFromEnvVar()
} }
func getDevices(hookConfig *HookConfig, image image.CUDA, privileged bool) []string { func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
// If enabled, try and get the device list from volume mounts first // If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts { if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := image.VisibleDevicesFromMounts() devices := image.VisibleDevicesFromMounts()
@ -197,7 +197,7 @@ func getMigDevices(image image.CUDA, envvar string) *string {
return &devices return &devices
} }
func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool) []string { func (hookConfig *hookConfig) getImexChannels(image image.CUDA, privileged bool) []string {
// If enabled, try and get the device list from volume mounts first // If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts { if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := image.ImexChannelsFromMounts() devices := image.ImexChannelsFromMounts()
@ -217,10 +217,10 @@ func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool)
return nil return nil
} }
func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities { func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities {
// We use the default driver capabilities by default. This is filtered to only include the // We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities // supported capabilities
supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities) supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
@ -244,10 +244,10 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo
return capabilities return capabilities
} }
func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool) *nvidiaConfig { func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
legacyImage := image.IsLegacy() legacyImage := image.IsLegacy()
devices := getDevices(hookConfig, image, privileged) devices := hookConfig.getDevices(image, privileged)
if len(devices) == 0 { if len(devices) == 0 {
// empty devices means this is not a GPU container. // empty devices means this is not a GPU container.
return nil return nil
@ -269,7 +269,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool)
log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container")
} }
imexChannels := getImexChannels(hookConfig, image, privileged) imexChannels := hookConfig.getImexChannels(image, privileged)
driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String() driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String()
@ -288,7 +288,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool)
} }
} }
func getContainerConfig(hook HookConfig) (config containerConfig) { func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
var h HookState var h HookState
d := json.NewDecoder(os.Stdin) d := json.NewDecoder(os.Stdin)
if err := d.Decode(&h); err != nil { if err := d.Decode(&h); err != nil {
@ -305,7 +305,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
image, err := image.New( image, err := image.New(
image.WithEnv(s.Process.Env), image.WithEnv(s.Process.Env),
image.WithMounts(s.Mounts), image.WithMounts(s.Mounts),
image.WithDisableRequire(hook.DisableRequire), image.WithDisableRequire(hookConfig.DisableRequire),
) )
if err != nil { if err != nil {
log.Panicln(err) log.Panicln(err)
@ -316,6 +316,6 @@ func getContainerConfig(hook HookConfig) (config containerConfig) {
Pid: h.Pid, Pid: h.Pid,
Rootfs: s.Root.Path, Rootfs: s.Root.Path,
Image: image, Image: image,
Nvidia: getNvidiaConfig(&hook, image, privileged), Nvidia: hookConfig.getNvidiaConfig(image, privileged),
} }
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
) )
@ -15,7 +16,7 @@ func TestGetNvidiaConfig(t *testing.T) {
description string description string
env map[string]string env map[string]string
privileged bool privileged bool
hookConfig *HookConfig hookConfig *hookConfig
expectedConfig *nvidiaConfig expectedConfig *nvidiaConfig
expectedPanic bool expectedPanic bool
}{ }{
@ -394,9 +395,11 @@ func TestGetNvidiaConfig(t *testing.T) {
image.EnvVarNvidiaDriverCapabilities: "all", image.EnvVarNvidiaDriverCapabilities: "all",
}, },
privileged: true, privileged: true,
hookConfig: &HookConfig{ hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display", SupportedDriverCapabilities: "video,display",
}, },
},
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: []string{"all"}, Devices: []string{"all"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
@ -409,9 +412,11 @@ func TestGetNvidiaConfig(t *testing.T) {
image.EnvVarNvidiaDriverCapabilities: "video,display", image.EnvVarNvidiaDriverCapabilities: "video,display",
}, },
privileged: true, privileged: true,
hookConfig: &HookConfig{ hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display,compute,utility", SupportedDriverCapabilities: "video,display,compute,utility",
}, },
},
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: []string{"all"}, Devices: []string{"all"},
DriverCapabilities: "display,video", DriverCapabilities: "display,video",
@ -423,9 +428,11 @@ func TestGetNvidiaConfig(t *testing.T) {
image.EnvVarNvidiaVisibleDevices: "all", image.EnvVarNvidiaVisibleDevices: "all",
}, },
privileged: true, privileged: true,
hookConfig: &HookConfig{ hookConfig: &hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: "video,display,utility,compute", SupportedDriverCapabilities: "video,display,utility,compute",
}, },
},
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: []string{"all"}, Devices: []string{"all"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
@ -438,10 +445,12 @@ func TestGetNvidiaConfig(t *testing.T) {
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2", "DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
}, },
privileged: true, privileged: true,
hookConfig: &HookConfig{ hookConfig: &hookConfig{
Config: &config.Config{
SwarmResource: "DOCKER_SWARM_RESOURCE", SwarmResource: "DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute", SupportedDriverCapabilities: "video,display,utility,compute",
}, },
},
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: []string{"GPU1", "GPU2"}, Devices: []string{"GPU1", "GPU2"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
@ -454,10 +463,12 @@ func TestGetNvidiaConfig(t *testing.T) {
"DOCKER_SWARM_RESOURCE": "GPU1,GPU2", "DOCKER_SWARM_RESOURCE": "GPU1,GPU2",
}, },
privileged: true, privileged: true,
hookConfig: &HookConfig{ hookConfig: &hookConfig{
Config: &config.Config{
SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE", SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE",
SupportedDriverCapabilities: "video,display,utility,compute", SupportedDriverCapabilities: "video,display,utility,compute",
}, },
},
expectedConfig: &nvidiaConfig{ expectedConfig: &nvidiaConfig{
Devices: []string{"GPU1", "GPU2"}, Devices: []string{"GPU1", "GPU2"},
DriverCapabilities: image.DefaultDriverCapabilities.String(), DriverCapabilities: image.DefaultDriverCapabilities.String(),
@ -470,14 +481,14 @@ func TestGetNvidiaConfig(t *testing.T) {
image.WithEnvMap(tc.env), image.WithEnvMap(tc.env),
) )
// Wrap the call to getNvidiaConfig() in a closure. // Wrap the call to getNvidiaConfig() in a closure.
var config *nvidiaConfig var cfg *nvidiaConfig
getConfig := func() { getConfig := func() {
hookConfig := tc.hookConfig hookCfg := tc.hookConfig
if hookConfig == nil { if hookCfg == nil {
defaultConfig, _ := getDefaultHookConfig() defaultConfig, _ := config.GetDefault()
hookConfig = &defaultConfig hookCfg = &hookConfig{defaultConfig}
} }
config = getNvidiaConfig(hookConfig, image, tc.privileged) cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
} }
// For any tests that are expected to panic, make sure they do. // For any tests that are expected to panic, make sure they do.
@ -491,18 +502,18 @@ func TestGetNvidiaConfig(t *testing.T) {
// And start comparing the test results to the expected results. // And start comparing the test results to the expected results.
if tc.expectedConfig == nil { if tc.expectedConfig == nil {
require.Nil(t, config, tc.description) require.Nil(t, cfg, tc.description)
return return
} }
require.NotNil(t, config, tc.description) require.NotNil(t, cfg, tc.description)
require.Equal(t, tc.expectedConfig.Devices, config.Devices) require.Equal(t, tc.expectedConfig.Devices, cfg.Devices)
require.Equal(t, tc.expectedConfig.MigConfigDevices, config.MigConfigDevices) require.Equal(t, tc.expectedConfig.MigConfigDevices, cfg.MigConfigDevices)
require.Equal(t, tc.expectedConfig.MigMonitorDevices, config.MigMonitorDevices) require.Equal(t, tc.expectedConfig.MigMonitorDevices, cfg.MigMonitorDevices)
require.Equal(t, tc.expectedConfig.DriverCapabilities, config.DriverCapabilities) require.Equal(t, tc.expectedConfig.DriverCapabilities, cfg.DriverCapabilities)
require.ElementsMatch(t, tc.expectedConfig.Requirements, config.Requirements) require.ElementsMatch(t, tc.expectedConfig.Requirements, cfg.Requirements)
}) })
} }
} }
@ -612,10 +623,11 @@ func TestDeviceListSourcePriority(t *testing.T) {
), ),
image.WithMounts(tc.mountDevices), image.WithMounts(tc.mountDevices),
) )
hookConfig, _ := getDefaultHookConfig() defaultConfig, _ := config.GetDefault()
hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged cfg := &hookConfig{defaultConfig}
hookConfig.AcceptDeviceListAsVolumeMounts = tc.acceptMounts cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
devices = getDevices(&hookConfig, image, tc.privileged) cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
devices = cfg.getDevices(image, tc.privileged)
} }
// For all other tests, just grab the devices and check the results // For all other tests, just grab the devices and check the results
@ -940,8 +952,10 @@ func TestGetDriverCapabilities(t *testing.T) {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
var capabilities string var capabilities string
c := HookConfig{ c := hookConfig{
Config: &config.Config{
SupportedDriverCapabilities: tc.supportedCapabilities, SupportedDriverCapabilities: tc.supportedCapabilities,
},
} }
image, _ := image.New( image, _ := image.New(

View File

@ -17,8 +17,11 @@ const (
driverPath = "/run/nvidia/driver" driverPath = "/run/nvidia/driver"
) )
// HookConfig : options for the nvidia-container-runtime-hook. // hookConfig wraps the toolkit config.
type HookConfig config.Config // This allows for functions to be defined on the local type.
type hookConfig struct {
*config.Config
}
// loadConfig loads the required paths for the hook config. // loadConfig loads the required paths for the hook config.
func loadConfig() (*config.Config, error) { func loadConfig() (*config.Config, error) {
@ -47,12 +50,12 @@ func loadConfig() (*config.Config, error) {
return config.GetDefault() return config.GetDefault()
} }
func getHookConfig() (*HookConfig, error) { func getHookConfig() (*hookConfig, error) {
cfg, err := loadConfig() cfg, err := loadConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err) return nil, fmt.Errorf("failed to load config: %v", err)
} }
config := (*HookConfig)(cfg) config := &hookConfig{cfg}
allSupportedDriverCapabilities := image.SupportedDriverCapabilities allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" { if config.SupportedDriverCapabilities == "all" {
@ -70,7 +73,7 @@ func getHookConfig() (*HookConfig, error) {
// getConfigOption returns the toml config option associated with the // getConfigOption returns the toml config option associated with the
// specified struct field. // specified struct field.
func (c HookConfig) getConfigOption(fieldName string) string { func (c hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(c) t := reflect.TypeOf(c)
f, ok := t.FieldByName(fieldName) f, ok := t.FieldByName(fieldName)
if !ok { if !ok {
@ -84,7 +87,7 @@ func (c HookConfig) getConfigOption(fieldName string) string {
} }
// getSwarmResourceEnvvars returns the swarm resource envvars for the config. // getSwarmResourceEnvvars returns the swarm resource envvars for the config.
func (c *HookConfig) getSwarmResourceEnvvars() []string { func (c *hookConfig) getSwarmResourceEnvvars() []string {
if c.SwarmResource == "" { if c.SwarmResource == "" {
return nil return nil
} }

View File

@ -23,6 +23,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
) )
@ -89,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
} }
} }
var config HookConfig var cfg hookConfig
getHookConfig := func() { getHookConfig := func() {
c, _ := getHookConfig() c, _ := getHookConfig()
config = *c cfg = *c
} }
if tc.expectedPanic { if tc.expectedPanic {
@ -102,7 +103,7 @@ func TestGetHookConfig(t *testing.T) {
getHookConfig() getHookConfig()
require.EqualValues(t, tc.expectedDriverCapabilities, config.SupportedDriverCapabilities) require.EqualValues(t, tc.expectedDriverCapabilities, cfg.SupportedDriverCapabilities)
}) })
} }
} }
@ -144,8 +145,10 @@ func TestGetSwarmResourceEnvvars(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
c := &HookConfig{ c := &hookConfig{
Config: &config.Config{
SwarmResource: tc.value, SwarmResource: tc.value,
},
} }
envvars := c.getSwarmResourceEnvvars() envvars := c.getSwarmResourceEnvvars()

View File

@ -75,7 +75,7 @@ func doPrestart() {
} }
cli := hook.NVIDIAContainerCLIConfig cli := hook.NVIDIAContainerCLIConfig
container := getContainerConfig(*hook) container := hook.getContainerConfig()
nvidia := container.Nvidia nvidia := container.Nvidia
if nvidia == nil { if nvidia == nil {
// Not a GPU container, nothing to do. // Not a GPU container, nothing to do.