Refactor loading of hook configs

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-08-10 15:11:01 +02:00
parent b18ac09f77
commit 3670e7b89e

View File

@ -8,7 +8,6 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/BurntSushi/toml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
) )
@ -18,10 +17,7 @@ const (
driverPath = "/run/nvidia/driver" driverPath = "/run/nvidia/driver"
) )
var defaultPaths = [...]string{ var defaultPaths = [...]string{}
path.Join(driverPath, configPath),
configPath,
}
// HookConfig : options for the nvidia-container-runtime-hook. // HookConfig : options for the nvidia-container-runtime-hook.
type HookConfig config.Config type HookConfig config.Config
@ -35,34 +31,37 @@ func getDefaultHookConfig() (HookConfig, error) {
return *(*HookConfig)(defaultCfg), nil return *(*HookConfig)(defaultCfg), nil
} }
func getHookConfig() (*HookConfig, error) { // loadConfig loads the required paths for the hook config.
var err error func loadConfig() (*config.Config, error) {
var config HookConfig var configPaths []string
var required bool
if len(*configflag) > 0 { if len(*configflag) != 0 {
config, err = getDefaultHookConfig() configPaths = append(configPaths, *configflag)
if err != nil { required = true
return nil, fmt.Errorf("couldn't get default configuration: %v", err)
}
_, err = toml.DecodeFile(*configflag, &config)
if err != nil {
return nil, fmt.Errorf("couldn't open configuration file: %v", err)
}
} else { } else {
for _, p := range defaultPaths { configPaths = append(configPaths, path.Join(driverPath, configPath), configPath)
config, err = getDefaultHookConfig()
if err != nil {
return nil, fmt.Errorf("couldn't get default configuration: %v", err)
}
_, err = toml.DecodeFile(p, &config)
if err == nil {
break
} else if !os.IsNotExist(err) {
return nil, fmt.Errorf("couldn't open default configuration file: %v", err)
}
}
} }
for _, p := range configPaths {
cfg, err := config.Load(p)
if err == nil {
return cfg, nil
} else if os.IsNotExist(err) && !required {
continue
}
return nil, fmt.Errorf("couldn't open configuration file: %v", err)
}
return config.GetDefault()
}
func getHookConfig() (*HookConfig, error) {
cfg, err := loadConfig()
if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err)
}
config := (*HookConfig)(cfg)
allSupportedDriverCapabilities := image.SupportedDriverCapabilities allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" { if config.SupportedDriverCapabilities == "all" {
config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String() config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String()
@ -74,7 +73,7 @@ func getHookConfig() (*HookConfig, error) {
log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String()) log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String())
} }
return &config, nil return config, nil
} }
// getConfigOption returns the toml config option associated with the // getConfigOption returns the toml config option associated with the