Add function to get config file path.

This commit is contained in:
Evan Lezar 2023-08-07 15:24:02 +02:00
parent 5216e89a70
commit c2d4de54b0
2 changed files with 68 additions and 10 deletions

View File

@ -21,7 +21,6 @@ import (
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
@ -51,8 +50,6 @@ var (
NVIDIAContainerRuntimeHookExecutable = "nvidia-container-runtime-hook"
// NVIDIAContainerToolkitExecutable is the executable name for the NVIDIA Container Toolkit (an alias for the NVIDIA Container Runtime Hook)
NVIDIAContainerToolkitExecutable = "nvidia-container-toolkit"
configDir = "/etc/"
)
// Config represents the contents of the config.toml file for the NVIDIA Container Toolkit
@ -70,16 +67,19 @@ type Config struct {
NVIDIAContainerRuntimeHookConfig RuntimeHookConfig `toml:"nvidia-container-runtime-hook"`
}
// GetConfigFilePath returns the path to the config file for the configured system
func GetConfigFilePath() string {
if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 {
return filepath.Join(XDGConfigDir, configFilePath)
}
return filepath.Join("/etc", configFilePath)
}
// GetConfig sets up the config struct. Values are read from a toml file
// or set via the environment.
func GetConfig() (*Config, error) {
if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 {
configDir = XDGConfigDir
}
configFilePath := path.Join(configDir, configFilePath)
return Load(configFilePath)
return Load(GetConfigFilePath())
}
// Load loads the config from the specified file path.

View File

@ -284,3 +284,61 @@ func TestConfigDefault(t *testing.T) {
require.Subset(t, lines, expectedLines)
}
func TestFormat(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{
input: "# comment",
expected: "#comment",
},
{
input: " #comment",
expected: "#comment",
},
{
input: " # comment",
expected: "#comment",
},
{
input: strings.Join([]string{
"some",
"# comment",
" # comment",
" #comment",
"other"}, "\n"),
expected: strings.Join([]string{
"some",
"#comment",
"#comment",
"#comment",
"other"}, "\n"),
},
}
for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
actual, _ := Config{}.format([]byte(tc.input))
require.Equal(t, tc.expected, string(actual))
})
}
}
func TestGetFormattedConfig(t *testing.T) {
expectedLines := []string{
"#no-cgroups = false",
"#debug = \"/var/log/nvidia-container-toolkit.log\"",
"#debug = \"/var/log/nvidia-container-runtime.log\"",
}
config, _ := GetDefault()
contents, err := config.contents()
require.NoError(t, err)
lines := strings.Split(string(contents), "\n")
for _, line := range expectedLines {
require.Contains(t, lines, line)
}
}