From 1267c1d9a2929ac36da06cb078a3a23aef2a6cba Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 14 Jul 2022 16:19:19 +0200 Subject: [PATCH] Refactor docker config update This change updates the docker config update for simplicitly. This also allows for the API to match the crio update code. Signed-off-by: Evan Lezar --- cmd/nvidia-ctk/runtime/configure/configure.go | 9 ++- internal/config/docker/docker.go | 16 ++--- internal/config/docker/docker_test.go | 61 ++++++++----------- tools/container/docker/docker.go | 24 +++----- 4 files changed, 47 insertions(+), 63 deletions(-) diff --git a/cmd/nvidia-ctk/runtime/configure/configure.go b/cmd/nvidia-ctk/runtime/configure/configure.go index 13689986..a0d66afc 100644 --- a/cmd/nvidia-ctk/runtime/configure/configure.go +++ b/cmd/nvidia-ctk/runtime/configure/configure.go @@ -132,9 +132,12 @@ func (m command) configureDocker(c *cli.Context, config *config) error { return fmt.Errorf("unable to load config: %v", err) } - defaultRuntime := config.nvidiaOptions.DefaultRuntime() - runtimeConfig := config.nvidiaOptions.Runtime().DockerRuntimesConfig() - err = docker.UpdateConfig(cfg, defaultRuntime, runtimeConfig) + err = docker.UpdateConfig( + cfg, + config.nvidiaOptions.RuntimeName, + config.nvidiaOptions.RuntimePath, + config.nvidiaOptions.SetAsDefault, + ) if err != nil { return fmt.Errorf("unable to update config: %v", err) } diff --git a/internal/config/docker/docker.go b/internal/config/docker/docker.go index efe65aef..707e7923 100644 --- a/internal/config/docker/docker.go +++ b/internal/config/docker/docker.go @@ -57,11 +57,7 @@ func LoadConfig(configFilePath string) (map[string]interface{}, error) { } // UpdateConfig updates the docker config to include the nvidia runtimes -func UpdateConfig(config map[string]interface{}, defaultRuntime string, newRuntimes map[string]interface{}) error { - if defaultRuntime != "" { - config["default-runtime"] = defaultRuntime - } - +func UpdateConfig(config map[string]interface{}, runtimeName string, runtimePath string, setAsDefault bool) error { // Read the existing runtimes runtimes := make(map[string]interface{}) if _, exists := config["runtimes"]; exists { @@ -69,14 +65,20 @@ func UpdateConfig(config map[string]interface{}, defaultRuntime string, newRunti } // Add / update the runtime definitions - for name, rt := range newRuntimes { - runtimes[name] = rt + runtimes[runtimeName] = map[string]interface{}{ + "path": runtimePath, + "args": []string{}, } // Update the runtimes definition if len(runtimes) > 0 { config["runtimes"] = runtimes } + + if setAsDefault { + config["default-runtime"] = runtimeName + } + return nil } diff --git a/internal/config/docker/docker_test.go b/internal/config/docker/docker_test.go index 9c4e961e..c5b7128f 100644 --- a/internal/config/docker/docker_test.go +++ b/internal/config/docker/docker_test.go @@ -27,30 +27,33 @@ import ( func TestUpdateConfigDefaultRuntime(t *testing.T) { testCases := []struct { config map[string]interface{} - defaultRuntime string runtimeName string + setAsDefault bool expectedDefaultRuntimeName interface{} }{ { - defaultRuntime: "", + setAsDefault: false, expectedDefaultRuntimeName: nil, }, { - defaultRuntime: "NAME", + runtimeName: "NAME", + setAsDefault: true, expectedDefaultRuntimeName: "NAME", }, { config: map[string]interface{}{ "default-runtime": "ALREADY_SET", }, - defaultRuntime: "", + runtimeName: "NAME", + setAsDefault: false, expectedDefaultRuntimeName: "ALREADY_SET", }, { config: map[string]interface{}{ "default-runtime": "ALREADY_SET", }, - defaultRuntime: "NAME", + runtimeName: "NAME", + setAsDefault: true, expectedDefaultRuntimeName: "NAME", }, } @@ -60,7 +63,7 @@ func TestUpdateConfigDefaultRuntime(t *testing.T) { if tc.config == nil { tc.config = make(map[string]interface{}) } - err := UpdateConfig(tc.config, tc.defaultRuntime, nil) + err := UpdateConfig(tc.config, tc.runtimeName, "", tc.setAsDefault) require.NoError(t, err) defaultRuntimeName := tc.config["default-runtime"] @@ -72,20 +75,14 @@ func TestUpdateConfigDefaultRuntime(t *testing.T) { func TestUpdateConfigRuntimes(t *testing.T) { testCases := []struct { config map[string]interface{} - runtimes map[string]interface{} + runtimes map[string]string expectedConfig map[string]interface{} }{ { config: map[string]interface{}{}, - runtimes: map[string]interface{}{ - "runtime1": map[string]interface{}{ - "path": "/test/runtime/dir/runtime1", - "args": []string{}, - }, - "runtime2": map[string]interface{}{ - "path": "/test/runtime/dir/runtime2", - "args": []string{}, - }, + runtimes: map[string]string{ + "runtime1": "/test/runtime/dir/runtime1", + "runtime2": "/test/runtime/dir/runtime2", }, expectedConfig: map[string]interface{}{ "runtimes": map[string]interface{}{ @@ -109,15 +106,9 @@ func TestUpdateConfigRuntimes(t *testing.T) { }, }, }, - runtimes: map[string]interface{}{ - "runtime1": map[string]interface{}{ - "path": "/test/runtime/dir/runtime1", - "args": []string{}, - }, - "runtime2": map[string]interface{}{ - "path": "/test/runtime/dir/runtime2", - "args": []string{}, - }, + runtimes: map[string]string{ + "runtime1": "/test/runtime/dir/runtime1", + "runtime2": "/test/runtime/dir/runtime2", }, expectedConfig: map[string]interface{}{ "runtimes": map[string]interface{}{ @@ -141,11 +132,8 @@ func TestUpdateConfigRuntimes(t *testing.T) { }, }, }, - runtimes: map[string]interface{}{ - "runtime1": map[string]interface{}{ - "path": "/test/runtime/dir/runtime1", - "args": []string{}, - }, + runtimes: map[string]string{ + "runtime1": "/test/runtime/dir/runtime1", }, expectedConfig: map[string]interface{}{ "runtimes": map[string]interface{}{ @@ -169,11 +157,8 @@ func TestUpdateConfigRuntimes(t *testing.T) { }, "storage-driver": "overlay2", }, - runtimes: map[string]interface{}{ - "runtime1": map[string]interface{}{ - "path": "/test/runtime/dir/runtime1", - "args": []string{}, - }, + runtimes: map[string]string{ + "runtime1": "/test/runtime/dir/runtime1", }, expectedConfig: map[string]interface{}{ "exec-opts": []string{"native.cgroupdriver=systemd"}, @@ -212,8 +197,10 @@ func TestUpdateConfigRuntimes(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { - err := UpdateConfig(tc.config, "", tc.runtimes) - require.NoError(t, err) + for runtimeName, runtimePath := range tc.runtimes { + err := UpdateConfig(tc.config, runtimeName, runtimePath, false) + require.NoError(t, err) + } configContent, err := json.MarshalIndent(tc.config, "", " ") require.NoError(t, err) diff --git a/tools/container/docker/docker.go b/tools/container/docker/docker.go index ce6e7dca..e64b8ba9 100644 --- a/tools/container/docker/docker.go +++ b/tools/container/docker/docker.go @@ -250,10 +250,15 @@ func LoadConfig(config string) (map[string]interface{}, error) { // UpdateConfig updates the docker config to include the nvidia runtimes func UpdateConfig(config map[string]interface{}, o *options) error { - defaultRuntime := o.getDefaultRuntime() - runtimes := o.runtimes() + for runtimeName, runtimePath := range o.getRuntimeBinaries() { + setAsDefault := runtimeName == o.getDefaultRuntime() + err := docker.UpdateConfig(config, runtimeName, runtimePath, setAsDefault) + if err != nil { + return fmt.Errorf("failed to update runtime %q: %v", runtimeName, err) + } + } - return docker.UpdateConfig(config, defaultRuntime, runtimes) + return nil } //RevertConfig reverts the docker config to remove the nvidia runtime @@ -392,19 +397,6 @@ func (o options) getDefaultRuntime() string { return o.runtimeName } -// runtimes returns the docker runtime definitions for the supported nvidia runtimes -// for the given options. This includes the path with the options runtimeDir applied -func (o options) runtimes() map[string]interface{} { - runtimes := make(map[string]interface{}) - for r, bin := range o.getRuntimeBinaries() { - runtimes[r] = map[string]interface{}{ - "path": bin, - "args": []string{}, - } - } - return runtimes -} - // getRuntimeBinaries returns a map of runtime names to binary paths. This includes the // renaming of the `nvidia` runtime as per the --runtime-class command line flag. func (o options) getRuntimeBinaries() map[string]string {