Use functional options to construct runtime mode resolver

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2025-06-13 15:47:47 +02:00
parent b33d475ff3
commit 8bfce9488d
No known key found for this signature in database
8 changed files with 110 additions and 22 deletions

View File

@ -242,7 +242,14 @@ func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool)
}
}
func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
func (hookConfig *hookConfig) getContainerConfig() (config *containerConfig) {
hookConfig.Lock()
defer hookConfig.Unlock()
if hookConfig.containerConfig != nil {
return hookConfig.containerConfig
}
var h HookState
d := json.NewDecoder(os.Stdin)
if err := d.Decode(&h); err != nil {
@ -271,10 +278,13 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
log.Panicln(err)
}
return containerConfig{
cc := containerConfig{
Pid: h.Pid,
Rootfs: s.Root.Path,
Image: i,
Nvidia: hookConfig.getNvidiaConfig(i, privileged),
}
hookConfig.containerConfig = &cc
return hookConfig.containerConfig
}

View File

@ -487,7 +487,7 @@ func TestGetNvidiaConfig(t *testing.T) {
hookCfg := tc.hookConfig
if hookCfg == nil {
defaultConfig, _ := config.GetDefault()
hookCfg = &hookConfig{defaultConfig}
hookCfg = &hookConfig{Config: defaultConfig}
}
cfg = hookCfg.getNvidiaConfig(image, tc.privileged)
}

View File

@ -7,9 +7,11 @@ import (
"path"
"reflect"
"strings"
"sync"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
)
const (
@ -20,7 +22,9 @@ const (
// hookConfig wraps the toolkit config.
// This allows for functions to be defined on the local type.
type hookConfig struct {
sync.Mutex
*config.Config
containerConfig *containerConfig
}
// loadConfig loads the required paths for the hook config.
@ -55,7 +59,7 @@ func getHookConfig() (*hookConfig, error) {
if err != nil {
return nil, fmt.Errorf("failed to load config: %v", err)
}
config := &hookConfig{cfg}
config := &hookConfig{Config: cfg}
allSupportedDriverCapabilities := image.SupportedDriverCapabilities
if config.SupportedDriverCapabilities == "all" {
@ -73,8 +77,8 @@ func getHookConfig() (*hookConfig, error) {
// getConfigOption returns the toml config option associated with the
// specified struct field.
func (c hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(c)
func (c *hookConfig) getConfigOption(fieldName string) string {
t := reflect.TypeOf(&c)
f, ok := t.FieldByName(fieldName)
if !ok {
return fieldName
@ -127,3 +131,20 @@ func (c *hookConfig) nvidiaContainerCliCUDACompatModeFlags() []string {
}
return []string{flag}
}
func (c *hookConfig) assertModeIsLegacy() error {
if c.NVIDIAContainerRuntimeHookConfig.SkipModeDetection {
return nil
}
mr := info.NewRuntimeModeResolver(
info.WithLogger(&logInterceptor{}),
info.WithImage(&c.containerConfig.Image),
)
mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode)
if mode == "legacy" {
return nil
}
return fmt.Errorf("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead")
}

View File

@ -90,10 +90,10 @@ func TestGetHookConfig(t *testing.T) {
}
}
var cfg hookConfig
var cfg *hookConfig
getHookConfig := func() {
c, _ := getHookConfig()
cfg = *c
cfg = c
}
if tc.expectedPanic {

View File

@ -55,7 +55,7 @@ func getCLIPath(config config.ContainerCLIConfig) string {
}
// getRootfsPath returns an absolute path. We don't need to resolve symlinks for now.
func getRootfsPath(config containerConfig) string {
func getRootfsPath(config *containerConfig) string {
rootfs, err := filepath.Abs(config.Rootfs)
if err != nil {
log.Panicln(err)
@ -82,8 +82,8 @@ func doPrestart() {
return
}
if !hook.NVIDIAContainerRuntimeHookConfig.SkipModeDetection && info.ResolveAutoMode(&logInterceptor{}, hook.NVIDIAContainerRuntimeConfig.Mode, container.Image) != "legacy" {
log.Panicln("invoking the NVIDIA Container Runtime Hook directly (e.g. specifying the docker --gpus flag) is not supported. Please use the NVIDIA Container Runtime (e.g. specify the --runtime=nvidia flag) instead.")
if err := hook.assertModeIsLegacy(); err != nil {
log.Panicf("%v", err)
}
rootfs := getRootfsPath(container)

View File

@ -23,27 +23,75 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
// ResolveAutoMode determines the correct mode for the platform if set to "auto"
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) {
return resolveMode(logger, mode, image, nil)
type RuntimeModeResolver interface {
ResolveRuntimeMode(string) string
}
func resolveMode(logger logger.Interface, mode string, image image.CUDA, propertyExtractor info.PropertyExtractor) (rmode string) {
type modeResolver struct {
logger logger.Interface
// TODO: This only needs to consider the requested devices.
image *image.CUDA
propertyExtractor info.PropertyExtractor
}
type Option func(*modeResolver)
func WithLogger(logger logger.Interface) Option {
return func(mr *modeResolver) {
mr.logger = logger
}
}
func WithImage(image *image.CUDA) Option {
return func(mr *modeResolver) {
mr.image = image
}
}
func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option {
return func(mr *modeResolver) {
mr.propertyExtractor = propertyExtractor
}
}
func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver {
r := &modeResolver{}
for _, opt := range opts {
opt(r)
}
if r.logger == nil {
r.logger = &logger.NullLogger{}
}
return r
}
// ResolveAutoMode determines the correct mode for the platform if set to "auto"
func ResolveAutoMode(logger logger.Interface, mode string, image image.CUDA) (rmode string) {
r := modeResolver{
logger: logger,
image: &image,
propertyExtractor: nil,
}
return r.ResolveRuntimeMode(mode)
}
func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode string) {
if mode != "auto" {
logger.Infof("Using requested mode '%s'", mode)
m.logger.Infof("Using requested mode '%s'", mode)
return mode
}
defer func() {
logger.Infof("Auto-detected mode as '%v'", rmode)
m.logger.Infof("Auto-detected mode as '%v'", rmode)
}()
if image.OnlyFullyQualifiedCDIDevices() {
if m.image.OnlyFullyQualifiedCDIDevices() {
return "cdi"
}
nvinfo := info.New(
info.WithLogger(logger),
info.WithPropertyExtractor(propertyExtractor),
info.WithLogger(m.logger),
info.WithPropertyExtractor(m.propertyExtractor),
)
switch nvinfo.ResolvePlatform() {

View File

@ -251,7 +251,12 @@ func TestResolveAutoMode(t *testing.T) {
image.WithAcceptDeviceListAsVolumeMounts(true),
image.WithAcceptEnvvarUnprivileged(true),
)
mode := resolveMode(logger, tc.mode, image, properties)
mr := NewRuntimeModeResolver(
WithLogger(logger),
WithImage(&image),
WithPropertyExtractor(properties),
)
mode := mr.ResolveRuntimeMode(tc.mode)
require.EqualValues(t, tc.expectedMode, mode)
})
}

View File

@ -136,7 +136,11 @@ func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpe
return "", nil, err
}
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
modeResolver := info.NewRuntimeModeResolver(
info.WithLogger(logger),
info.WithImage(&image),
)
mode := modeResolver.ResolveRuntimeMode(cfg.NVIDIAContainerRuntimeConfig.Mode)
// We update the mode here so that we can continue passing just the config to other functions.
cfg.NVIDIAContainerRuntimeConfig.Mode = mode