mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Merge branch 'device-namer' into 'main'
Refactor device namer See merge request nvidia/container-toolkit/container-toolkit!453
This commit is contained in:
		
						commit
						80a78e60d1
					
				| @ -39,7 +39,7 @@ func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, erro | ||||
| 		return nil, fmt.Errorf("failed to get edits for device: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	name, err := l.deviceNamer.GetDeviceName(i, d) | ||||
| 	name, err := l.deviceNamer.GetDeviceName(i, convert{d}) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to get device name: %v", err) | ||||
| 	} | ||||
|  | ||||
| @ -53,8 +53,13 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) { | ||||
| 		return nil, fmt.Errorf("failed to create container edits for CSV files: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	name, err := l.deviceNamer.GetDeviceName(0, uuidUnsupported{}) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to get device name: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	deviceSpec := specs.Device{ | ||||
| 		Name:           "all", | ||||
| 		Name:           name, | ||||
| 		ContainerEdits: *e.ContainerEdits, | ||||
| 	} | ||||
| 	return []specs.Device{deviceSpec}, nil | ||||
|  | ||||
| @ -36,7 +36,7 @@ func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.Mi | ||||
| 		return nil, fmt.Errorf("failed to get edits for device: %v", err) | ||||
| 	} | ||||
| 
 | ||||
| 	name, err := l.deviceNamer.GetMigDeviceName(i, d, j, mig) | ||||
| 	name, err := l.deviceNamer.GetMigDeviceName(i, convert{d}, j, convert{mig}) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to get device name: %v", err) | ||||
| 	} | ||||
|  | ||||
| @ -17,16 +17,21 @@ | ||||
| package nvcdi | ||||
| 
 | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 
 | ||||
| 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" | ||||
| 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" | ||||
| ) | ||||
| 
 | ||||
| // UUIDer is an interface for getting UUIDs.
 | ||||
| type UUIDer interface { | ||||
| 	GetUUID() (string, error) | ||||
| } | ||||
| 
 | ||||
| // DeviceNamer is an interface for getting device names
 | ||||
| type DeviceNamer interface { | ||||
| 	GetDeviceName(int, device.Device) (string, error) | ||||
| 	GetMigDeviceName(int, device.Device, int, device.MigDevice) (string, error) | ||||
| 	GetDeviceName(int, UUIDer) (string, error) | ||||
| 	GetMigDeviceName(int, UUIDer, int, UUIDer) (string, error) | ||||
| } | ||||
| 
 | ||||
| // Supported device naming strategies
 | ||||
| @ -61,29 +66,57 @@ func NewDeviceNamer(strategy string) (DeviceNamer, error) { | ||||
| } | ||||
| 
 | ||||
| // GetDeviceName returns the name for the specified device based on the naming strategy
 | ||||
| func (s deviceNameIndex) GetDeviceName(i int, d device.Device) (string, error) { | ||||
| func (s deviceNameIndex) GetDeviceName(i int, _ UUIDer) (string, error) { | ||||
| 	return fmt.Sprintf("%s%d", s.gpuPrefix, i), nil | ||||
| } | ||||
| 
 | ||||
| // GetMigDeviceName returns the name for the specified device based on the naming strategy
 | ||||
| func (s deviceNameIndex) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) { | ||||
| func (s deviceNameIndex) GetMigDeviceName(i int, _ UUIDer, j int, _ UUIDer) (string, error) { | ||||
| 	return fmt.Sprintf("%s%d:%d", s.migPrefix, i, j), nil | ||||
| } | ||||
| 
 | ||||
| // GetDeviceName returns the name for the specified device based on the naming strategy
 | ||||
| func (s deviceNameUUID) GetDeviceName(i int, d device.Device) (string, error) { | ||||
| 	uuid, ret := d.GetUUID() | ||||
| 	if ret != nvml.SUCCESS { | ||||
| 		return "", fmt.Errorf("failed to get device UUID: %v", ret) | ||||
| func (s deviceNameUUID) GetDeviceName(i int, d UUIDer) (string, error) { | ||||
| 	uuid, err := d.GetUUID() | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("failed to get device UUID: %v", err) | ||||
| 	} | ||||
| 	return uuid, nil | ||||
| } | ||||
| 
 | ||||
| // GetMigDeviceName returns the name for the specified device based on the naming strategy
 | ||||
| func (s deviceNameUUID) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) { | ||||
| 	uuid, ret := mig.GetUUID() | ||||
| 	if ret != nvml.SUCCESS { | ||||
| 		return "", fmt.Errorf("failed to get device UUID: %v", ret) | ||||
| func (s deviceNameUUID) GetMigDeviceName(i int, _ UUIDer, j int, mig UUIDer) (string, error) { | ||||
| 	uuid, err := mig.GetUUID() | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("failed to get device UUID: %v", err) | ||||
| 	} | ||||
| 	return uuid, nil | ||||
| } | ||||
| 
 | ||||
| //go:generate moq -stub -out namer_nvml_mock.go . nvmlUUIDer
 | ||||
| type nvmlUUIDer interface { | ||||
| 	GetUUID() (string, nvml.Return) | ||||
| } | ||||
| 
 | ||||
| type convert struct { | ||||
| 	nvmlUUIDer | ||||
| } | ||||
| 
 | ||||
| type uuidUnsupported struct{} | ||||
| 
 | ||||
| func (m convert) GetUUID() (string, error) { | ||||
| 	if m.nvmlUUIDer == nil { | ||||
| 		return uuidUnsupported{}.GetUUID() | ||||
| 	} | ||||
| 	uuid, ret := m.nvmlUUIDer.GetUUID() | ||||
| 	if ret != nvml.SUCCESS { | ||||
| 		return "", ret | ||||
| 	} | ||||
| 	return uuid, nil | ||||
| } | ||||
| 
 | ||||
| var errUUIDUnsupported = errors.New("GetUUID is not supported") | ||||
| 
 | ||||
| func (m uuidUnsupported) GetUUID() (string, error) { | ||||
| 	return "", errUUIDUnsupported | ||||
| } | ||||
|  | ||||
							
								
								
									
										72
									
								
								pkg/nvcdi/namer_nvml_mock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								pkg/nvcdi/namer_nvml_mock.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,72 @@ | ||||
| // Code generated by moq; DO NOT EDIT.
 | ||||
| // github.com/matryer/moq
 | ||||
| 
 | ||||
| package nvcdi | ||||
| 
 | ||||
| import ( | ||||
| 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" | ||||
| 	"sync" | ||||
| ) | ||||
| 
 | ||||
| // Ensure, that nvmlUUIDerMock does implement nvmlUUIDer.
 | ||||
| // If this is not the case, regenerate this file with moq.
 | ||||
| var _ nvmlUUIDer = &nvmlUUIDerMock{} | ||||
| 
 | ||||
| // nvmlUUIDerMock is a mock implementation of nvmlUUIDer.
 | ||||
| //
 | ||||
| //	func TestSomethingThatUsesnvmlUUIDer(t *testing.T) {
 | ||||
| //
 | ||||
| //		// make and configure a mocked nvmlUUIDer
 | ||||
| //		mockednvmlUUIDer := &nvmlUUIDerMock{
 | ||||
| //			GetUUIDFunc: func() (string, nvml.Return) {
 | ||||
| //				panic("mock out the GetUUID method")
 | ||||
| //			},
 | ||||
| //		}
 | ||||
| //
 | ||||
| //		// use mockednvmlUUIDer in code that requires nvmlUUIDer
 | ||||
| //		// and then make assertions.
 | ||||
| //
 | ||||
| //	}
 | ||||
| type nvmlUUIDerMock struct { | ||||
| 	// GetUUIDFunc mocks the GetUUID method.
 | ||||
| 	GetUUIDFunc func() (string, nvml.Return) | ||||
| 
 | ||||
| 	// calls tracks calls to the methods.
 | ||||
| 	calls struct { | ||||
| 		// GetUUID holds details about calls to the GetUUID method.
 | ||||
| 		GetUUID []struct { | ||||
| 		} | ||||
| 	} | ||||
| 	lockGetUUID sync.RWMutex | ||||
| } | ||||
| 
 | ||||
| // GetUUID calls GetUUIDFunc.
 | ||||
| func (mock *nvmlUUIDerMock) GetUUID() (string, nvml.Return) { | ||||
| 	callInfo := struct { | ||||
| 	}{} | ||||
| 	mock.lockGetUUID.Lock() | ||||
| 	mock.calls.GetUUID = append(mock.calls.GetUUID, callInfo) | ||||
| 	mock.lockGetUUID.Unlock() | ||||
| 	if mock.GetUUIDFunc == nil { | ||||
| 		var ( | ||||
| 			sOut      string | ||||
| 			returnOut nvml.Return | ||||
| 		) | ||||
| 		return sOut, returnOut | ||||
| 	} | ||||
| 	return mock.GetUUIDFunc() | ||||
| } | ||||
| 
 | ||||
| // GetUUIDCalls gets all the calls that were made to GetUUID.
 | ||||
| // Check the length with:
 | ||||
| //
 | ||||
| //	len(mockednvmlUUIDer.GetUUIDCalls())
 | ||||
| func (mock *nvmlUUIDerMock) GetUUIDCalls() []struct { | ||||
| } { | ||||
| 	var calls []struct { | ||||
| 	} | ||||
| 	mock.lockGetUUID.RLock() | ||||
| 	calls = mock.calls.GetUUID | ||||
| 	mock.lockGetUUID.RUnlock() | ||||
| 	return calls | ||||
| } | ||||
							
								
								
									
										67
									
								
								pkg/nvcdi/namer_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								pkg/nvcdi/namer_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,67 @@ | ||||
| /** | ||||
| # Copyright (c) NVIDIA CORPORATION.  All rights reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| **/ | ||||
| 
 | ||||
| package nvcdi | ||||
| 
 | ||||
| import ( | ||||
| 	"testing" | ||||
| 
 | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" | ||||
| ) | ||||
| 
 | ||||
| func TestConvert(t *testing.T) { | ||||
| 	testCases := []struct { | ||||
| 		description   string | ||||
| 		nvml          nvmlUUIDer | ||||
| 		expectedError error | ||||
| 		expecteUUID   string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			description:   "empty UUIDer returns error", | ||||
| 			expectedError: errUUIDUnsupported, | ||||
| 			expecteUUID:   "", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "nvmlUUIDer returns UUID", | ||||
| 			nvml: &nvmlUUIDerMock{ | ||||
| 				GetUUIDFunc: func() (string, nvml.Return) { | ||||
| 					return "SOME_UUID", nvml.SUCCESS | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedError: nil, | ||||
| 			expecteUUID:   "SOME_UUID", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "nvmlUUIDer returns error", | ||||
| 			nvml: &nvmlUUIDerMock{ | ||||
| 				GetUUIDFunc: func() (string, nvml.Return) { | ||||
| 					return "SOME_UUID", nvml.ERROR_UNKNOWN | ||||
| 				}, | ||||
| 			}, | ||||
| 			expectedError: nvml.ERROR_UNKNOWN, | ||||
| 			expecteUUID:   "", | ||||
| 		}, | ||||
| 	} | ||||
| 
 | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.description, func(t *testing.T) { | ||||
| 			uuid, err := convert{tc.nvml}.GetUUID() | ||||
| 			require.ErrorIs(t, err, tc.expectedError) | ||||
| 			require.Equal(t, tc.expecteUUID, uuid) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user