diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fbe06d5..ec6e543f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * Add support for `--library-search-paths` to `nvidia-ctk cdi generate` command. * Add support for injecting /dev/nvidia-nvswitch* devices if the NVIDIA_NVSWITCH=enabled envvar is specified. * Added support for `nvidia-ctk runtime configure --enable-cdi` for the `docker` runtime. Note that this requires Docker >= 25. +* Fixed bug in `nvidia-ctk config` command when using `--set`. The types of applied config options are now applied correctly. * [libnvidia-container] Fix device permission check when using cgroupv2 (fixes #227) diff --git a/cmd/nvidia-ctk/config/config.go b/cmd/nvidia-ctk/config/config.go index 17cf1e06..1cd1896e 100644 --- a/cmd/nvidia-ctk/config/config.go +++ b/cmd/nvidia-ctk/config/config.go @@ -19,6 +19,7 @@ package config import ( "errors" "fmt" + "reflect" "strconv" "strings" @@ -103,7 +104,7 @@ func run(c *cli.Context, opts *options) error { } for _, set := range opts.sets.Value() { - key, value, err := (*configToml)(cfgToml).setFlagToKeyValue(set) + key, value, err := setFlagToKeyValue(set) if err != nil { return fmt.Errorf("invalid --set option %v: %w", set, err) } @@ -126,49 +127,86 @@ func run(c *cli.Context, opts *options) error { return nil } -type configToml config.Toml - var errInvalidConfigOption = errors.New("invalid config option") +var errUndefinedField = errors.New("undefined field") var errInvalidFormat = errors.New("invalid format") // setFlagToKeyValue converts a --set flag to a key-value pair. // The set flag is of the form key[=value], with the value being optional if key refers to a // boolean config option. -func (c *configToml) setFlagToKeyValue(setFlag string) (string, interface{}, error) { - if c == nil { - return "", nil, errInvalidConfigOption - } - +func setFlagToKeyValue(setFlag string) (string, interface{}, error) { setParts := strings.SplitN(setFlag, "=", 2) key := setParts[0] - v := (*config.Toml)(c).Get(key) - if v == nil { - return key, nil, errInvalidConfigOption - } - if _, ok := v.(bool); ok { - if len(setParts) == 1 { - return key, true, nil - } + field, err := getField(key) + if err != nil { + return key, nil, fmt.Errorf("%w: %w", errInvalidConfigOption, err) } + kind := field.Kind() if len(setParts) != 2 { + if kind == reflect.Bool { + return key, true, nil + } return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag) } value := setParts[1] - switch vt := v.(type) { - case bool: + switch kind { + case reflect.Bool: b, err := strconv.ParseBool(value) if err != nil { return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err) } return key, b, err - case string: + case reflect.String: return key, value, nil - case []string: - return key, strings.Split(value, ","), nil - default: - return key, nil, fmt.Errorf("unsupported type for %v (%v)", setParts, vt) + case reflect.Slice: + valueParts := strings.Split(value, ",") + switch field.Elem().Kind() { + case reflect.String: + return key, valueParts, nil + case reflect.Int: + var output []int64 + for _, v := range valueParts { + vi, err := strconv.ParseInt(v, 10, 0) + if err != nil { + return key, nil, fmt.Errorf("%w: %w", errInvalidFormat, err) + } + output = append(output, vi) + } + return key, output, nil + } } + return key, nil, fmt.Errorf("unsupported type for %v (%v)", setParts, kind) +} + +func getField(key string) (reflect.Type, error) { + s, err := getStruct(reflect.TypeOf(config.Config{}), strings.Split(key, ".")...) + if err != nil { + return nil, err + } + return s.Type, err +} + +func getStruct(current reflect.Type, paths ...string) (reflect.StructField, error) { + if len(paths) < 1 { + return reflect.StructField{}, fmt.Errorf("%w: no fields selected", errUndefinedField) + } + tomlField := paths[0] + for i := 0; i < current.NumField(); i++ { + f := current.Field(i) + v, ok := f.Tag.Lookup("toml") + if !ok { + continue + } + if v != tomlField { + continue + } + if len(paths) == 1 { + return f, nil + } + return getStruct(f.Type, paths[1:]...) + } + return reflect.StructField{}, fmt.Errorf("%w: %q", errUndefinedField, tomlField) } diff --git a/cmd/nvidia-ctk/config/config_test.go b/cmd/nvidia-ctk/config/config_test.go index bab1cb4d..eca474e9 100644 --- a/cmd/nvidia-ctk/config/config_test.go +++ b/cmd/nvidia-ctk/config/config_test.go @@ -19,152 +19,109 @@ package config import ( "testing" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config" - "github.com/pelletier/go-toml" "github.com/stretchr/testify/require" ) func TestSetFlagToKeyValue(t *testing.T) { + // TODO: We need to enable this test again since switching to reflect. testCases := []struct { description string - config map[string]interface{} setFlag string expectedKey string expectedValue interface{} expectedError error }{ { - description: "empty config returns an error", - setFlag: "anykey=value", - expectedKey: "anykey", - expectedError: errInvalidConfigOption, - }, - { - description: "option not present returns an error", - config: map[string]interface{}{ - "defined": "defined-value", - }, + description: "option not present returns an error", setFlag: "undefined=new-value", expectedKey: "undefined", expectedError: errInvalidConfigOption, }, { - description: "boolean option assumes true", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean", - expectedKey: "boolean", + description: "undefined nexted option returns error", + setFlag: "nvidia-container-cli.undefined", + expectedKey: "nvidia-container-cli.undefined", + expectedError: errInvalidConfigOption, + }, + { + description: "boolean option assumes true", + setFlag: "disable-require", + expectedKey: "disable-require", expectedValue: true, }, { - description: "boolean option returns true", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean=true", - expectedKey: "boolean", + description: "boolean option returns true", + setFlag: "disable-require=true", + expectedKey: "disable-require", expectedValue: true, }, { - description: "boolean option returns false", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean=false", - expectedKey: "boolean", + description: "boolean option returns false", + setFlag: "disable-require=false", + expectedKey: "disable-require", expectedValue: false, }, { - description: "invalid boolean option returns error", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean=something", - expectedKey: "boolean", + description: "invalid boolean option returns error", + setFlag: "disable-require=something", + expectedKey: "disable-require", expectedValue: "something", expectedError: errInvalidFormat, }, { - description: "string option requires value", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string", - expectedKey: "string", + description: "string option requires value", + setFlag: "swarm-resource", + expectedKey: "swarm-resource", expectedValue: nil, expectedError: errInvalidFormat, }, { - description: "string option returns value", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=string-value", - expectedKey: "string", + description: "string option returns value", + setFlag: "swarm-resource=string-value", + expectedKey: "swarm-resource", expectedValue: "string-value", }, { - description: "string option returns value with equals", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=string-value=more", - expectedKey: "string", + description: "string option returns value with equals", + setFlag: "swarm-resource=string-value=more", + expectedKey: "swarm-resource", expectedValue: "string-value=more", }, { - description: "string option treats bool value as string", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=true", - expectedKey: "string", + description: "string option treats bool value as string", + setFlag: "swarm-resource=true", + expectedKey: "swarm-resource", expectedValue: "true", }, { - description: "string option treats int value as string", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=5", - expectedKey: "string", + description: "string option treats int value as string", + setFlag: "swarm-resource=5", + expectedKey: "swarm-resource", expectedValue: "5", }, { - description: "[]string option returns single value", - config: map[string]interface{}{ - "string": []string{"value"}, - }, - setFlag: "string=string-value", - expectedKey: "string", + description: "[]string option returns single value", + setFlag: "nvidia-container-cli.environment=string-value", + expectedKey: "nvidia-container-cli.environment", expectedValue: []string{"string-value"}, }, { - description: "[]string option returns multiple values", - config: map[string]interface{}{ - "string": []string{"value"}, - }, - setFlag: "string=first,second", - expectedKey: "string", + description: "[]string option returns multiple values", + setFlag: "nvidia-container-cli.environment=first,second", + expectedKey: "nvidia-container-cli.environment", expectedValue: []string{"first", "second"}, }, { - description: "[]string option returns values with equals", - config: map[string]interface{}{ - "string": []string{"value"}, - }, - setFlag: "string=first=1,second=2", - expectedKey: "string", + description: "[]string option returns values with equals", + setFlag: "nvidia-container-cli.environment=first=1,second=2", + expectedKey: "nvidia-container-cli.environment", expectedValue: []string{"first=1", "second=2"}, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - tree, _ := toml.TreeFromMap(tc.config) - cfgToml := (*config.Toml)(tree) - k, v, err := (*configToml)(cfgToml).setFlagToKeyValue(tc.setFlag) + k, v, err := setFlagToKeyValue(tc.setFlag) require.ErrorIs(t, err, tc.expectedError) require.EqualValues(t, tc.expectedKey, k) require.EqualValues(t, tc.expectedValue, v)