diff --git a/tools/container/runtime/containerd/containerd.go b/tools/container/runtime/containerd/containerd.go index 6a5df5ad..a53a0655 100644 --- a/tools/container/runtime/containerd/containerd.go +++ b/tools/container/runtime/containerd/containerd.go @@ -25,6 +25,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine" "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine/containerd" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/toml" "github.com/NVIDIA/nvidia-container-toolkit/tools/container" ) @@ -84,13 +85,7 @@ func Flags(opts *Options) []cli.Flag { func Setup(c *cli.Context, o *container.Options, co *Options) error { log.Infof("Starting 'setup' for %v", c.App.Name) - cfg, err := containerd.New( - containerd.WithPath(o.Config), - containerd.WithConfigSource(containerd.CommandLineSource(o.HostRootMount)), - containerd.WithRuntimeType(co.runtimeType), - containerd.WithUseLegacyConfig(co.useLegacyConfig), - containerd.WithContainerAnnotations(co.containerAnnotationsFromCDIPrefixes()...), - ) + cfg, err := getRuntimeConfig(o, co) if err != nil { return fmt.Errorf("unable to load config: %v", err) } @@ -114,13 +109,7 @@ func Setup(c *cli.Context, o *container.Options, co *Options) error { func Cleanup(c *cli.Context, o *container.Options, co *Options) error { log.Infof("Starting 'cleanup' for %v", c.App.Name) - cfg, err := containerd.New( - containerd.WithPath(o.Config), - containerd.WithConfigSource(containerd.CommandLineSource(o.HostRootMount)), - containerd.WithRuntimeType(co.runtimeType), - containerd.WithUseLegacyConfig(co.useLegacyConfig), - containerd.WithContainerAnnotations(co.containerAnnotationsFromCDIPrefixes()...), - ) + cfg, err := getRuntimeConfig(o, co) if err != nil { return fmt.Errorf("unable to load config: %v", err) } @@ -169,13 +158,24 @@ func (o *Options) runtimeConfigOverride() (map[string]interface{}, error) { } func GetLowlevelRuntimePaths(o *container.Options, co *Options) ([]string, error) { - cfg, err := containerd.New( - containerd.WithConfigSource(containerd.CommandLineSource(o.HostRootMount)), - containerd.WithRuntimeType(co.runtimeType), - containerd.WithUseLegacyConfig(co.useLegacyConfig), - ) + cfg, err := getRuntimeConfig(o, co) if err != nil { return nil, fmt.Errorf("unable to load containerd config: %w", err) } return engine.GetBinaryPathsForRuntimes(cfg), nil } + +func getRuntimeConfig(o *container.Options, co *Options) (engine.Interface, error) { + return containerd.New( + containerd.WithPath(o.Config), + containerd.WithConfigSource( + toml.LoadFirst( + containerd.CommandLineSource(o.HostRootMount), + toml.FromFile(o.Config), + ), + ), + containerd.WithRuntimeType(co.runtimeType), + containerd.WithUseLegacyConfig(co.useLegacyConfig), + containerd.WithContainerAnnotations(co.containerAnnotationsFromCDIPrefixes()...), + ) +} diff --git a/tools/container/runtime/crio/crio.go b/tools/container/runtime/crio/crio.go index e3fa77f6..e37fc4ba 100644 --- a/tools/container/runtime/crio/crio.go +++ b/tools/container/runtime/crio/crio.go @@ -28,6 +28,7 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine" "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/engine/crio" "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/ocihook" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/config/toml" "github.com/NVIDIA/nvidia-container-toolkit/tools/container" ) @@ -116,10 +117,7 @@ func setupHook(o *container.Options, co *Options) error { func setupConfig(o *container.Options) error { log.Infof("Updating config file") - cfg, err := crio.New( - crio.WithPath(o.Config), - crio.WithConfigSource(crio.CommandLineSource(o.HostRootMount)), - ) + cfg, err := getRuntimeConfig(o) if err != nil { return fmt.Errorf("unable to load config: %v", err) } @@ -168,10 +166,7 @@ func cleanupHook(co *Options) error { func cleanupConfig(o *container.Options) error { log.Infof("Reverting config file modifications") - cfg, err := crio.New( - crio.WithPath(o.Config), - crio.WithConfigSource(crio.CommandLineSource(o.HostRootMount)), - ) + cfg, err := getRuntimeConfig(o) if err != nil { return fmt.Errorf("unable to load config: %v", err) } @@ -195,11 +190,21 @@ func RestartCrio(o *container.Options) error { } func GetLowlevelRuntimePaths(o *container.Options) ([]string, error) { - cfg, err := crio.New( - crio.WithConfigSource(crio.CommandLineSource(o.HostRootMount)), - ) + cfg, err := getRuntimeConfig(o) if err != nil { return nil, fmt.Errorf("unable to load crio config: %w", err) } return engine.GetBinaryPathsForRuntimes(cfg), nil } + +func getRuntimeConfig(o *container.Options) (engine.Interface, error) { + return crio.New( + crio.WithPath(o.Config), + crio.WithConfigSource( + toml.LoadFirst( + crio.CommandLineSource(o.HostRootMount), + toml.FromFile(o.Config), + ), + ), + ) +} diff --git a/tools/container/runtime/docker/docker.go b/tools/container/runtime/docker/docker.go index 0d6a4ff8..fd5a2750 100644 --- a/tools/container/runtime/docker/docker.go +++ b/tools/container/runtime/docker/docker.go @@ -45,9 +45,7 @@ func Flags(opts *Options) []cli.Flag { func Setup(c *cli.Context, o *container.Options) error { log.Infof("Starting 'setup' for %v", c.App.Name) - cfg, err := docker.New( - docker.WithPath(o.Config), - ) + cfg, err := getRuntimeConfig(o) if err != nil { return fmt.Errorf("unable to load config: %v", err) } @@ -71,9 +69,7 @@ func Setup(c *cli.Context, o *container.Options) error { func Cleanup(c *cli.Context, o *container.Options) error { log.Infof("Starting 'cleanup' for %v", c.App.Name) - cfg, err := docker.New( - docker.WithPath(o.Config), - ) + cfg, err := getRuntimeConfig(o) if err != nil { return fmt.Errorf("unable to load config: %v", err) } @@ -107,3 +103,9 @@ func GetLowlevelRuntimePaths(o *container.Options) ([]string, error) { } return engine.GetBinaryPathsForRuntimes(cfg), nil } + +func getRuntimeConfig(o *container.Options) (engine.Interface, error) { + return docker.New( + docker.WithPath(o.Config), + ) +}