diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index ec4e0434..e06b44bf 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -4,7 +4,6 @@ import ( "fmt" "log" "os" - "path" "reflect" "strings" @@ -13,7 +12,6 @@ import ( ) const ( - configPath = "/etc/nvidia-container-runtime/config.toml" driverPath = "/run/nvidia/driver" ) @@ -25,29 +23,27 @@ type hookConfig struct { // loadConfig loads the required paths for the hook config. func loadConfig() (*config.Config, error) { - var configPaths []string - var required bool - if len(*configflag) != 0 { - configPaths = append(configPaths, *configflag) - required = true - } else { - configPaths = append(configPaths, path.Join(driverPath, configPath), configPath) + configFilePath, required := getConfigFilePath() + cfg, err := config.New( + config.WithConfigFile(configFilePath), + config.WithRequired(true), + ) + if err == nil { + return cfg.Config() + } else if os.IsNotExist(err) && !required { + return config.GetDefault() } + return nil, fmt.Errorf("couldn't open required configuration file: %v", err) +} - for _, p := range configPaths { - cfg, err := config.New( - config.WithConfigFile(p), - config.WithRequired(true), - ) - if err == nil { - return cfg.Config() - } else if os.IsNotExist(err) && !required { - continue - } - return nil, fmt.Errorf("couldn't open required configuration file: %v", err) +func getConfigFilePath() (string, bool) { + if configFromFlag := *configflag; configFromFlag != "" { + return configFromFlag, true } - - return config.GetDefault() + if configFromEnvvar := os.Getenv(config.FilePathOverrideEnvVar); configFromEnvvar != "" { + return configFromEnvvar, true + } + return config.GetConfigFilePath(), false } func getHookConfig() (*hookConfig, error) { diff --git a/internal/config/config.go b/internal/config/config.go index 4d4ba605..5da0aba5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -31,6 +31,8 @@ import ( ) const ( + FilePathOverrideEnvVar = "NVIDIA_CTK_CONFIG_FILE_PATH" + configOverride = "XDG_CONFIG_HOME" configFilePath = "nvidia-container-runtime/config.toml" @@ -74,6 +76,9 @@ type Config struct { // GetConfigFilePath returns the path to the config file for the configured system func GetConfigFilePath() string { + if configFilePathOverride := os.Getenv(FilePathOverrideEnvVar); configFilePathOverride != "" { + return configFilePathOverride + } if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 { return filepath.Join(XDGConfigDir, configFilePath) }