diff --git a/cmd/nvidia-ctk/config/create-default/create-default.go b/cmd/nvidia-ctk/config/create-default/create-default.go index db89345a..a14fda05 100644 --- a/cmd/nvidia-ctk/config/create-default/create-default.go +++ b/cmd/nvidia-ctk/config/create-default/create-default.go @@ -17,12 +17,16 @@ package defaultsubcommand import ( + "bytes" "fmt" "io" "os" + "path/filepath" + "regexp" - nvctkConfig "github.com/NVIDIA/nvidia-container-toolkit/internal/config" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" + "github.com/pelletier/go-toml" "github.com/urfave/cli/v2" ) @@ -32,7 +36,9 @@ type command struct { // options stores the subcommand options type options struct { - output string + config string + output string + inPlace bool } // NewCommand constructs a default command with the specified logger @@ -61,9 +67,20 @@ func (m command) build() *cli.Command { } c.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "config", + Usage: "Specify the config file to process; The contents of this file overrides the default config", + Destination: &opts.config, + }, + &cli.BoolFlag{ + Name: "in-place", + Aliases: []string{"i"}, + Usage: "Modify the config file in-place", + Destination: &opts.inPlace, + }, &cli.StringFlag{ Name: "output", - Usage: "Specify the file to output the generated configuration for to. If this is '' the configuration is ouput to STDOUT.", + Usage: "Specify the output file to write to; If not specified, the output is written to stdout", Destination: &opts.output, }, } @@ -72,31 +89,98 @@ func (m command) build() *cli.Command { } func (m command) validateFlags(c *cli.Context, opts *options) error { + if opts.inPlace { + if opts.output != "" { + return fmt.Errorf("cannot specify both --in-place and --output") + } + opts.output = opts.config + } return nil } func (m command) run(c *cli.Context, opts *options) error { - defaultConfig, err := nvctkConfig.GetDefaultConfigToml() - if err != nil { - return fmt.Errorf("unable to get default config: %v", err) + if err := opts.ensureOutputFolder(); err != nil { + return fmt.Errorf("unable to create output directory: %v", err) } + contents, err := opts.getFormattedConfig() + if err != nil { + return fmt.Errorf("unable to fix comments: %v", err) + } + + if _, err := opts.Write(contents); err != nil { + return fmt.Errorf("unable to write to output: %v", err) + } + + return nil +} + +// getFormattedConfig returns the default config formatted as required from the specified config file. +// The config is then formatted as required. +// No indentation is used and comments are modified so that there is no space +// after the '#' character. +func (opts options) getFormattedConfig() ([]byte, error) { + cfg, err := config.Load(opts.config) + if err != nil { + return nil, fmt.Errorf("unable to load or create config: %v", err) + } + + buffer := bytes.NewBuffer(nil) + enc := toml.NewEncoder(buffer).Indentation("") + + if err := enc.Encode(cfg); err != nil { + return nil, fmt.Errorf("invalid config: %v", err) + } + + return fixComments(buffer.Bytes()) +} + +func fixComments(contents []byte) ([]byte, error) { + r, err := regexp.Compile(`(\n*)\s*?#\s*(\S.*)`) + if err != nil { + return nil, fmt.Errorf("unable to compile regexp: %v", err) + } + replaced := r.ReplaceAll(contents, []byte("$1#$2")) + + return replaced, nil +} + +func (opts options) outputExists() (bool, error) { + if opts.output == "" { + return false, nil + } + _, err := os.Stat(opts.output) + if err == nil { + return true, nil + } else if !os.IsNotExist(err) { + return false, fmt.Errorf("unable to stat output file: %v", err) + } + return false, nil +} + +func (opts options) ensureOutputFolder() error { + if opts.output == "" { + return nil + } + if dir := filepath.Dir(opts.output); dir != "" { + return os.MkdirAll(dir, 0755) + } + return nil +} + +// Write writes the contents to the output file specified in the options. +func (opts options) Write(contents []byte) (int, error) { var output io.Writer if opts.output == "" { output = os.Stdout } else { outputFile, err := os.Create(opts.output) if err != nil { - return fmt.Errorf("unable to create output file: %v", err) + return 0, fmt.Errorf("unable to create output file: %v", err) } defer outputFile.Close() output = outputFile } - _, err = defaultConfig.WriteTo(output) - if err != nil { - return fmt.Errorf("unable to write to output: %v", err) - } - - return nil + return output.Write(contents) } diff --git a/cmd/nvidia-ctk/config/create-default/create-default_test.go b/cmd/nvidia-ctk/config/create-default/create-default_test.go new file mode 100644 index 00000000..65c940ca --- /dev/null +++ b/cmd/nvidia-ctk/config/create-default/create-default_test.go @@ -0,0 +1,82 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package defaultsubcommand + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFixComment(t *testing.T) { + testCases := []struct { + input string + expected string + }{ + { + input: "# comment", + expected: "#comment", + }, + { + input: " #comment", + expected: "#comment", + }, + { + input: " # comment", + expected: "#comment", + }, + { + input: strings.Join([]string{ + "some", + "# comment", + " # comment", + " #comment", + "other"}, "\n"), + expected: strings.Join([]string{ + "some", + "#comment", + "#comment", + "#comment", + "other"}, "\n"), + }, + } + + for _, tc := range testCases { + t.Run(tc.input, func(t *testing.T) { + actual, _ := fixComments([]byte(tc.input)) + require.Equal(t, tc.expected, string(actual)) + }) + } +} + +func TestGetFormattedConfig(t *testing.T) { + expectedLines := []string{ + "#no-cgroups = false", + "#debug = \"/var/log/nvidia-container-toolkit.log\"", + "#debug = \"/var/log/nvidia-container-runtime.log\"", + } + + opts := &options{} + contents, err := opts.getFormattedConfig() + require.NoError(t, err) + lines := strings.Split(string(contents), "\n") + + for _, line := range expectedLines { + require.Contains(t, lines, line) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index c25aa19d..4381c188 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -74,13 +74,22 @@ func GetConfig() (*Config, error) { configFilePath := path.Join(configDir, configFilePath) + return Load(configFilePath) +} + +// Load loads the config from the specified file path. +func Load(configFilePath string) (*Config, error) { + if configFilePath == "" { + return getDefault() + } + tomlFile, err := os.Open(configFilePath) if err != nil { - return getDefaultConfig() + return getDefault() } defer tomlFile.Close() - cfg, err := loadConfigFrom(tomlFile) + cfg, err := LoadFrom(tomlFile) if err != nil { return nil, fmt.Errorf("failed to read config values: %v", err) } @@ -88,21 +97,28 @@ func GetConfig() (*Config, error) { return cfg, nil } -// loadRuntimeConfigFrom reads the config from the specified Reader -func loadConfigFrom(reader io.Reader) (*Config, error) { - toml, err := toml.LoadReader(reader) +// LoadFrom reads the config from the specified Reader +func LoadFrom(reader io.Reader) (*Config, error) { + var tree *toml.Tree + if reader != nil { + toml, err := toml.LoadReader(reader) + if err != nil { + return nil, err + } + tree = toml + } + + return getFromTree(tree) +} + +// getFromTree reads the nvidia container runtime config from the specified toml Tree. +func getFromTree(toml *toml.Tree) (*Config, error) { + cfg, err := getDefault() if err != nil { return nil, err } - - return getConfigFrom(toml) -} - -// getConfigFrom reads the nvidia container runtime config from the specified toml Tree. -func getConfigFrom(toml *toml.Tree) (*Config, error) { - cfg, err := getDefaultConfig() - if err != nil { - return nil, err + if toml == nil { + return cfg, nil } if err := toml.Unmarshal(cfg); err != nil { @@ -112,9 +128,9 @@ func getConfigFrom(toml *toml.Tree) (*Config, error) { return cfg, nil } -// getDefaultConfig defines the default values for the config -func getDefaultConfig() (*Config, error) { - tomlConfig, err := GetDefaultConfigToml() +// getDefault defines the default values for the config +func getDefault() (*Config, error) { + tomlConfig, err := GetDefaultToml() if err != nil { return nil, err } @@ -149,8 +165,8 @@ func getDefaultConfig() (*Config, error) { return &d, nil } -// GetDefaultConfigToml returns the default config as a toml Tree. -func GetDefaultConfigToml() (*toml.Tree, error) { +// GetDefaultToml returns the default config as a toml Tree. +func GetDefaultToml() (*toml.Tree, error) { tree, err := toml.TreeFromMap(nil) if err != nil { return nil, err diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 867edf79..f69f2736 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -192,7 +192,7 @@ func TestGetConfig(t *testing.T) { t.Run(tc.description, func(t *testing.T) { reader := strings.NewReader(strings.Join(tc.contents, "\n")) - cfg, err := loadConfigFrom(reader) + cfg, err := LoadFrom(reader) if tc.expectedError != nil { require.Error(t, err) } else { diff --git a/internal/config/hook.go b/internal/config/hook.go index 1222a4bb..99c52061 100644 --- a/internal/config/hook.go +++ b/internal/config/hook.go @@ -27,7 +27,7 @@ type RuntimeHookConfig struct { // GetDefaultRuntimeHookConfig defines the default values for the config func GetDefaultRuntimeHookConfig() (*RuntimeHookConfig, error) { - cfg, err := getDefaultConfig() + cfg, err := getDefault() if err != nil { return nil, err } diff --git a/internal/config/runtime.go b/internal/config/runtime.go index 4dc89e2d..ba9fc83c 100644 --- a/internal/config/runtime.go +++ b/internal/config/runtime.go @@ -48,7 +48,7 @@ type csvModeConfig struct { // GetDefaultRuntimeConfig defines the default values for the config func GetDefaultRuntimeConfig() (*RuntimeConfig, error) { - cfg, err := getDefaultConfig() + cfg, err := getDefault() if err != nil { return nil, err }