diff --git a/internal/config/config.go b/internal/config/config.go index 38a65e22..fa326c27 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -21,7 +21,6 @@ import ( "fmt" "io" "os" - "path" "path/filepath" "strings" @@ -51,8 +50,6 @@ var ( NVIDIAContainerRuntimeHookExecutable = "nvidia-container-runtime-hook" // NVIDIAContainerToolkitExecutable is the executable name for the NVIDIA Container Toolkit (an alias for the NVIDIA Container Runtime Hook) NVIDIAContainerToolkitExecutable = "nvidia-container-toolkit" - - configDir = "/etc/" ) // Config represents the contents of the config.toml file for the NVIDIA Container Toolkit @@ -70,16 +67,19 @@ type Config struct { NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"` } +// GetConfigFilePath returns the path to the config file for the configured system +func GetConfigFilePath() string { + if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 { + return filepath.Join(XDGConfigDir, configFilePath) + } + + return filepath.Join("/etc", configFilePath) +} + // GetConfig sets up the config struct. Values are read from a toml file // or set via the environment. func GetConfig() (*Config, error) { - if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 { - configDir = XDGConfigDir - } - - configFilePath := path.Join(configDir, configFilePath) - - return Load(configFilePath) + return Load(GetConfigFilePath()) } // Load loads the config from the specified file path. diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 9cb4a946..55513433 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -284,3 +284,61 @@ func TestConfigDefault(t *testing.T) { require.Subset(t, lines, expectedLines) } + +func TestFormat(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + { + input: "# comment", + expected: "#comment", + }, + { + input: " #comment", + expected: "#comment", + }, + { + input: " # comment", + expected: "#comment", + }, + { + input: strings.Join([]string{ + "some", + "# comment", + " # comment", + " #comment", + "other"}, "\n"), + expected: strings.Join([]string{ + "some", + "#comment", + "#comment", + "#comment", + "other"}, "\n"), + }, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + actual, _ := Config{}.format([]byte(tc.input)) + require.Equal(t, tc.expected, string(actual)) + }) + } +} + +func TestGetFormattedConfig(t *testing.T) { + expectedLines := []string{ + "#no-cgroups = false", + "#debug = \"/var/log/nvidia-container-toolkit.log\"", + "#debug = \"/var/log/nvidia-container-runtime.log\"", + } + + config, _ := GetDefault() + contents, err := config.contents() + require.NoError(t, err) + lines := strings.Split(string(contents), "\n") + + for _, line := range expectedLines { + require.Contains(t, lines, line) + } +}