From 980ca5d1bc4cc3bca2b91dce099538705216defa Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Fri, 13 Jun 2025 15:47:47 +0200 Subject: [PATCH] Use functional options to construct runtime mode resolver Signed-off-by: Evan Lezar --- .../container_config.go | 14 +++- .../container_config_test.go | 2 +- .../hook_config.go | 27 +++++++- .../hook_config_test.go | 4 +- cmd/nvidia-container-runtime-hook/main.go | 6 +- internal/info/auto.go | 66 ++++++++++++++++--- internal/info/auto_test.go | 7 +- internal/runtime/runtime_factory.go | 6 +- 8 files changed, 110 insertions(+), 22 deletions(-) diff --git a/cmd/nvidia-container-runtime-hook/container_config.go b/cmd/nvidia-container-runtime-hook/container_config.go index cfb84745..16be7179 100644 --- a/cmd/nvidia-container-runtime-hook/container_config.go +++ b/cmd/nvidia-container-runtime-hook/container_config.go @@ -242,7 +242,14 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) } } -func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) { +func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) { + hookConfig.Lock() + defer hookConfig.Unlock() + + if hookConfig.containerConfig != nil { + return hookConfig.containerConfig + } + var h HookState d := json.NewDecoder(os.Stdin) if err := d.Decode(&h); err != nil { @@ -271,10 +278,13 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) { log.Panicln(err) } - return containerConfig{ + cc := containerConfig{ Pid: h.Pid, Rootfs: s.Root.Path, Image: i, Nvidia: hookConfig.getNvidiaConfig(i, privileged), } + hookConfig.containerConfig = &cc + + return hookConfig.containerConfig } diff --git a/cmd/nvidia-container-runtime-hook/container_config_test.go b/cmd/nvidia-container-runtime-hook/container_config_test.go index 5247e80f..8803ff86 100644 --- a/cmd/nvidia-container-runtime-hook/container_config_test.go +++ b/cmd/nvidia-container-runtime-hook/container_config_test.go @@ -487,7 +487,7 @@ func TestGetNvidiaConfig(t *testing.T) { hookCfg := tc.hookConfig if hookCfg == nil { defaultConfig, _ := config.GetDefault() - hookCfg = &hookConfig{defaultConfig} + hookCfg = &hookConfig{Config: defaultConfig} } cfg = hookCfg.getNvidiaConfig(image, tc.privileged) } diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index ec4e0434..c622d52b 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -7,9 +7,11 @@ import ( "path" "reflect" "strings" + "sync" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/NVIDIA/nvidia-container-toolkit/internal/info" ) const ( @@ -20,7 +22,9 @@ const ( // hookConfig wraps the toolkit config. // This allows for functions to be defined on the local type. type hookConfig struct { + sync.Mutex *config.Config + containerConfig *containerConfig } // loadConfig loads the required paths for the hook config. @@ -55,7 +59,7 @@ func getHookConfig() (*hookConfig, error) { if err != nil { return nil, fmt.Errorf("failed to load config: %v", err) } - config := &hookConfig{cfg} + config := &hookConfig{Config: cfg} allSupportedDriverCapabilities := image.SupportedDriverCapabilities if config.SupportedDriverCapabilities == "all" { @@ -73,8 +77,8 @@ func getHookConfig() (*hookConfig, error) { // getConfigOption returns the toml config option associated with the // specified struct field. -func (c hookConfig) getConfigOption(fieldName string) string { - t := reflect.TypeOf(c) +func (c *hookConfig) getConfigOption(fieldName string) string { + t := reflect.TypeOf(&c) f, ok := t.FieldByName(fieldName) if !ok { return fieldName @@ -127,3 +131,20 @@ func (c *hookConfig) nvidiaContainerCliCUDACompatModeFlags() []string { } return []string{flag} } + +func (c *hookConfig) assertModeIsLegacy() error { + if c.NVIDIAContainerRuntimeHookConfig.SkipModeDetection { + return nil + } + + mr := info.NewRuntimeModeResolver( + info.WithLogger(&logInterceptor{}), + info.WithImage(&c.containerConfig.Image), + ) + + mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode) + if mode == "legacy" { + return nil + } + return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead") +} diff --git a/cmd/nvidia-container-runtime-hook/hook_config_test.go b/cmd/nvidia-container-runtime-hook/hook_config_test.go index 19147ecf..149f7004 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config_test.go +++ b/cmd/nvidia-container-runtime-hook/hook_config_test.go @@ -90,10 +90,10 @@ func TestGetHookConfig(t *testing.T) { } } - var cfg hookConfig + var cfg *hookConfig getHookConfig := func() { c, _ := getHookConfig() - cfg = *c + cfg = c } if tc.expectedPanic { diff --git a/cmd/nvidia-container-runtime-hook/main.go b/cmd/nvidia-container-runtime-hook/main.go index c77fa390..7cdb0ad8 100644 --- a/cmd/nvidia-container-runtime-hook/main.go +++ b/cmd/nvidia-container-runtime-hook/main.go @@ -55,7 +55,7 @@ func getCLIPath(config config.ContainerCLIConfig) string { } // getRootfsPath returns an absolute path. We don't need to resolve symlinks for now. -func getRootfsPath(config containerConfig) string { +func getRootfsPath(config *containerConfig) string { rootfs, err := filepath.Abs(config.Rootfs) if err != nil { log.Panicln(err) @@ -82,8 +82,8 @@ func doPrestart() { return } - if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" { - log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") + if err := hook.assertModeIsLegacy(); err != nil { + log.Panicf("%v", err) } rootfs := getRootfsPath(container) diff --git a/internal/info/auto.go b/internal/info/auto.go index c6800da1..5246477c 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -23,27 +23,75 @@ import ( "github.com/NVIDIA/nvidia-container-toolkit/internal/logger" ) -// ResolveAutoMode determines the correct mode for the platform if set to "auto" -func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) { - return resolveMode(logger, mode, image, nil) +type RuntimeModeResolver interface { + ResolveRuntimeMode(string) string } -func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) { +type modeResolver struct { + logger logger.Interface + // TODO: This only needs to consider the requested devices. + image *image.CUDA + propertyExtractor info.PropertyExtractor +} + +type Option func(*modeResolver) + +func WithLogger(logger logger.Interface) Option { + return func(mr *modeResolver) { + mr.logger = logger + } +} + +func WithImage(image *image.CUDA) Option { + return func(mr *modeResolver) { + mr.image = image + } +} + +func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option { + return func(mr *modeResolver) { + mr.propertyExtractor = propertyExtractor + } +} + +func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver { + r := &modeResolver{} + for _, opt := range opts { + opt(r) + } + if r.logger == nil { + r.logger = &logger.NullLogger{} + } + + return r +} + +// ResolveAutoMode determines the correct mode for the platform if set to "auto" +func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) { + r := modeResolver{ + logger: logger, + image: &image, + propertyExtractor: nil, + } + return r.ResolveRuntimeMode(mode) +} + +func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode string) { if mode != "auto" { - logger.Infof("Using requested mode '%s'", mode) + m.logger.Infof("Using requested mode '%s'", mode) return mode } defer func() { - logger.Infof("Auto-detected mode as '%v'", rmode) + m.logger.Infof("Auto-detected mode as '%v'", rmode) }() - if image.OnlyFullyQualifiedCDIDevices() { + if m.image.OnlyFullyQualifiedCDIDevices() { return "cdi" } nvinfo := info.New( - info.WithLogger(logger), - info.WithPropertyExtractor(propertyExtractor), + info.WithLogger(m.logger), + info.WithPropertyExtractor(m.propertyExtractor), ) switch nvinfo.ResolvePlatform() { diff --git a/internal/info/auto_test.go b/internal/info/auto_test.go index c2ab93d7..25f14327 100644 --- a/internal/info/auto_test.go +++ b/internal/info/auto_test.go @@ -251,7 +251,12 @@ func TestResolveAutoMode(t *testing.T) { image.WithAcceptDeviceListAsVolumeMounts(true), image.WithAcceptEnvvarUnprivileged(true), ) - mode := resolveMode(logger, tc.mode, image, properties) + mr := NewRuntimeModeResolver( + WithLogger(logger), + WithImage(&image), + WithPropertyExtractor(properties), + ) + mode := mr.ResolveRuntimeMode(tc.mode) require.EqualValues(t, tc.expectedMode, mode) }) } diff --git a/internal/runtime/runtime_factory.go b/internal/runtime/runtime_factory.go index a4e992c1..8bb050a5 100644 --- a/internal/runtime/runtime_factory.go +++ b/internal/runtime/runtime_factory.go @@ -136,7 +136,11 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe return "", nil, err } - mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) + modeResolver := info.NewRuntimeModeResolver( + info.WithLogger(logger), + info.WithImage(&image), + ) + mode := modeResolver.ResolveRuntimeMode(cfg.NVIDIAContainerRuntimeConfig.Mode) // We update the mode here so that we can continue passing just the config to other functions. cfg.NVIDIAContainerRuntimeConfig.Mode = mode