From 65ae6f1dabcf6c3c15e664139cfb93ebe08b6ae0 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Fri, 30 Jun 2023 12:47:19 +0200 Subject: [PATCH] Fix generation of default config This change ensures that the nvidia-ctk config default command generates a config file that is compatible with the official documentation to, for example, disable cgroups in the NVIDIA Container CLI. This requires that whitespace around comments is stripped before outputing the contets. This also adds an option to load a config and modify it in-place instead. This can be triggered as a post-install step, for example. Signed-off-by: Evan Lezar --- .../config/create-default/create-default.go | 110 +++++++++++++++--- .../create-default/create-default_test.go | 82 +++++++++++++ internal/config/config.go | 54 ++++++--- internal/config/config_test.go | 2 +- internal/config/hook.go | 2 +- internal/config/runtime.go | 2 +- 6 files changed, 217 insertions(+), 35 deletions(-) create mode 100644 cmd/nvidia-ctk/config/create-default/create-default_test.go diff --git a/cmd/nvidia-ctk/config/create-default/create-default.go b/cmd/nvidia-ctk/config/create-default/create-default.go index db89345a..a14fda05 100644 --- a/cmd/nvidia-ctk/config/create-default/create-default.go +++ b/cmd/nvidia-ctk/config/create-default/create-default.go @@ -17,12 +17,16 @@ package defaultsubcommand import ( + "bytes" "fmt" "io" "os" + "path/filepath" + "regexp" - nvctkConfig "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/pelletier/go-toml" "github.com/urfave/cli/v2" ) @@ -32,7 +36,9 @@ type command struct { // options stores the subcommand options type options struct { - output string + config string + output string + inPlace bool } // NewCommand constructs a default command with the specified logger @@ -61,9 +67,20 @@ func (m command) build() *cli.Command { } c.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Usage: "Specify the config file to process; The contents of this file overrides the default config", + Destination: &opts.config, + }, + &cli.BoolFlag{ + Name: "in-place", + Aliases: []string{"i"}, + Usage: "Modify the config file in-place", + Destination: &opts.inPlace, + }, &cli.StringFlag{ Name: "output", - Usage: "Specify the file to output the generated configuration for to. If this is '' the configuration is ouput to STDOUT.", + Usage: "Specify the output file to write to; If not specified, the output is written to stdout", Destination: &opts.output, }, } @@ -72,31 +89,98 @@ func (m command) build() *cli.Command { } func (m command) validateFlags(c *cli.Context, opts *options) error { + if opts.inPlace { + if opts.output != "" { + return fmt.Errorf("cannot specify both --in-place and --output") + } + opts.output = opts.config + } return nil } func (m command) run(c *cli.Context, opts *options) error { - defaultConfig, err := nvctkConfig.GetDefaultConfigToml() - if err != nil { - return fmt.Errorf("unable to get default config: %v", err) + if err := opts.ensureOutputFolder(); err != nil { + return fmt.Errorf("unable to create output directory: %v", err) } + contents, err := opts.getFormattedConfig() + if err != nil { + return fmt.Errorf("unable to fix comments: %v", err) + } + + if _, err := opts.Write(contents); err != nil { + return fmt.Errorf("unable to write to output: %v", err) + } + + return nil +} + +// getFormattedConfig returns the default config formatted as required from the specified config file. +// The config is then formatted as required. +// No indentation is used and comments are modified so that there is no space +// after the '#' character. +func (opts options) getFormattedConfig() ([]byte, error) { + cfg, err := config.Load(opts.config) + if err != nil { + return nil, fmt.Errorf("unable to load or create config: %v", err) + } + + buffer := bytes.NewBuffer(nil) + enc := toml.NewEncoder(buffer).Indentation("") + + if err := enc.Encode(cfg); err != nil { + return nil, fmt.Errorf("invalid config: %v", err) + } + + return fixComments(buffer.Bytes()) +} + +func fixComments(contents []byte) ([]byte, error) { + r, err := regexp.Compile(`(\n*)\s*?#\s*(\S.*)`) + if err != nil { + return nil, fmt.Errorf("unable to compile regexp: %v", err) + } + replaced := r.ReplaceAll(contents, []byte("$1#$2")) + + return replaced, nil +} + +func (opts options) outputExists() (bool, error) { + if opts.output == "" { + return false, nil + } + _, err := os.Stat(opts.output) + if err == nil { + return true, nil + } else if !os.IsNotExist(err) { + return false, fmt.Errorf("unable to stat output file: %v", err) + } + return false, nil +} + +func (opts options) ensureOutputFolder() error { + if opts.output == "" { + return nil + } + if dir := filepath.Dir(opts.output); dir != "" { + return os.MkdirAll(dir, 0755) + } + return nil +} + +// Write writes the contents to the output file specified in the options. +func (opts options) Write(contents []byte) (int, error) { var output io.Writer if opts.output == "" { output = os.Stdout } else { outputFile, err := os.Create(opts.output) if err != nil { - return fmt.Errorf("unable to create output file: %v", err) + return 0, fmt.Errorf("unable to create output file: %v", err) } defer outputFile.Close() output = outputFile } - _, err = defaultConfig.WriteTo(output) - if err != nil { - return fmt.Errorf("unable to write to output: %v", err) - } - - return nil + return output.Write(contents) } diff --git a/cmd/nvidia-ctk/config/create-default/create-default_test.go b/cmd/nvidia-ctk/config/create-default/create-default_test.go new file mode 100644 index 00000000..65c940ca --- /dev/null +++ b/cmd/nvidia-ctk/config/create-default/create-default_test.go @@ -0,0 +1,82 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package defaultsubcommand + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFixComment(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + { + input: "# comment", + expected: "#comment", + }, + { + input: " #comment", + expected: "#comment", + }, + { + input: " # comment", + expected: "#comment", + }, + { + input: strings.Join([]string{ + "some", + "# comment", + " # comment", + " #comment", + "other"}, "\n"), + expected: strings.Join([]string{ + "some", + "#comment", + "#comment", + "#comment", + "other"}, "\n"), + }, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + actual, _ := fixComments([]byte(tc.input)) + require.Equal(t, tc.expected, string(actual)) + }) + } +} + +func TestGetFormattedConfig(t *testing.T) { + expectedLines := []string{ + "#no-cgroups = false", + "#debug = \"/var/log/nvidia-container-toolkit.log\"", + "#debug = \"/var/log/nvidia-container-runtime.log\"", + } + + opts := &options{} + contents, err := opts.getFormattedConfig() + require.NoError(t, err) + lines := strings.Split(string(contents), "\n") + + for _, line := range expectedLines { + require.Contains(t, lines, line) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index c25aa19d..4381c188 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -74,13 +74,22 @@ func GetConfig() (*Config, error) { configFilePath := path.Join(configDir, configFilePath) + return Load(configFilePath) +} + +// Load loads the config from the specified file path. +func Load(configFilePath string) (*Config, error) { + if configFilePath == "" { + return getDefault() + } + tomlFile, err := os.Open(configFilePath) if err != nil { - return getDefaultConfig() + return getDefault() } defer tomlFile.Close() - cfg, err := loadConfigFrom(tomlFile) + cfg, err := LoadFrom(tomlFile) if err != nil { return nil, fmt.Errorf("failed to read config values: %v", err) } @@ -88,21 +97,28 @@ func GetConfig() (*Config, error) { return cfg, nil } -// loadRuntimeConfigFrom reads the config from the specified Reader -func loadConfigFrom(reader io.Reader) (*Config, error) { - toml, err := toml.LoadReader(reader) +// LoadFrom reads the config from the specified Reader +func LoadFrom(reader io.Reader) (*Config, error) { + var tree *toml.Tree + if reader != nil { + toml, err := toml.LoadReader(reader) + if err != nil { + return nil, err + } + tree = toml + } + + return getFromTree(tree) +} + +// getFromTree reads the nvidia container runtime config from the specified toml Tree. +func getFromTree(toml *toml.Tree) (*Config, error) { + cfg, err := getDefault() if err != nil { return nil, err } - - return getConfigFrom(toml) -} - -// getConfigFrom reads the nvidia container runtime config from the specified toml Tree. -func getConfigFrom(toml *toml.Tree) (*Config, error) { - cfg, err := getDefaultConfig() - if err != nil { - return nil, err + if toml == nil { + return cfg, nil } if err := toml.Unmarshal(cfg); err != nil { @@ -112,9 +128,9 @@ func getConfigFrom(toml *toml.Tree) (*Config, error) { return cfg, nil } -// getDefaultConfig defines the default values for the config -func getDefaultConfig() (*Config, error) { - tomlConfig, err := GetDefaultConfigToml() +// getDefault defines the default values for the config +func getDefault() (*Config, error) { + tomlConfig, err := GetDefaultToml() if err != nil { return nil, err } @@ -149,8 +165,8 @@ func getDefaultConfig() (*Config, error) { return &d, nil } -// GetDefaultConfigToml returns the default config as a toml Tree. -func GetDefaultConfigToml() (*toml.Tree, error) { +// GetDefaultToml returns the default config as a toml Tree. +func GetDefaultToml() (*toml.Tree, error) { tree, err := toml.TreeFromMap(nil) if err != nil { return nil, err diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 867edf79..f69f2736 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -192,7 +192,7 @@ func TestGetConfig(t *testing.T) { t.Run(tc.description, func(t *testing.T) { reader := strings.NewReader(strings.Join(tc.contents, "\n")) - cfg, err := loadConfigFrom(reader) + cfg, err := LoadFrom(reader) if tc.expectedError != nil { require.Error(t, err) } else { diff --git a/internal/config/hook.go b/internal/config/hook.go index 1222a4bb..99c52061 100644 --- a/internal/config/hook.go +++ b/internal/config/hook.go @@ -27,7 +27,7 @@ type RuntimeHookConfig struct { // GetDefaultRuntimeHookConfig defines the default values for the config func GetDefaultRuntimeHookConfig() (*RuntimeHookConfig, error) { - cfg, err := getDefaultConfig() + cfg, err := getDefault() if err != nil { return nil, err } diff --git a/internal/config/runtime.go b/internal/config/runtime.go index 4dc89e2d..ba9fc83c 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -48,7 +48,7 @@ type csvModeConfig struct { // GetDefaultRuntimeConfig defines the default values for the config func GetDefaultRuntimeConfig() (*RuntimeConfig, error) { - cfg, err := getDefaultConfig() + cfg, err := getDefault() if err != nil { return nil, err }