diff --git a/cmd/nvidia-container-runtime-hook/hook_config.go b/cmd/nvidia-container-runtime-hook/hook_config.go index c622d52b..603dd048 100644 --- a/cmd/nvidia-container-runtime-hook/hook_config.go +++ b/cmd/nvidia-container-runtime-hook/hook_config.go @@ -140,6 +140,7 @@ func (c *hookConfig) assertModeIsLegacy() error { mr := info.NewRuntimeModeResolver( info.WithLogger(&logInterceptor{}), info.WithImage(&c.containerConfig.Image), + info.WithDefaultMode(info.LegacyRuntimeMode), ) mode := mr.ResolveRuntimeMode(c.NVIDIAContainerRuntimeConfig.Mode) diff --git a/internal/info/auto.go b/internal/info/auto.go index ce64fc6e..3d69ad5c 100644 --- a/internal/info/auto.go +++ b/internal/info/auto.go @@ -55,10 +55,17 @@ type modeResolver struct { // TODO: This only needs to consider the requested devices. image *image.CUDA propertyExtractor info.PropertyExtractor + defaultMode RuntimeMode } type Option func(*modeResolver) +func WithDefaultMode(defaultMode RuntimeMode) Option { + return func(mr *modeResolver) { + mr.defaultMode = defaultMode + } +} + func WithLogger(logger logger.Interface) Option { return func(mr *modeResolver) { mr.logger = logger @@ -78,7 +85,9 @@ func WithPropertyExtractor(propertyExtractor info.PropertyExtractor) Option { } func NewRuntimeModeResolver(opts ...Option) RuntimeModeResolver { - r := &modeResolver{} + r := &modeResolver{ + defaultMode: JitCDIRuntimeMode, + } for _, opt := range opts { opt(r) } @@ -119,9 +128,9 @@ func (m *modeResolver) ResolveRuntimeMode(mode string) (rmode RuntimeMode) { switch nvinfo.ResolvePlatform() { case info.PlatformNVML, info.PlatformWSL: - return JitCDIRuntimeMode + return m.defaultMode case info.PlatformTegra: return CSVRuntimeMode } - return JitCDIRuntimeMode + return m.defaultMode }