mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 00:08:11 +00:00
Fix generation of default config
This change ensures that the nvidia-ctk config default command generates a config file that is compatible with the official documentation to, for example, disable cgroups in the NVIDIA Container CLI. This requires that whitespace around comments is stripped before outputing the contets. This also adds an option to load a config and modify it in-place instead. This can be triggered as a post-install step, for example. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
ba24338122
commit
65ae6f1dab
@ -17,12 +17,16 @@
|
||||
package defaultsubcommand
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
||||
nvctkConfig "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||||
"github.com/pelletier/go-toml"
|
||||
"github.com/urfave/cli/v2"
|
||||
)
|
||||
|
||||
@ -32,7 +36,9 @@ type command struct {
|
||||
|
||||
// options stores the subcommand options
|
||||
type options struct {
|
||||
output string
|
||||
config string
|
||||
output string
|
||||
inPlace bool
|
||||
}
|
||||
|
||||
// NewCommand constructs a default command with the specified logger
|
||||
@ -61,9 +67,20 @@ func (m command) build() *cli.Command {
|
||||
}
|
||||
|
||||
c.Flags = []cli.Flag{
|
||||
&cli.StringFlag{
|
||||
Name: "config",
|
||||
Usage: "Specify the config file to process; The contents of this file overrides the default config",
|
||||
Destination: &opts.config,
|
||||
},
|
||||
&cli.BoolFlag{
|
||||
Name: "in-place",
|
||||
Aliases: []string{"i"},
|
||||
Usage: "Modify the config file in-place",
|
||||
Destination: &opts.inPlace,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "output",
|
||||
Usage: "Specify the file to output the generated configuration for to. If this is '' the configuration is ouput to STDOUT.",
|
||||
Usage: "Specify the output file to write to; If not specified, the output is written to stdout",
|
||||
Destination: &opts.output,
|
||||
},
|
||||
}
|
||||
@ -72,31 +89,98 @@ func (m command) build() *cli.Command {
|
||||
}
|
||||
|
||||
func (m command) validateFlags(c *cli.Context, opts *options) error {
|
||||
if opts.inPlace {
|
||||
if opts.output != "" {
|
||||
return fmt.Errorf("cannot specify both --in-place and --output")
|
||||
}
|
||||
opts.output = opts.config
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m command) run(c *cli.Context, opts *options) error {
|
||||
defaultConfig, err := nvctkConfig.GetDefaultConfigToml()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get default config: %v", err)
|
||||
if err := opts.ensureOutputFolder(); err != nil {
|
||||
return fmt.Errorf("unable to create output directory: %v", err)
|
||||
}
|
||||
|
||||
contents, err := opts.getFormattedConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fix comments: %v", err)
|
||||
}
|
||||
|
||||
if _, err := opts.Write(contents); err != nil {
|
||||
return fmt.Errorf("unable to write to output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFormattedConfig returns the default config formatted as required from the specified config file.
|
||||
// The config is then formatted as required.
|
||||
// No indentation is used and comments are modified so that there is no space
|
||||
// after the '#' character.
|
||||
func (opts options) getFormattedConfig() ([]byte, error) {
|
||||
cfg, err := config.Load(opts.config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load or create config: %v", err)
|
||||
}
|
||||
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
enc := toml.NewEncoder(buffer).Indentation("")
|
||||
|
||||
if err := enc.Encode(cfg); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %v", err)
|
||||
}
|
||||
|
||||
return fixComments(buffer.Bytes())
|
||||
}
|
||||
|
||||
func fixComments(contents []byte) ([]byte, error) {
|
||||
r, err := regexp.Compile(`(\n*)\s*?#\s*(\S.*)`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to compile regexp: %v", err)
|
||||
}
|
||||
replaced := r.ReplaceAll(contents, []byte("$1#$2"))
|
||||
|
||||
return replaced, nil
|
||||
}
|
||||
|
||||
func (opts options) outputExists() (bool, error) {
|
||||
if opts.output == "" {
|
||||
return false, nil
|
||||
}
|
||||
_, err := os.Stat(opts.output)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return false, fmt.Errorf("unable to stat output file: %v", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (opts options) ensureOutputFolder() error {
|
||||
if opts.output == "" {
|
||||
return nil
|
||||
}
|
||||
if dir := filepath.Dir(opts.output); dir != "" {
|
||||
return os.MkdirAll(dir, 0755)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write writes the contents to the output file specified in the options.
|
||||
func (opts options) Write(contents []byte) (int, error) {
|
||||
var output io.Writer
|
||||
if opts.output == "" {
|
||||
output = os.Stdout
|
||||
} else {
|
||||
outputFile, err := os.Create(opts.output)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create output file: %v", err)
|
||||
return 0, fmt.Errorf("unable to create output file: %v", err)
|
||||
}
|
||||
defer outputFile.Close()
|
||||
output = outputFile
|
||||
}
|
||||
|
||||
_, err = defaultConfig.WriteTo(output)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to write to output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return output.Write(contents)
|
||||
}
|
||||
|
82
cmd/nvidia-ctk/config/create-default/create-default_test.go
Normal file
82
cmd/nvidia-ctk/config/create-default/create-default_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package defaultsubcommand
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFixComment(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: "# comment",
|
||||
expected: "#comment",
|
||||
},
|
||||
{
|
||||
input: " #comment",
|
||||
expected: "#comment",
|
||||
},
|
||||
{
|
||||
input: " # comment",
|
||||
expected: "#comment",
|
||||
},
|
||||
{
|
||||
input: strings.Join([]string{
|
||||
"some",
|
||||
"# comment",
|
||||
" # comment",
|
||||
" #comment",
|
||||
"other"}, "\n"),
|
||||
expected: strings.Join([]string{
|
||||
"some",
|
||||
"#comment",
|
||||
"#comment",
|
||||
"#comment",
|
||||
"other"}, "\n"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
actual, _ := fixComments([]byte(tc.input))
|
||||
require.Equal(t, tc.expected, string(actual))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFormattedConfig(t *testing.T) {
|
||||
expectedLines := []string{
|
||||
"#no-cgroups = false",
|
||||
"#debug = \"/var/log/nvidia-container-toolkit.log\"",
|
||||
"#debug = \"/var/log/nvidia-container-runtime.log\"",
|
||||
}
|
||||
|
||||
opts := &options{}
|
||||
contents, err := opts.getFormattedConfig()
|
||||
require.NoError(t, err)
|
||||
lines := strings.Split(string(contents), "\n")
|
||||
|
||||
for _, line := range expectedLines {
|
||||
require.Contains(t, lines, line)
|
||||
}
|
||||
}
|
@ -74,13 +74,22 @@ func GetConfig() (*Config, error) {
|
||||
|
||||
configFilePath := path.Join(configDir, configFilePath)
|
||||
|
||||
return Load(configFilePath)
|
||||
}
|
||||
|
||||
// Load loads the config from the specified file path.
|
||||
func Load(configFilePath string) (*Config, error) {
|
||||
if configFilePath == "" {
|
||||
return getDefault()
|
||||
}
|
||||
|
||||
tomlFile, err := os.Open(configFilePath)
|
||||
if err != nil {
|
||||
return getDefaultConfig()
|
||||
return getDefault()
|
||||
}
|
||||
defer tomlFile.Close()
|
||||
|
||||
cfg, err := loadConfigFrom(tomlFile)
|
||||
cfg, err := LoadFrom(tomlFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config values: %v", err)
|
||||
}
|
||||
@ -88,21 +97,28 @@ func GetConfig() (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// loadRuntimeConfigFrom reads the config from the specified Reader
|
||||
func loadConfigFrom(reader io.Reader) (*Config, error) {
|
||||
toml, err := toml.LoadReader(reader)
|
||||
// LoadFrom reads the config from the specified Reader
|
||||
func LoadFrom(reader io.Reader) (*Config, error) {
|
||||
var tree *toml.Tree
|
||||
if reader != nil {
|
||||
toml, err := toml.LoadReader(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tree = toml
|
||||
}
|
||||
|
||||
return getFromTree(tree)
|
||||
}
|
||||
|
||||
// getFromTree reads the nvidia container runtime config from the specified toml Tree.
|
||||
func getFromTree(toml *toml.Tree) (*Config, error) {
|
||||
cfg, err := getDefault()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return getConfigFrom(toml)
|
||||
}
|
||||
|
||||
// getConfigFrom reads the nvidia container runtime config from the specified toml Tree.
|
||||
func getConfigFrom(toml *toml.Tree) (*Config, error) {
|
||||
cfg, err := getDefaultConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if toml == nil {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
if err := toml.Unmarshal(cfg); err != nil {
|
||||
@ -112,9 +128,9 @@ func getConfigFrom(toml *toml.Tree) (*Config, error) {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// getDefaultConfig defines the default values for the config
|
||||
func getDefaultConfig() (*Config, error) {
|
||||
tomlConfig, err := GetDefaultConfigToml()
|
||||
// getDefault defines the default values for the config
|
||||
func getDefault() (*Config, error) {
|
||||
tomlConfig, err := GetDefaultToml()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -149,8 +165,8 @@ func getDefaultConfig() (*Config, error) {
|
||||
return &d, nil
|
||||
}
|
||||
|
||||
// GetDefaultConfigToml returns the default config as a toml Tree.
|
||||
func GetDefaultConfigToml() (*toml.Tree, error) {
|
||||
// GetDefaultToml returns the default config as a toml Tree.
|
||||
func GetDefaultToml() (*toml.Tree, error) {
|
||||
tree, err := toml.TreeFromMap(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -192,7 +192,7 @@ func TestGetConfig(t *testing.T) {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
reader := strings.NewReader(strings.Join(tc.contents, "\n"))
|
||||
|
||||
cfg, err := loadConfigFrom(reader)
|
||||
cfg, err := LoadFrom(reader)
|
||||
if tc.expectedError != nil {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
|
@ -27,7 +27,7 @@ type RuntimeHookConfig struct {
|
||||
|
||||
// GetDefaultRuntimeHookConfig defines the default values for the config
|
||||
func GetDefaultRuntimeHookConfig() (*RuntimeHookConfig, error) {
|
||||
cfg, err := getDefaultConfig()
|
||||
cfg, err := getDefault()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ type csvModeConfig struct {
|
||||
|
||||
// GetDefaultRuntimeConfig defines the default values for the config
|
||||
func GetDefaultRuntimeConfig() (*RuntimeConfig, error) {
|
||||
cfg, err := getDefaultConfig()
|
||||
cfg, err := getDefault()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user