Use functional options to construct runtime mode resolver

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2025-06-13 15:47:47 +02:00
parent b33d475ff3
commit 8bfce9488d
No known key found for this signature in database
8 changed files with 110 additions and 22 deletions

View File

@ -242,7 +242,14 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
} }
} }
func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) { func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) {
hookConfig.Lock()
defer hookConfig.Unlock()
if hookConfig.containerConfig != nil {
return hookConfig.containerConfig
}
var h HookState var h HookState
d := json.NewDecoder(os.Stdin) d := json.NewDecoder(os.Stdin)
if err := d.Decode(&h); err != nil { if err := d.Decode(&h); err != nil {
@ -271,10 +278,13 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
log.Panicln(err) log.Panicln(err)
} }
return containerConfig{ cc := containerConfig{
Pid: h.Pid, Pid: h.Pid,
Rootfs: s.Root.Path, Rootfs: s.Root.Path,
Image: i, Image: i,
Nvidia: hookConfig.getNvidiaConfig(i, privileged), Nvidia: hookConfig.getNvidiaConfig(i, privileged),
} }
hookConfig.containerConfig = &cc
return hookConfig.containerConfig
} }

View File

@ -487,7 +487,7 @@ func TestGetNvidiaConfig(t *testing.T) {
hookCfg := tc.hookConfig hookCfg := tc.hookConfig
if hookCfg == nil { if hookCfg == nil {
defaultConfig, _ := config.GetDefault() defaultConfig, _ := config.GetDefault()
hookCfg = &hookConfig{defaultConfig} hookCfg = &hookConfig{Config: defaultConfig}
} }
cfg = hookCfg.getNvidiaConfig(image, tc.privileged) cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
} }

View File

@ -7,9 +7,11 @@ import (
"path" "path"
"reflect" "reflect"
"strings" "strings"
"sync"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
) )
const ( const (
@ -20,7 +22,9 @@ const (
// hookConfig wraps the toolkit config. // hookConfig wraps the toolkit config.
// This allows for functions to be defined on the local type. // This allows for functions to be defined on the local type.
type hookConfig struct { type hookConfig struct {
sync.Mutex
*config.Config *config.Config
containerConfig *containerConfig
} }
// loadConfig loads the required paths for the hook config. // loadConfig loads the required paths for the hook config.
@ -55,7 +59,7 @@ func getHookConfig() (*hookConfig, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err) return nil, fmt.Errorf("failed to load config: %v", err)
} }
config := &hookConfig{cfg} config := &hookConfig{Config: cfg}
allSupportedDriverCapabilities := image.SupportedDriverCapabilities allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" { if config.SupportedDriverCapabilities == "all" {
@ -73,8 +77,8 @@ func getHookConfig() (*hookConfig, error) {
// getConfigOption returns the toml config option associated with the // getConfigOption returns the toml config option associated with the
// specified struct field. // specified struct field.
func (c hookConfig) getConfigOption(fieldName string) string { func (c *hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(c) t := reflect.TypeOf(&c)
f, ok := t.FieldByName(fieldName) f, ok := t.FieldByName(fieldName)
if !ok { if !ok {
return fieldName return fieldName
@ -127,3 +131,20 @@ func (c *hookConfig) nvidiaContainerCliCUDACompatModeFlags() []string {
} }
return []string{flag} return []string{flag}
} }
func (c *hookConfig) assertModeIsLegacy() error {
if c.NVIDIAContainerRuntimeHookConfig.SkipModeDetection {
return nil
}
mr := info.NewRuntimeModeResolver(
info.WithLogger(&logInterceptor{}),
info.WithImage(&c.containerConfig.Image),
)
mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode)
if mode == "legacy" {
return nil
}
return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
}

View File

@ -90,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
} }
} }
var cfg hookConfig var cfg *hookConfig
getHookConfig := func() { getHookConfig := func() {
c, _ := getHookConfig() c, _ := getHookConfig()
cfg = *c cfg = c
} }
if tc.expectedPanic { if tc.expectedPanic {

View File

@ -55,7 +55,7 @@ func getCLIPath(config config.ContainerCLIConfig) string {
} }
// getRootfsPath returns an absolute path. We don't need to resolve symlinks for now. // getRootfsPath returns an absolute path. We don't need to resolve symlinks for now.
func getRootfsPath(config containerConfig) string { func getRootfsPath(config *containerConfig) string {
rootfs, err := filepath.Abs(config.Rootfs) rootfs, err := filepath.Abs(config.Rootfs)
if err != nil { if err != nil {
log.Panicln(err) log.Panicln(err)
@ -82,8 +82,8 @@ func doPrestart() {
return return
} }
if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" { if err := hook.assertModeIsLegacy(); err != nil {
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.") log.Panicf("%v", err)
} }
rootfs := getRootfsPath(container) rootfs := getRootfsPath(container)

View File

@ -23,27 +23,75 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
) )
// ResolveAutoMode determines the correct mode for the platform if set to "auto" type RuntimeModeResolver interface {
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) { ResolveRuntimeMode(string) string
return resolveMode(logger, mode, image, nil)
} }
func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) { type modeResolver struct {
logger logger.Interface
// TODO: This only needs to consider the requested devices.
image *image.CUDA
propertyExtractor info.PropertyExtractor
}
type Option func(*modeResolver)
func WithLogger(logger logger.Interface) Option {
return func(mr *modeResolver) {
mr.logger = logger
}
}
func WithImage(image *image.CUDA) Option {
return func(mr *modeResolver) {
mr.image = image
}
}
func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option {
return func(mr *modeResolver) {
mr.propertyExtractor = propertyExtractor
}
}
func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver {
r := &modeResolver{}
for _, opt := range opts {
opt(r)
}
if r.logger == nil {
r.logger = &logger.NullLogger{}
}
return r
}
// ResolveAutoMode determines the correct mode for the platform if set to "auto"
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) {
r := modeResolver{
logger: logger,
image: &image,
propertyExtractor: nil,
}
return r.ResolveRuntimeMode(mode)
}
func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode string) {
if mode != "auto" { if mode != "auto" {
logger.Infof("Using requested mode '%s'", mode) m.logger.Infof("Using requested mode '%s'", mode)
return mode return mode
} }
defer func() { defer func() {
logger.Infof("Auto-detected mode as '%v'", rmode) m.logger.Infof("Auto-detected mode as '%v'", rmode)
}() }()
if image.OnlyFullyQualifiedCDIDevices() { if m.image.OnlyFullyQualifiedCDIDevices() {
return "cdi" return "cdi"
} }
nvinfo := info.New( nvinfo := info.New(
info.WithLogger(logger), info.WithLogger(m.logger),
info.WithPropertyExtractor(propertyExtractor), info.WithPropertyExtractor(m.propertyExtractor),
) )
switch nvinfo.ResolvePlatform() { switch nvinfo.ResolvePlatform() {

View File

@ -251,7 +251,12 @@ func TestResolveAutoMode(t *testing.T) {
image.WithAcceptDeviceListAsVolumeMounts(true), image.WithAcceptDeviceListAsVolumeMounts(true),
image.WithAcceptEnvvarUnprivileged(true), image.WithAcceptEnvvarUnprivileged(true),
) )
mode := resolveMode(logger, tc.mode, image, properties) mr := NewRuntimeModeResolver(
WithLogger(logger),
WithImage(&image),
WithPropertyExtractor(properties),
)
mode := mr.ResolveRuntimeMode(tc.mode)
require.EqualValues(t, tc.expectedMode, mode) require.EqualValues(t, tc.expectedMode, mode)
}) })
} }

View File

@ -136,7 +136,11 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe
return "", nil, err return "", nil, err
} }
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image) modeResolver := info.NewRuntimeModeResolver(
info.WithLogger(logger),
info.WithImage(&image),
)
mode := modeResolver.ResolveRuntimeMode(cfg.NVIDIAContainerRuntimeConfig.Mode)
// We update the mode here so that we can continue passing just the config to other functions. // We update the mode here so that we can continue passing just the config to other functions.
cfg.NVIDIAContainerRuntimeConfig.Mode = mode cfg.NVIDIAContainerRuntimeConfig.Mode = mode