From 671d787a421c241987459c67c8a4b4da0a3458db Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 22 Nov 2023 12:36:39 +0100 Subject: [PATCH] Switch to reflect package for config updates This change switches to using the reflect package to determine the type of config options instead of inferring the type from the Toml data structure. Signed-off-by: Evan Lezar --- CHANGELOG.md | 1 + cmd/nvidia-ctk/config/config.go | 84 ++++++++++++----- cmd/nvidia-ctk/config/config_test.go | 133 +++++++++------------------ 3 files changed, 107 insertions(+), 111 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fbe06d5..ec6e543f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * Add support for `--library-search-paths` to `nvidia-ctk cdi generate` command. * Add support for injecting /dev/nvidia-nvswitch* devices if the NVIDIA_NVSWITCH=enabled envvar is specified. * Added support for `nvidia-ctk runtime configure --enable-cdi` for the `docker` runtime. Note that this requires Docker >= 25. +* Fixed bug in `nvidia-ctk config` command when using `--set`. The types of applied config options are now applied correctly. * [libnvidia-container] Fix device permission check when using cgroupv2 (fixes #227) diff --git a/cmd/nvidia-ctk/config/config.go b/cmd/nvidia-ctk/config/config.go index 17cf1e06..1cd1896e 100644 --- a/cmd/nvidia-ctk/config/config.go +++ b/cmd/nvidia-ctk/config/config.go @@ -19,6 +19,7 @@ package config import ( "errors" "fmt" + "reflect" "strconv" "strings" @@ -103,7 +104,7 @@ func run(c *cli.Context, opts *options) error { } for _, set := range opts.sets.Value() { - key, value, err := (*configToml)(cfgToml).setFlagToKeyValue(set) + key, value, err := setFlagToKeyValue(set) if err != nil { return fmt.Errorf("invalid --set option %v: %w", set, err) } @@ -126,49 +127,86 @@ func run(c *cli.Context, opts *options) error { return nil } -type configToml config.Toml - var errInvalidConfigOption = errors.New("invalid config option") +var errUndefinedField = errors.New("undefined field") var errInvalidFormat = errors.New("invalid format") // setFlagToKeyValue converts a --set flag to a key-value pair. // The set flag is of the form key[=value], with the value being optional if key refers to a // boolean config option. -func (c *configToml) setFlagToKeyValue(setFlag string) (string, interface{}, error) { - if c == nil { - return "", nil, errInvalidConfigOption - } - +func setFlagToKeyValue(setFlag string) (string, interface{}, error) { setParts := strings.SplitN(setFlag, "=", 2) key := setParts[0] - v := (*config.Toml)(c).Get(key) - if v == nil { - return key, nil, errInvalidConfigOption - } - if _, ok := v.(bool); ok { - if len(setParts) == 1 { - return key, true, nil - } + field, err := getField(key) + if err != nil { + return key, nil, fmt.Errorf("%w: %w", errInvalidConfigOption, err) } + kind := field.Kind() if len(setParts) != 2 { + if kind == reflect.Bool { + return key, true, nil + } return key, nil, fmt.Errorf("%w: expected key=value; got %v", errInvalidFormat, setFlag) } value := setParts[1] - switch vt := v.(type) { - case bool: + switch kind { + case reflect.Bool: b, err := strconv.ParseBool(value) if err != nil { return key, value, fmt.Errorf("%w: %w", errInvalidFormat, err) } return key, b, err - case string: + case reflect.String: return key, value, nil - case []string: - return key, strings.Split(value, ","), nil - default: - return key, nil, fmt.Errorf("unsupported type for %v (%v)", setParts, vt) + case reflect.Slice: + valueParts := strings.Split(value, ",") + switch field.Elem().Kind() { + case reflect.String: + return key, valueParts, nil + case reflect.Int: + var output []int64 + for _, v := range valueParts { + vi, err := strconv.ParseInt(v, 10, 0) + if err != nil { + return key, nil, fmt.Errorf("%w: %w", errInvalidFormat, err) + } + output = append(output, vi) + } + return key, output, nil + } } + return key, nil, fmt.Errorf("unsupported type for %v (%v)", setParts, kind) +} + +func getField(key string) (reflect.Type, error) { + s, err := getStruct(reflect.TypeOf(config.Config{}), strings.Split(key, ".")...) + if err != nil { + return nil, err + } + return s.Type, err +} + +func getStruct(current reflect.Type, paths ...string) (reflect.StructField, error) { + if len(paths) < 1 { + return reflect.StructField{}, fmt.Errorf("%w: no fields selected", errUndefinedField) + } + tomlField := paths[0] + for i := 0; i < current.NumField(); i++ { + f := current.Field(i) + v, ok := f.Tag.Lookup("toml") + if !ok { + continue + } + if v != tomlField { + continue + } + if len(paths) == 1 { + return f, nil + } + return getStruct(f.Type, paths[1:]...) + } + return reflect.StructField{}, fmt.Errorf("%w: %q", errUndefinedField, tomlField) } diff --git a/cmd/nvidia-ctk/config/config_test.go b/cmd/nvidia-ctk/config/config_test.go index bab1cb4d..eca474e9 100644 --- a/cmd/nvidia-ctk/config/config_test.go +++ b/cmd/nvidia-ctk/config/config_test.go @@ -19,152 +19,109 @@ package config import ( "testing" - "github.com/NVIDIA/nvidia-container-toolkit/internal/config" - "github.com/pelletier/go-toml" "github.com/stretchr/testify/require" ) func TestSetFlagToKeyValue(t *testing.T) { + // TODO: We need to enable this test again since switching to reflect. testCases := []struct { description string - config map[string]interface{} setFlag string expectedKey string expectedValue interface{} expectedError error }{ { - description: "empty config returns an error", - setFlag: "anykey=value", - expectedKey: "anykey", - expectedError: errInvalidConfigOption, - }, - { - description: "option not present returns an error", - config: map[string]interface{}{ - "defined": "defined-value", - }, + description: "option not present returns an error", setFlag: "undefined=new-value", expectedKey: "undefined", expectedError: errInvalidConfigOption, }, { - description: "boolean option assumes true", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean", - expectedKey: "boolean", + description: "undefined nexted option returns error", + setFlag: "nvidia-container-cli.undefined", + expectedKey: "nvidia-container-cli.undefined", + expectedError: errInvalidConfigOption, + }, + { + description: "boolean option assumes true", + setFlag: "disable-require", + expectedKey: "disable-require", expectedValue: true, }, { - description: "boolean option returns true", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean=true", - expectedKey: "boolean", + description: "boolean option returns true", + setFlag: "disable-require=true", + expectedKey: "disable-require", expectedValue: true, }, { - description: "boolean option returns false", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean=false", - expectedKey: "boolean", + description: "boolean option returns false", + setFlag: "disable-require=false", + expectedKey: "disable-require", expectedValue: false, }, { - description: "invalid boolean option returns error", - config: map[string]interface{}{ - "boolean": false, - }, - setFlag: "boolean=something", - expectedKey: "boolean", + description: "invalid boolean option returns error", + setFlag: "disable-require=something", + expectedKey: "disable-require", expectedValue: "something", expectedError: errInvalidFormat, }, { - description: "string option requires value", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string", - expectedKey: "string", + description: "string option requires value", + setFlag: "swarm-resource", + expectedKey: "swarm-resource", expectedValue: nil, expectedError: errInvalidFormat, }, { - description: "string option returns value", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=string-value", - expectedKey: "string", + description: "string option returns value", + setFlag: "swarm-resource=string-value", + expectedKey: "swarm-resource", expectedValue: "string-value", }, { - description: "string option returns value with equals", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=string-value=more", - expectedKey: "string", + description: "string option returns value with equals", + setFlag: "swarm-resource=string-value=more", + expectedKey: "swarm-resource", expectedValue: "string-value=more", }, { - description: "string option treats bool value as string", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=true", - expectedKey: "string", + description: "string option treats bool value as string", + setFlag: "swarm-resource=true", + expectedKey: "swarm-resource", expectedValue: "true", }, { - description: "string option treats int value as string", - config: map[string]interface{}{ - "string": "value", - }, - setFlag: "string=5", - expectedKey: "string", + description: "string option treats int value as string", + setFlag: "swarm-resource=5", + expectedKey: "swarm-resource", expectedValue: "5", }, { - description: "[]string option returns single value", - config: map[string]interface{}{ - "string": []string{"value"}, - }, - setFlag: "string=string-value", - expectedKey: "string", + description: "[]string option returns single value", + setFlag: "nvidia-container-cli.environment=string-value", + expectedKey: "nvidia-container-cli.environment", expectedValue: []string{"string-value"}, }, { - description: "[]string option returns multiple values", - config: map[string]interface{}{ - "string": []string{"value"}, - }, - setFlag: "string=first,second", - expectedKey: "string", + description: "[]string option returns multiple values", + setFlag: "nvidia-container-cli.environment=first,second", + expectedKey: "nvidia-container-cli.environment", expectedValue: []string{"first", "second"}, }, { - description: "[]string option returns values with equals", - config: map[string]interface{}{ - "string": []string{"value"}, - }, - setFlag: "string=first=1,second=2", - expectedKey: "string", + description: "[]string option returns values with equals", + setFlag: "nvidia-container-cli.environment=first=1,second=2", + expectedKey: "nvidia-container-cli.environment", expectedValue: []string{"first=1", "second=2"}, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { - tree, _ := toml.TreeFromMap(tc.config) - cfgToml := (*config.Toml)(tree) - k, v, err := (*configToml)(cfgToml).setFlagToKeyValue(tc.setFlag) + k, v, err := setFlagToKeyValue(tc.setFlag) require.ErrorIs(t, err, tc.expectedError) require.EqualValues(t, tc.expectedKey, k) require.EqualValues(t, tc.expectedValue, v)