diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index c622d52b..ffb08b88 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -140,6 +140,7 @@ func (c *hookConfig) assertModeIsLegacy() error { mr := info.NewRuntimeModeResolver( info.WithLogger(&logInterceptor{}), info.WithImage(&c.containerConfig.Image), + info.WithDefaultMode(info.RuntimeModeLegacy), ) mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode) diff --git a/internal/info/auto.go b/internal/info/auto.go index 021a12f7..c321b9a7 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -42,10 +42,17 @@ type modeResolver struct { // TODO: This only needs to consider the requested devices. image *image.CUDA propertyExtractor info.PropertyExtractor + defaultMode RuntimeMode } type Option func(*modeResolver) +func WithDefaultMode(defaultMode RuntimeMode) Option { + return func(mr *modeResolver) { + mr.defaultMode = defaultMode + } +} + func WithLogger(logger logger.Interface) Option { return func(mr *modeResolver) { mr.logger = logger @@ -65,7 +72,9 @@ func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option { } func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver { - r := &modeResolver{} + r := &modeResolver{ + defaultMode: RuntimeModeJitCDI, + } for _, opt := range opts { opt(r) } @@ -106,9 +115,9 @@ func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode RuntimeMode) { switch nvinfo.ResolvePlatform() { case info.PlatformNVML, info.PlatformWSL: - return RuntimeModeJitCDI + return m.defaultMode case info.PlatformTegra: return RuntimeModeCSV } - return RuntimeModeJitCDI + return m.defaultMode }