diff --git a/cmd/nvidia-ctk/runtime/configure/configure.go b/cmd/nvidia-ctk/runtime/configure/configure.go index b4b53ed3..6f8f42f3 100644 --- a/cmd/nvidia-ctk/runtime/configure/configure.go +++ b/cmd/nvidia-ctk/runtime/configure/configure.go @@ -17,6 +17,7 @@ package configure import ( + "encoding/json" "fmt" "path/filepath" @@ -66,6 +67,8 @@ type config struct { mode string hookFilePath string + runtimeConfigOverrideJSON string + nvidiaRuntime struct { name string path string @@ -153,6 +156,13 @@ func (m command) build() *cli.Command { Usage: "Enable CDI in the configured runtime", Destination: &config.cdi.enabled, }, + &cli.StringFlag{ + Name: "runtime-config-override", + Destination: &config.runtimeConfigOverrideJSON, + Usage: "specify additional runtime options as a JSON string. The paths are relative to the runtime config.", + Value: "{}", + EnvVars: []string{"RUNTIME_CONFIG_OVERRIDE"}, + }, } return &configure @@ -194,6 +204,11 @@ func (m command) validateFlags(c *cli.Context, config *config) error { config.cdi.enabled = false } + if config.runtimeConfigOverrideJSON != "" && config.runtime != "containerd" { + m.logger.Warningf("Ignoring runtime-config-override flag for %v", config.runtime) + config.runtimeConfigOverrideJSON = "" + } + return nil } @@ -237,10 +252,16 @@ func (m command) configureConfigFile(c *cli.Context, config *config) error { return fmt.Errorf("unable to load config for runtime %v: %v", config.runtime, err) } + runtimeConfigOverride, err := config.runtimeConfigOverride() + if err != nil { + return fmt.Errorf("unable to parse config overrides: %w", err) + } + err = cfg.AddRuntime( config.nvidiaRuntime.name, config.nvidiaRuntime.path, config.nvidiaRuntime.setAsDefault, + runtimeConfigOverride, ) if err != nil { return fmt.Errorf("unable to update config: %v", err) @@ -293,6 +314,20 @@ func (c *config) getOuputConfigPath() string { return c.resolveConfigFilePath() } +// runtimeConfigOverride converts the specified runtimeConfigOverride JSON string to a map. +func (o *config) runtimeConfigOverride() (map[string]interface{}, error) { + if o.runtimeConfigOverrideJSON == "" { + return nil, nil + } + + runtimeOptions := make(map[string]interface{}) + if err := json.Unmarshal([]byte(o.runtimeConfigOverrideJSON), &runtimeOptions); err != nil { + return nil, fmt.Errorf("failed to read %v as JSON: %w", o.runtimeConfigOverrideJSON, err) + } + + return runtimeOptions, nil +} + // configureOCIHook creates and configures the OCI hook for the NVIDIA runtime func (m *command) configureOCIHook(c *cli.Context, config *config) error { err := ocihook.CreateHook(config.hookFilePath, config.nvidiaRuntime.hookPath) diff --git a/tools/container/container.go b/tools/container/container.go index c2c50c5b..8d09e734 100644 --- a/tools/container/container.go +++ b/tools/container/container.go @@ -67,8 +67,8 @@ func ParseArgs(c *cli.Context, o *Options) error { } // Configure applies the options to the specified config -func (o Options) Configure(cfg engine.Interface) error { - err := o.UpdateConfig(cfg) +func (o Options) Configure(cfg engine.Interface, configOverrides ...map[string]interface{}) error { + err := o.UpdateConfig(cfg, configOverrides...) if err != nil { return fmt.Errorf("unable to update config: %v", err) } @@ -98,14 +98,14 @@ func (o Options) flush(cfg engine.Interface) error { } // UpdateConfig updates the specified config to include the nvidia runtimes -func (o Options) UpdateConfig(cfg engine.Interface) error { +func (o Options) UpdateConfig(cfg engine.Interface, configOverrides ...map[string]interface{}) error { runtimes := operator.GetRuntimes( operator.WithNvidiaRuntimeName(o.RuntimeName), operator.WithSetAsDefault(o.SetAsDefault), operator.WithRoot(o.RuntimeDir), ) for name, runtime := range runtimes { - err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault) + err := cfg.AddRuntime(name, runtime.Path, runtime.SetAsDefault, configOverrides...) if err != nil { return fmt.Errorf("failed to update runtime %q: %v", name, err) } diff --git a/tools/container/containerd/containerd.go b/tools/container/containerd/containerd.go index bfe055e5..b5dff882 100644 --- a/tools/container/containerd/containerd.go +++ b/tools/container/containerd/containerd.go @@ -17,6 +17,7 @@ package main import ( + "encoding/json" "fmt" "os" @@ -47,6 +48,8 @@ type options struct { runtimeType string ContainerRuntimeModesCDIAnnotationPrefixes cli.StringSlice + + runtimeConfigOverrideJSON string } func main() { @@ -162,6 +165,13 @@ func main() { Destination: &options.ContainerRuntimeModesCDIAnnotationPrefixes, EnvVars: []string{"NVIDIA_CONTAINER_RUNTIME_MODES_CDI_ANNOTATION_PREFIXES"}, }, + &cli.StringFlag{ + Name: "runtime-config-override", + Destination: &options.runtimeConfigOverrideJSON, + Usage: "specify additional runtime options as a JSON string. The paths are relative to the runtime config.", + Value: "{}", + EnvVars: []string{"RUNTIME_CONFIG_OVERRIDE", "CONTAINERD_RUNTIME_CONFIG_OVERRIDE"}, + }, } // Update the subcommand flags with the common subcommand flags @@ -170,7 +180,7 @@ func main() { // Run the top-level CLI if err := c.Run(os.Args); err != nil { - log.Fatal(fmt.Errorf("Error: %v", err)) + log.Fatal(fmt.Errorf("error: %v", err)) } } @@ -188,7 +198,11 @@ func Setup(c *cli.Context, o *options) error { return fmt.Errorf("unable to load config: %v", err) } - err = o.Configure(cfg) + runtimeConfigOverride, err := o.runtimeConfigOverride() + if err != nil { + return fmt.Errorf("unable to parse config overrides: %w", err) + } + err = o.Configure(cfg, runtimeConfigOverride) if err != nil { return fmt.Errorf("unable to configure containerd: %v", err) } @@ -246,3 +260,16 @@ func (o *options) containerAnnotationsFromCDIPrefixes() []string { return annotations } + +func (o *options) runtimeConfigOverride() (map[string]interface{}, error) { + if o.runtimeConfigOverrideJSON == "" { + return nil, nil + } + + runtimeOptions := make(map[string]interface{}) + if err := json.Unmarshal([]byte(o.runtimeConfigOverrideJSON), &runtimeOptions); err != nil { + return nil, fmt.Errorf("failed to read %v as JSON: %w", o.runtimeConfigOverrideJSON, err) + } + + return runtimeOptions, nil +} diff --git a/tools/container/containerd/containerd_test.go b/tools/container/containerd/containerd_test.go new file mode 100644 index 00000000..7240a4dc --- /dev/null +++ b/tools/container/containerd/containerd_test.go @@ -0,0 +1,72 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRuntimeOptions(t *testing.T) { + testCases := []struct { + description string + options options + expected map[string]interface{} + expectedError error + }{ + { + description: "empty is nil", + }, + { + description: "empty json", + options: options{ + runtimeConfigOverrideJSON: "{}", + }, + expected: map[string]interface{}{}, + expectedError: nil, + }, + { + description: "SystemdCgroup is true", + options: options{ + runtimeConfigOverrideJSON: "{\"SystemdCgroup\": true}", + }, + expected: map[string]interface{}{ + "SystemdCgroup": true, + }, + expectedError: nil, + }, + { + description: "SystemdCgroup is false", + options: options{ + runtimeConfigOverrideJSON: "{\"SystemdCgroup\": false}", + }, + expected: map[string]interface{}{ + "SystemdCgroup": false, + }, + expectedError: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + runtimeOptions, err := tc.options.runtimeConfigOverride() + require.ErrorIs(t, tc.expectedError, err) + require.EqualValues(t, tc.expected, runtimeOptions) + }) + } +}