mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Merge branch 'add-cdi-auto-mode' into 'main'
Add constants for CDI mode to nvcdi API See merge request nvidia/container-toolkit/container-toolkit!302
This commit is contained in:
		
						commit
						882fbb3209
					
				| @ -36,10 +36,6 @@ import ( | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	discoveryModeAuto = "auto" | ||||
| 	discoveryModeNVML = "nvml" | ||||
| 	discoveryModeWSL  = "wsl" | ||||
| 
 | ||||
| 	formatJSON = "json" | ||||
| 	formatYAML = "yaml" | ||||
| 
 | ||||
| @ -97,8 +93,8 @@ func (m command) build() *cli.Command { | ||||
| 		}, | ||||
| 		&cli.StringFlag{ | ||||
| 			Name:        "discovery-mode", | ||||
| 			Usage:       "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. I mode is set to 'auto' the mode will be determined based on the system configuration.", | ||||
| 			Value:       discoveryModeAuto, | ||||
| 			Usage:       "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.", | ||||
| 			Value:       nvcdi.ModeAuto, | ||||
| 			Destination: &cfg.discoveryMode, | ||||
| 		}, | ||||
| 		&cli.StringFlag{ | ||||
| @ -133,9 +129,9 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { | ||||
| 
 | ||||
| 	cfg.discoveryMode = strings.ToLower(cfg.discoveryMode) | ||||
| 	switch cfg.discoveryMode { | ||||
| 	case discoveryModeAuto: | ||||
| 	case discoveryModeNVML: | ||||
| 	case discoveryModeWSL: | ||||
| 	case nvcdi.ModeAuto: | ||||
| 	case nvcdi.ModeNvml: | ||||
| 	case nvcdi.ModeWsl: | ||||
| 	default: | ||||
| 		return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode) | ||||
| 	} | ||||
|  | ||||
| @ -22,6 +22,15 @@ import ( | ||||
| 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" | ||||
| ) | ||||
| 
 | ||||
| const ( | ||||
| 	// ModeAuto configures the CDI spec generator to automatically detect the system configuration
 | ||||
| 	ModeAuto = "auto" | ||||
| 	// ModeNvml configures the CDI spec generator to use the NVML library.
 | ||||
| 	ModeNvml = "nvml" | ||||
| 	// ModeWsl configures the CDI spec generator to generate a WSL spec.
 | ||||
| 	ModeWsl = "wsl" | ||||
| ) | ||||
| 
 | ||||
| // Interface defines the API for the nvcdi package
 | ||||
| type Interface interface { | ||||
| 	GetCommonEdits() (*cdi.ContainerEdits, error) | ||||
|  | ||||
| @ -31,6 +31,8 @@ type nvcdilib struct { | ||||
| 	deviceNamer   DeviceNamer | ||||
| 	driverRoot    string | ||||
| 	nvidiaCTKPath string | ||||
| 
 | ||||
| 	infolib info.Interface | ||||
| } | ||||
| 
 | ||||
| // New creates a new nvcdi library
 | ||||
| @ -40,7 +42,7 @@ func New(opts ...Option) Interface { | ||||
| 		opt(l) | ||||
| 	} | ||||
| 	if l.mode == "" { | ||||
| 		l.mode = "auto" | ||||
| 		l.mode = ModeAuto | ||||
| 	} | ||||
| 	if l.logger == nil { | ||||
| 		l.logger = logrus.StandardLogger() | ||||
| @ -54,9 +56,12 @@ func New(opts ...Option) Interface { | ||||
| 	if l.nvidiaCTKPath == "" { | ||||
| 		l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" | ||||
| 	} | ||||
| 	if l.infolib == nil { | ||||
| 		l.infolib = info.New() | ||||
| 	} | ||||
| 
 | ||||
| 	switch l.resolveMode() { | ||||
| 	case "nvml": | ||||
| 	case ModeNvml: | ||||
| 		if l.nvmllib == nil { | ||||
| 			l.nvmllib = nvml.New() | ||||
| 		} | ||||
| @ -65,7 +70,7 @@ func New(opts ...Option) Interface { | ||||
| 		} | ||||
| 
 | ||||
| 		return (*nvmllib)(l) | ||||
| 	case "wsl": | ||||
| 	case ModeWsl: | ||||
| 		return (*wsllib)(l) | ||||
| 	} | ||||
| 
 | ||||
| @ -75,21 +80,19 @@ func New(opts ...Option) Interface { | ||||
| 
 | ||||
| // resolveMode resolves the mode for CDI spec generation based on the current system.
 | ||||
| func (l *nvcdilib) resolveMode() (rmode string) { | ||||
| 	if l.mode != "auto" { | ||||
| 	if l.mode != ModeAuto { | ||||
| 		return l.mode | ||||
| 	} | ||||
| 	defer func() { | ||||
| 		l.logger.Infof("Auto-detected mode as %q", rmode) | ||||
| 	}() | ||||
| 
 | ||||
| 	nvinfo := info.New() | ||||
| 
 | ||||
| 	isWSL, reason := nvinfo.HasDXCore() | ||||
| 	isWSL, reason := l.infolib.HasDXCore() | ||||
| 	l.logger.Debugf("Is WSL-based system? %v: %v", isWSL, reason) | ||||
| 
 | ||||
| 	if isWSL { | ||||
| 		return "wsl" | ||||
| 		return ModeWsl | ||||
| 	} | ||||
| 
 | ||||
| 	return "nvml" | ||||
| 	return ModeNvml | ||||
| } | ||||
|  | ||||
							
								
								
									
										88
									
								
								pkg/nvcdi/lib_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								pkg/nvcdi/lib_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,88 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvcdi | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"testing" | ||||
| 
 | ||||
| 	testlog "github.com/sirupsen/logrus/hooks/test" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| ) | ||||
| 
 | ||||
| func TestResolveMode(t *testing.T) { | ||||
| 	logger, _ := testlog.NewNullLogger() | ||||
| 
 | ||||
| 	testCases := []struct { | ||||
| 		mode string | ||||
| 		// TODO: This should be a proper mock
 | ||||
| 		hasDXCore bool | ||||
| 		expected  string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			mode:      "auto", | ||||
| 			hasDXCore: true, | ||||
| 			expected:  "wsl", | ||||
| 		}, | ||||
| 		{ | ||||
| 			mode:      "auto", | ||||
| 			hasDXCore: false, | ||||
| 			expected:  "nvml", | ||||
| 		}, | ||||
| 		{ | ||||
| 			mode:      "nvml", | ||||
| 			hasDXCore: true, | ||||
| 			expected:  "nvml", | ||||
| 		}, | ||||
| 		{ | ||||
| 			mode:      "wsl", | ||||
| 			hasDXCore: false, | ||||
| 			expected:  "wsl", | ||||
| 		}, | ||||
| 		{ | ||||
| 			mode:      "not-auto", | ||||
| 			hasDXCore: true, | ||||
| 			expected:  "not-auto", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for i, tc := range testCases { | ||||
| 		t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { | ||||
| 			l := nvcdilib{ | ||||
| 				logger:  logger, | ||||
| 				mode:    tc.mode, | ||||
| 				infolib: infoMock(tc.hasDXCore), | ||||
| 			} | ||||
| 
 | ||||
| 			require.Equal(t, tc.expected, l.resolveMode()) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| type infoMock bool | ||||
| 
 | ||||
| func (i infoMock) HasDXCore() (bool, string) { | ||||
| 	return bool(i), "" | ||||
| } | ||||
| 
 | ||||
| func (i infoMock) HasNvml() (bool, string) { | ||||
| 	panic("should not be called") | ||||
| } | ||||
| 
 | ||||
| func (i infoMock) IsTegraSystem() (bool, string) { | ||||
| 	panic("should not be called") | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user