From c0764366d99c1296878c8559c10d3d58fde4b873 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 11 Sep 2024 11:56:55 +0200 Subject: [PATCH] [no-relnote] Refactor config handling for hook This change removes indirect calls to get the default config from the nvidia-container-runtime-hook. Signed-off-by: Evan Lezar --- .../container_config.go | 20 ++--- .../container_config_test.go | 78 +++++++++++-------- .../hook_config.go | 15 ++-- .../hook_config_test.go | 13 ++-- cmd/nvidia-container-runtime-hook/main.go | 2 +- 5 files changed, 74 insertions(+), 54 deletions(-) diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index ac34fb93..3ae8c98f 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -157,7 +157,7 @@ func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []stri return containerImage.VisibleDevicesFromEnvVar() } -func getDevices(hookConfig *HookConfig, image image.CUDA, privileged bool) []string { +func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string { // If enabled, try and get the device list from volume mounts first if hookConfig.AcceptDeviceListAsVolumeMounts { devices := image.VisibleDevicesFromMounts() @@ -197,7 +197,7 @@ func getMigDevices(image image.CUDA, envvar string) *string { return &devices } -func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool) []string { +func (hookConfig *hookConfig) getImexChannels(image image.CUDA, privileged bool) []string { // If enabled, try and get the device list from volume mounts first if hookConfig.AcceptDeviceListAsVolumeMounts { devices := image.ImexChannelsFromMounts() @@ -217,10 +217,10 @@ func getImexChannels(hookConfig *HookConfig, image image.CUDA, privileged bool) return nil } -func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities { +func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage bool) image.DriverCapabilities { // We use the default driver capabilities by default. This is filtered to only include the // supported capabilities - supportedDriverCapabilities := image.NewDriverCapabilities(c.SupportedDriverCapabilities) + supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities) capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) @@ -244,10 +244,10 @@ func (c *HookConfig) getDriverCapabilities(cudaImage image.CUDA, legacyImage boo return capabilities } -func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool) *nvidiaConfig { +func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig { legacyImage := image.IsLegacy() - devices := getDevices(hookConfig, image, privileged) + devices := hookConfig.getDevices(image, privileged) if len(devices) == 0 { // empty devices means this is not a GPU container. return nil @@ -269,7 +269,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool) log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") } - imexChannels := getImexChannels(hookConfig, image, privileged) + imexChannels := hookConfig.getImexChannels(image, privileged) driverCapabilities := hookConfig.getDriverCapabilities(image, legacyImage).String() @@ -288,7 +288,7 @@ func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, privileged bool) } } -func getContainerConfig(hook HookConfig) (config containerConfig) { +func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) { var h HookState d := json.NewDecoder(os.Stdin) if err := d.Decode(&h); err != nil { @@ -305,7 +305,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) { image, err := image.New( image.WithEnv(s.Process.Env), image.WithMounts(s.Mounts), - image.WithDisableRequire(hook.DisableRequire), + image.WithDisableRequire(hookConfig.DisableRequire), ) if err != nil { log.Panicln(err) @@ -316,6 +316,6 @@ func getContainerConfig(hook HookConfig) (config containerConfig) { Pid: h.Pid, Rootfs: s.Root.Path, Image: image, - Nvidia: getNvidiaConfig(&hook, image, privileged), + Nvidia: hookConfig.getNvidiaConfig(image, privileged), } } diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index cad9ca6e..1f3858b1 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -7,6 +7,7 @@ import ( "github.com/opencontainers/runtime-spec/specs-go" "github.com/stretchr/testify/require" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) @@ -15,7 +16,7 @@ func TestGetNvidiaConfig(t *testing.T) { description string env map[string]string privileged bool - hookConfig *HookConfig + hookConfig *hookConfig expectedConfig *nvidiaConfig expectedPanic bool }{ @@ -394,8 +395,10 @@ func TestGetNvidiaConfig(t *testing.T) { image.EnvVarNvidiaDriverCapabilities: "all", }, privileged: true, - hookConfig: &HookConfig{ - SupportedDriverCapabilities: "video,display", + hookConfig: &hookConfig{ + Config: &config.Config{ + SupportedDriverCapabilities: "video,display", + }, }, expectedConfig: &nvidiaConfig{ Devices: []string{"all"}, @@ -409,8 +412,10 @@ func TestGetNvidiaConfig(t *testing.T) { image.EnvVarNvidiaDriverCapabilities: "video,display", }, privileged: true, - hookConfig: &HookConfig{ - SupportedDriverCapabilities: "video,display,compute,utility", + hookConfig: &hookConfig{ + Config: &config.Config{ + SupportedDriverCapabilities: "video,display,compute,utility", + }, }, expectedConfig: &nvidiaConfig{ Devices: []string{"all"}, @@ -423,8 +428,10 @@ func TestGetNvidiaConfig(t *testing.T) { image.EnvVarNvidiaVisibleDevices: "all", }, privileged: true, - hookConfig: &HookConfig{ - SupportedDriverCapabilities: "video,display,utility,compute", + hookConfig: &hookConfig{ + Config: &config.Config{ + SupportedDriverCapabilities: "video,display,utility,compute", + }, }, expectedConfig: &nvidiaConfig{ Devices: []string{"all"}, @@ -438,9 +445,11 @@ func TestGetNvidiaConfig(t *testing.T) { "DOCKER_SWARM_RESOURCE": "GPU1,GPU2", }, privileged: true, - hookConfig: &HookConfig{ - SwarmResource: "DOCKER_SWARM_RESOURCE", - SupportedDriverCapabilities: "video,display,utility,compute", + hookConfig: &hookConfig{ + Config: &config.Config{ + SwarmResource: "DOCKER_SWARM_RESOURCE", + SupportedDriverCapabilities: "video,display,utility,compute", + }, }, expectedConfig: &nvidiaConfig{ Devices: []string{"GPU1", "GPU2"}, @@ -454,9 +463,11 @@ func TestGetNvidiaConfig(t *testing.T) { "DOCKER_SWARM_RESOURCE": "GPU1,GPU2", }, privileged: true, - hookConfig: &HookConfig{ - SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE", - SupportedDriverCapabilities: "video,display,utility,compute", + hookConfig: &hookConfig{ + Config: &config.Config{ + SwarmResource: "NOT_DOCKER_SWARM_RESOURCE,DOCKER_SWARM_RESOURCE", + SupportedDriverCapabilities: "video,display,utility,compute", + }, }, expectedConfig: &nvidiaConfig{ Devices: []string{"GPU1", "GPU2"}, @@ -470,14 +481,14 @@ func TestGetNvidiaConfig(t *testing.T) { image.WithEnvMap(tc.env), ) // Wrap the call to getNvidiaConfig() in a closure. - var config *nvidiaConfig + var cfg *nvidiaConfig getConfig := func() { - hookConfig := tc.hookConfig - if hookConfig == nil { - defaultConfig, _ := getDefaultHookConfig() - hookConfig = &defaultConfig + hookCfg := tc.hookConfig + if hookCfg == nil { + defaultConfig, _ := config.GetDefault() + hookCfg = &hookConfig{defaultConfig} } - config = getNvidiaConfig(hookConfig, image, tc.privileged) + cfg = hookCfg.getNvidiaConfig(image, tc.privileged) } // For any tests that are expected to panic, make sure they do. @@ -491,18 +502,18 @@ func TestGetNvidiaConfig(t *testing.T) { // And start comparing the test results to the expected results. if tc.expectedConfig == nil { - require.Nil(t, config, tc.description) + require.Nil(t, cfg, tc.description) return } - require.NotNil(t, config, tc.description) + require.NotNil(t, cfg, tc.description) - require.Equal(t, tc.expectedConfig.Devices, config.Devices) - require.Equal(t, tc.expectedConfig.MigConfigDevices, config.MigConfigDevices) - require.Equal(t, tc.expectedConfig.MigMonitorDevices, config.MigMonitorDevices) - require.Equal(t, tc.expectedConfig.DriverCapabilities, config.DriverCapabilities) + require.Equal(t, tc.expectedConfig.Devices, cfg.Devices) + require.Equal(t, tc.expectedConfig.MigConfigDevices, cfg.MigConfigDevices) + require.Equal(t, tc.expectedConfig.MigMonitorDevices, cfg.MigMonitorDevices) + require.Equal(t, tc.expectedConfig.DriverCapabilities, cfg.DriverCapabilities) - require.ElementsMatch(t, tc.expectedConfig.Requirements, config.Requirements) + require.ElementsMatch(t, tc.expectedConfig.Requirements, cfg.Requirements) }) } } @@ -612,10 +623,11 @@ func TestDeviceListSourcePriority(t *testing.T) { ), image.WithMounts(tc.mountDevices), ) - hookConfig, _ := getDefaultHookConfig() - hookConfig.AcceptEnvvarUnprivileged = tc.acceptUnprivileged - hookConfig.AcceptDeviceListAsVolumeMounts = tc.acceptMounts - devices = getDevices(&hookConfig, image, tc.privileged) + defaultConfig, _ := config.GetDefault() + cfg := &hookConfig{defaultConfig} + cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged + cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts + devices = cfg.getDevices(image, tc.privileged) } // For all other tests, just grab the devices and check the results @@ -940,8 +952,10 @@ func TestGetDriverCapabilities(t *testing.T) { t.Run(tc.description, func(t *testing.T) { var capabilities string - c := HookConfig{ - SupportedDriverCapabilities: tc.supportedCapabilities, + c := hookConfig{ + Config: &config.Config{ + SupportedDriverCapabilities: tc.supportedCapabilities, + }, } image, _ := image.New( diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index 4936af67..f88815e5 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -17,8 +17,11 @@ const ( driverPath = "/run/nvidia/driver" ) -// HookConfig : options for the nvidia-container-runtime-hook. -type HookConfig config.Config +// hookConfig wraps the toolkit config. +// This allows for functions to be defined on the local type. +type hookConfig struct { + *config.Config +} // loadConfig loads the required paths for the hook config. func loadConfig() (*config.Config, error) { @@ -47,12 +50,12 @@ func loadConfig() (*config.Config, error) { return config.GetDefault() } -func getHookConfig() (*HookConfig, error) { +func getHookConfig() (*hookConfig, error) { cfg, err := loadConfig() if err != nil { return nil, fmt.Errorf("failed to load config: %v", err) } - config := (*HookConfig)(cfg) + config := &hookConfig{cfg} allSupportedDriverCapabilities := image.SupportedDriverCapabilities if config.SupportedDriverCapabilities == "all" { @@ -70,7 +73,7 @@ func getHookConfig() (*HookConfig, error) { // getConfigOption returns the toml config option associated with the // specified struct field. -func (c HookConfig) getConfigOption(fieldName string) string { +func (c hookConfig) getConfigOption(fieldName string) string { t := reflect.TypeOf(c) f, ok := t.FieldByName(fieldName) if !ok { @@ -84,7 +87,7 @@ func (c HookConfig) getConfigOption(fieldName string) string { } // getSwarmResourceEnvvars returns the swarm resource envvars for the config. -func (c *HookConfig) getSwarmResourceEnvvars() []string { +func (c *hookConfig) getSwarmResourceEnvvars() []string { if c.SwarmResource == "" { return nil } diff --git a/cmd/nvidia-container-runtime-hook/hook_config_test.go b/cmd/nvidia-container-runtime-hook/hook_config_test.go index 7c50ec12..ade45d27 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config_test.go +++ b/cmd/nvidia-container-runtime-hook/hook_config_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) @@ -89,10 +90,10 @@ func TestGetHookConfig(t *testing.T) { } } - var config HookConfig + var cfg hookConfig getHookConfig := func() { c, _ := getHookConfig() - config = *c + cfg = *c } if tc.expectedPanic { @@ -102,7 +103,7 @@ func TestGetHookConfig(t *testing.T) { getHookConfig() - require.EqualValues(t, tc.expectedDriverCapabilities, config.SupportedDriverCapabilities) + require.EqualValues(t, tc.expectedDriverCapabilities, cfg.SupportedDriverCapabilities) }) } } @@ -144,8 +145,10 @@ func TestGetSwarmResourceEnvvars(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - c := &HookConfig{ - SwarmResource: tc.value, + c := &hookConfig{ + Config: &config.Config{ + SwarmResource: tc.value, + }, } envvars := c.getSwarmResourceEnvvars() diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index a0cfe3ec..cf2322ef 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -75,7 +75,7 @@ func doPrestart() { } cli := hook.NVIDIAContainerCLIConfig - container := getContainerConfig(*hook) + container := hook.getContainerConfig() nvidia := container.Nvidia if nvidia == nil { // Not a GPU container, nothing to do.