From 7137f4b05ba25ecc5465c738955a29b57b5c653b Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 15 Mar 2022 17:27:34 +0200 Subject: [PATCH] Move runtime config to internal package Signed-off-by: Evan Lezar --- cmd/nvidia-container-runtime/main.go | 48 +-------- cmd/nvidia-container-runtime/main_test.go | 21 ---- .../runtime_factory.go | 3 +- .../runtime_factory_test.go | 5 +- internal/config/runtime.go | 89 +++++++++++++++++ internal/config/runtime_test.go | 98 +++++++++++++++++++ 6 files changed, 195 insertions(+), 69 deletions(-) create mode 100644 internal/config/runtime.go create mode 100644 internal/config/runtime_test.go diff --git a/cmd/nvidia-container-runtime/main.go b/cmd/nvidia-container-runtime/main.go index 6d103983..3bc6d36f 100644 --- a/cmd/nvidia-container-runtime/main.go +++ b/cmd/nvidia-container-runtime/main.go @@ -3,18 +3,8 @@ package main import ( "fmt" "os" - "path" - "github.com/pelletier/go-toml" -) - -const ( - configOverride = "XDG_CONFIG_HOME" - configFilePath = "nvidia-container-runtime/config.toml" -) - -var ( - configDir = "/etc/" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" ) var logger = NewLogger() @@ -30,12 +20,12 @@ func main() { // run is an entry point that allows for idiomatic handling of errors // when calling from the main function. func run(argv []string) (rerr error) { - cfg, err := getConfig() + cfg, err := config.GetRuntimeConfig() if err != nil { return fmt.Errorf("error loading config: %v", err) } - err = logger.LogToFile(cfg.debugFilePath) + err = logger.LogToFile(cfg.DebugFilePath) if err != nil { return fmt.Errorf("error opening debug log file: %v", err) } @@ -54,35 +44,3 @@ func run(argv []string) (rerr error) { return runtime.Exec(argv) } - -type config struct { - debugFilePath string - Experimental bool -} - -// getConfig sets up the config struct. Values are read from a toml file -// or set via the environment. -func getConfig() (*config, error) { - cfg := &config{} - - if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 { - configDir = XDGConfigDir - } - - configFilePath := path.Join(configDir, configFilePath) - - tomlContent, err := os.ReadFile(configFilePath) - if err != nil { - return nil, err - } - - toml, err := toml.Load(string(tomlContent)) - if err != nil { - return nil, err - } - - cfg.debugFilePath = toml.GetDefault("nvidia-container-runtime.debug", "/dev/null").(string) - cfg.Experimental = toml.GetDefault("nvidia-container-runtime.experimental", false).(bool) - - return cfg, nil -} diff --git a/cmd/nvidia-container-runtime/main_test.go b/cmd/nvidia-container-runtime/main_test.go index 0b576661..d6ba14ae 100644 --- a/cmd/nvidia-container-runtime/main_test.go +++ b/cmd/nvidia-container-runtime/main_test.go @@ -245,24 +245,3 @@ func nvidiaHookCount(hooks *specs.Hooks) int { } return count } - -func TestGetConfigWithCustomConfig(t *testing.T) { - wd, err := os.Getwd() - require.NoError(t, err) - - // By default debug is disabled - contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"") - testDir := filepath.Join(wd, "test") - filename := filepath.Join(testDir, configFilePath) - - os.Setenv(configOverride, testDir) - - require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766)) - require.NoError(t, ioutil.WriteFile(filename, contents, 0766)) - - defer func() { require.NoError(t, os.RemoveAll(testDir)) }() - - cfg, err := getConfig() - require.NoError(t, err) - require.Equal(t, cfg.debugFilePath, "/nvidia-container-toolkit.log") -} diff --git a/cmd/nvidia-container-runtime/runtime_factory.go b/cmd/nvidia-container-runtime/runtime_factory.go index 3b6dbef7..fd4c0960 100644 --- a/cmd/nvidia-container-runtime/runtime_factory.go +++ b/cmd/nvidia-container-runtime/runtime_factory.go @@ -20,6 +20,7 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-container-runtime/modifier" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" "github.com/NVIDIA/nvidia-container-toolkit/internal/runtime" "github.com/sirupsen/logrus" @@ -31,7 +32,7 @@ const ( ) // newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger -func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config, argv []string) (oci.Runtime, error) { +func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.RuntimeConfig, argv []string) (oci.Runtime, error) { ociSpec, err := oci.NewSpec(logger, argv) if err != nil { return nil, fmt.Errorf("error constructing OCI specification: %v", err) diff --git a/cmd/nvidia-container-runtime/runtime_factory_test.go b/cmd/nvidia-container-runtime/runtime_factory_test.go index b29e07bf..467e5efc 100644 --- a/cmd/nvidia-container-runtime/runtime_factory_test.go +++ b/cmd/nvidia-container-runtime/runtime_factory_test.go @@ -19,6 +19,7 @@ package main import ( "testing" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" testlog "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/require" ) @@ -28,13 +29,13 @@ func TestFactoryMethod(t *testing.T) { testCases := []struct { description string - config config + config config.RuntimeConfig argv []string expectedError bool }{ { description: "empty config no error", - config: config{}, + config: config.RuntimeConfig{}, }, } diff --git a/internal/config/runtime.go b/internal/config/runtime.go new file mode 100644 index 00000000..ca2d1565 --- /dev/null +++ b/internal/config/runtime.go @@ -0,0 +1,89 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package config + +import ( + "fmt" + "io" + "os" + "path" + + "github.com/pelletier/go-toml" +) + +const ( + configOverride = "XDG_CONFIG_HOME" + configFilePath = "nvidia-container-runtime/config.toml" +) + +var ( + configDir = "/etc/" +) + +// RuntimeConfig stores the config options for the NVIDIA Container Runtime +type RuntimeConfig struct { + DebugFilePath string + Experimental bool +} + +// GetRuntimeConfig sets up the config struct. Values are read from a toml file +// or set via the environment. +func GetRuntimeConfig() (*RuntimeConfig, error) { + if XDGConfigDir := os.Getenv(configOverride); len(XDGConfigDir) != 0 { + configDir = XDGConfigDir + } + + configFilePath := path.Join(configDir, configFilePath) + + tomlFile, err := os.Open(configFilePath) + if err != nil { + return nil, fmt.Errorf("failed to open config file %v: %v", configFilePath, err) + } + defer tomlFile.Close() + + cfg, err := getRuntimeConfigFrom(tomlFile) + if err != nil { + return nil, fmt.Errorf("failed to read config values: %v", err) + } + + return cfg, nil +} + +// getRuntimeConfigFrom reads the config from the specified Reader +func getRuntimeConfigFrom(reader io.Reader) (*RuntimeConfig, error) { + toml, err := toml.LoadReader(reader) + if err != nil { + return nil, err + } + + cfg := getDefaultRuntimeConfig() + + cfg.DebugFilePath = toml.GetDefault("nvidia-container-runtime.debug", cfg.DebugFilePath).(string) + cfg.Experimental = toml.GetDefault("nvidia-container-runtime.experimental", cfg.Experimental).(bool) + + return cfg, nil +} + +// getDefaultRuntimeConfig defines the default values for the config +func getDefaultRuntimeConfig() *RuntimeConfig { + c := RuntimeConfig{ + DebugFilePath: "/dev/null", + Experimental: false, + } + + return &c +} diff --git a/internal/config/runtime_test.go b/internal/config/runtime_test.go new file mode 100644 index 00000000..91f1b3a5 --- /dev/null +++ b/internal/config/runtime_test.go @@ -0,0 +1,98 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package config + +import ( + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGerRuntimeConfigWithCustomConfig(t *testing.T) { + wd, err := os.Getwd() + require.NoError(t, err) + + // By default debug is disabled + contents := []byte("[nvidia-container-runtime]\ndebug = \"/nvidia-container-toolkit.log\"") + testDir := filepath.Join(wd, "test") + filename := filepath.Join(testDir, configFilePath) + + os.Setenv(configOverride, testDir) + + require.NoError(t, os.MkdirAll(filepath.Dir(filename), 0766)) + require.NoError(t, ioutil.WriteFile(filename, contents, 0766)) + + defer func() { require.NoError(t, os.RemoveAll(testDir)) }() + + cfg, err := GetRuntimeConfig() + require.NoError(t, err) + require.Equal(t, cfg.DebugFilePath, "/nvidia-container-toolkit.log") +} + +func TestGerRuntimeConfig(t *testing.T) { + testCases := []struct { + description string + contents []string + expectedError error + expectedConfig *RuntimeConfig + }{ + { + description: "empty config is default", + expectedConfig: &RuntimeConfig{ + DebugFilePath: "/dev/null", + }, + }, + { + description: "config options set inline", + contents: []string{ + "nvidia-container-runtime.debug = \"/foo/bar\"", + }, + expectedConfig: &RuntimeConfig{ + DebugFilePath: "/foo/bar", + }, + }, + { + description: "config options set in section", + contents: []string{ + "[nvidia-container-runtime]", + "debug = \"/foo/bar\"", + }, + expectedConfig: &RuntimeConfig{ + DebugFilePath: "/foo/bar", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + reader := strings.NewReader(strings.Join(tc.contents, "\n")) + + cfg, err := getRuntimeConfigFrom(reader) + if tc.expectedError != nil { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + require.EqualValues(t, tc.expectedConfig, cfg) + }) + } +}