Fix generation of default config

This change ensures that the nvidia-ctk config default command
generates a config file that is compatible with the official documentation
to, for example, disable cgroups in the NVIDIA Container CLI.

This requires that whitespace around comments is stripped before outputing the
contets.

This also adds an option to load a config and modify it in-place instead. This can
be triggered as a post-install step, for example.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-06-30 12:47:19 +02:00
parent ba24338122
commit 65ae6f1dab
6 changed files with 217 additions and 35 deletions

View File

@ -17,12 +17,16 @@
package defaultsubcommand package defaultsubcommand
import ( import (
"bytes"
"fmt" "fmt"
"io" "io"
"os" "os"
"path/filepath"
"regexp"
nvctkConfig "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/pelletier/go-toml"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
) )
@ -32,7 +36,9 @@ type command struct {
// options stores the subcommand options // options stores the subcommand options
type options struct { type options struct {
output string config string
output string
inPlace bool
} }
// NewCommand constructs a default command with the specified logger // NewCommand constructs a default command with the specified logger
@ -61,9 +67,20 @@ func (m command) build() *cli.Command {
} }
c.Flags = []cli.Flag{ c.Flags = []cli.Flag{
&cli.StringFlag{
Name: "config",
Usage: "Specify the config file to process; The contents of this file overrides the default config",
Destination: &opts.config,
},
&cli.BoolFlag{
Name: "in-place",
Aliases: []string{"i"},
Usage: "Modify the config file in-place",
Destination: &opts.inPlace,
},
&cli.StringFlag{ &cli.StringFlag{
Name: "output", Name: "output",
Usage: "Specify the file to output the generated configuration for to. If this is '' the configuration is ouput to STDOUT.", Usage: "Specify the output file to write to; If not specified, the output is written to stdout",
Destination: &opts.output, Destination: &opts.output,
}, },
} }
@ -72,31 +89,98 @@ func (m command) build() *cli.Command {
} }
func (m command) validateFlags(c *cli.Context, opts *options) error { func (m command) validateFlags(c *cli.Context, opts *options) error {
if opts.inPlace {
if opts.output != "" {
return fmt.Errorf("cannot specify both --in-place and --output")
}
opts.output = opts.config
}
return nil return nil
} }
func (m command) run(c *cli.Context, opts *options) error { func (m command) run(c *cli.Context, opts *options) error {
defaultConfig, err := nvctkConfig.GetDefaultConfigToml() if err := opts.ensureOutputFolder(); err != nil {
if err != nil { return fmt.Errorf("unable to create output directory: %v", err)
return fmt.Errorf("unable to get default config: %v", err)
} }
contents, err := opts.getFormattedConfig()
if err != nil {
return fmt.Errorf("unable to fix comments: %v", err)
}
if _, err := opts.Write(contents); err != nil {
return fmt.Errorf("unable to write to output: %v", err)
}
return nil
}
// getFormattedConfig returns the default config formatted as required from the specified config file.
// The config is then formatted as required.
// No indentation is used and comments are modified so that there is no space
// after the '#' character.
func (opts options) getFormattedConfig() ([]byte, error) {
cfg, err := config.Load(opts.config)
if err != nil {
return nil, fmt.Errorf("unable to load or create config: %v", err)
}
buffer := bytes.NewBuffer(nil)
enc := toml.NewEncoder(buffer).Indentation("")
if err := enc.Encode(cfg); err != nil {
return nil, fmt.Errorf("invalid config: %v", err)
}
return fixComments(buffer.Bytes())
}
func fixComments(contents []byte) ([]byte, error) {
r, err := regexp.Compile(`(\n*)\s*?#\s*(\S.*)`)
if err != nil {
return nil, fmt.Errorf("unable to compile regexp: %v", err)
}
replaced := r.ReplaceAll(contents, []byte("$1#$2"))
return replaced, nil
}
func (opts options) outputExists() (bool, error) {
if opts.output == "" {
return false, nil
}
_, err := os.Stat(opts.output)
if err == nil {
return true, nil
} else if !os.IsNotExist(err) {
return false, fmt.Errorf("unable to stat output file: %v", err)
}
return false, nil
}
func (opts options) ensureOutputFolder() error {
if opts.output == "" {
return nil
}
if dir := filepath.Dir(opts.output); dir != "" {
return os.MkdirAll(dir, 0755)
}
return nil
}
// Write writes the contents to the output file specified in the options.
func (opts options) Write(contents []byte) (int, error) {
var output io.Writer var output io.Writer
if opts.output == "" { if opts.output == "" {
output = os.Stdout output = os.Stdout
} else { } else {
outputFile, err := os.Create(opts.output) outputFile, err := os.Create(opts.output)
if err != nil { if err != nil {
return fmt.Errorf("unable to create output file: %v", err) return 0, fmt.Errorf("unable to create output file: %v", err)
} }
defer outputFile.Close() defer outputFile.Close()
output = outputFile output = outputFile
} }
_, err = defaultConfig.WriteTo(output) return output.Write(contents)
if err != nil {
return fmt.Errorf("unable to write to output: %v", err)
}
return nil
} }

View File

@ -0,0 +1,82 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package defaultsubcommand
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestFixComment(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{
input: "# comment",
expected: "#comment",
},
{
input: " #comment",
expected: "#comment",
},
{
input: " # comment",
expected: "#comment",
},
{
input: strings.Join([]string{
"some",
"# comment",
" # comment",
" #comment",
"other"}, "\n"),
expected: strings.Join([]string{
"some",
"#comment",
"#comment",
"#comment",
"other"}, "\n"),
},
}
for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
actual, _ := fixComments([]byte(tc.input))
require.Equal(t, tc.expected, string(actual))
})
}
}
func TestGetFormattedConfig(t *testing.T) {
expectedLines := []string{
"#no-cgroups = false",
"#debug = \"/var/log/nvidia-container-toolkit.log\"",
"#debug = \"/var/log/nvidia-container-runtime.log\"",
}
opts := &options{}
contents, err := opts.getFormattedConfig()
require.NoError(t, err)
lines := strings.Split(string(contents), "\n")
for _, line := range expectedLines {
require.Contains(t, lines, line)
}
}

View File

@ -74,13 +74,22 @@ func GetConfig() (*Config, error) {
configFilePath := path.Join(configDir, configFilePath) configFilePath := path.Join(configDir, configFilePath)
return Load(configFilePath)
}
// Load loads the config from the specified file path.
func Load(configFilePath string) (*Config, error) {
if configFilePath == "" {
return getDefault()
}
tomlFile, err := os.Open(configFilePath) tomlFile, err := os.Open(configFilePath)
if err != nil { if err != nil {
return getDefaultConfig() return getDefault()
} }
defer tomlFile.Close() defer tomlFile.Close()
cfg, err := loadConfigFrom(tomlFile) cfg, err := LoadFrom(tomlFile)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to read config values: %v", err) return nil, fmt.Errorf("failed to read config values: %v", err)
} }
@ -88,21 +97,28 @@ func GetConfig() (*Config, error) {
return cfg, nil return cfg, nil
} }
// loadRuntimeConfigFrom reads the config from the specified Reader // LoadFrom reads the config from the specified Reader
func loadConfigFrom(reader io.Reader) (*Config, error) { func LoadFrom(reader io.Reader) (*Config, error) {
toml, err := toml.LoadReader(reader) var tree *toml.Tree
if reader != nil {
toml, err := toml.LoadReader(reader)
if err != nil {
return nil, err
}
tree = toml
}
return getFromTree(tree)
}
// getFromTree reads the nvidia container runtime config from the specified toml Tree.
func getFromTree(toml *toml.Tree) (*Config, error) {
cfg, err := getDefault()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if toml == nil {
return getConfigFrom(toml) return cfg, nil
}
// getConfigFrom reads the nvidia container runtime config from the specified toml Tree.
func getConfigFrom(toml *toml.Tree) (*Config, error) {
cfg, err := getDefaultConfig()
if err != nil {
return nil, err
} }
if err := toml.Unmarshal(cfg); err != nil { if err := toml.Unmarshal(cfg); err != nil {
@ -112,9 +128,9 @@ func getConfigFrom(toml *toml.Tree) (*Config, error) {
return cfg, nil return cfg, nil
} }
// getDefaultConfig defines the default values for the config // getDefault defines the default values for the config
func getDefaultConfig() (*Config, error) { func getDefault() (*Config, error) {
tomlConfig, err := GetDefaultConfigToml() tomlConfig, err := GetDefaultToml()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -149,8 +165,8 @@ func getDefaultConfig() (*Config, error) {
return &d, nil return &d, nil
} }
// GetDefaultConfigToml returns the default config as a toml Tree. // GetDefaultToml returns the default config as a toml Tree.
func GetDefaultConfigToml() (*toml.Tree, error) { func GetDefaultToml() (*toml.Tree, error) {
tree, err := toml.TreeFromMap(nil) tree, err := toml.TreeFromMap(nil)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -192,7 +192,7 @@ func TestGetConfig(t *testing.T) {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
reader := strings.NewReader(strings.Join(tc.contents, "\n")) reader := strings.NewReader(strings.Join(tc.contents, "\n"))
cfg, err := loadConfigFrom(reader) cfg, err := LoadFrom(reader)
if tc.expectedError != nil { if tc.expectedError != nil {
require.Error(t, err) require.Error(t, err)
} else { } else {

View File

@ -27,7 +27,7 @@ type RuntimeHookConfig struct {
// GetDefaultRuntimeHookConfig defines the default values for the config // GetDefaultRuntimeHookConfig defines the default values for the config
func GetDefaultRuntimeHookConfig() (*RuntimeHookConfig, error) { func GetDefaultRuntimeHookConfig() (*RuntimeHookConfig, error) {
cfg, err := getDefaultConfig() cfg, err := getDefault()
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -48,7 +48,7 @@ type csvModeConfig struct {
// GetDefaultRuntimeConfig defines the default values for the config // GetDefaultRuntimeConfig defines the default values for the config
func GetDefaultRuntimeConfig() (*RuntimeConfig, error) { func GetDefaultRuntimeConfig() (*RuntimeConfig, error) {
cfg, err := getDefaultConfig() cfg, err := getDefault()
if err != nil { if err != nil {
return nil, err return nil, err
} }