diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index bd3535ec..d1dda572 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -8,7 +8,6 @@ import ( "reflect" "strings" - "github.com/BurntSushi/toml" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" ) @@ -18,10 +17,7 @@ const ( driverPath = "/run/nvidia/driver" ) -var defaultPaths = [...]string{ - path.Join(driverPath, configPath), - configPath, -} +var defaultPaths = [...]string{} // HookConfig : options for the nvidia-container-runtime-hook. type HookConfig config.Config @@ -35,34 +31,37 @@ func getDefaultHookConfig() (HookConfig, error) { return *(*HookConfig)(defaultCfg), nil } -func getHookConfig() (*HookConfig, error) { - var err error - var config HookConfig - - if len(*configflag) > 0 { - config, err = getDefaultHookConfig() - if err != nil { - return nil, fmt.Errorf("couldn't get default configuration: %v", err) - } - _, err = toml.DecodeFile(*configflag, &config) - if err != nil { - return nil, fmt.Errorf("couldn't open configuration file: %v", err) - } +// loadConfig loads the required paths for the hook config. +func loadConfig() (*config.Config, error) { + var configPaths []string + var required bool + if len(*configflag) != 0 { + configPaths = append(configPaths, *configflag) + required = true } else { - for _, p := range defaultPaths { - config, err = getDefaultHookConfig() - if err != nil { - return nil, fmt.Errorf("couldn't get default configuration: %v", err) - } - _, err = toml.DecodeFile(p, &config) - if err == nil { - break - } else if !os.IsNotExist(err) { - return nil, fmt.Errorf("couldn't open default configuration file: %v", err) - } - } + configPaths = append(configPaths, path.Join(driverPath, configPath), configPath) } + for _, p := range configPaths { + cfg, err := config.Load(p) + if err == nil { + return cfg, nil + } else if os.IsNotExist(err) && !required { + continue + } + return nil, fmt.Errorf("couldn't open configuration file: %v", err) + } + + return config.GetDefault() +} + +func getHookConfig() (*HookConfig, error) { + cfg, err := loadConfig() + if err != nil { + return nil, fmt.Errorf("failed to load config: %v", err) + } + config := (*HookConfig)(cfg) + allSupportedDriverCapabilities := image.SupportedDriverCapabilities if config.SupportedDriverCapabilities == "all" { config.SupportedDriverCapabilities = allSupportedDriverCapabilities.String() @@ -74,7 +73,7 @@ func getHookConfig() (*HookConfig, error) { log.Panicf("Invalid value for config option '%v'; %v (supported: %v)\n", configName, config.SupportedDriverCapabilities, allSupportedDriverCapabilities.String()) } - return &config, nil + return config, nil } // getConfigOption returns the toml config option associated with the