Merge branch 'device-namer' into 'main'

Refactor device namer

See merge request nvidia/container-toolkit/container-toolkit!453
This commit is contained in:
Evan Lezar 2023-07-18 14:16:01 +00:00
commit 80a78e60d1
6 changed files with 193 additions and 16 deletions

View File

@ -39,7 +39,7 @@ func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, erro
return nil, fmt.Errorf("failed to get edits for device: %v", err)
}
name, err := l.deviceNamer.GetDeviceName(i, d)
name, err := l.deviceNamer.GetDeviceName(i, convert{d})
if err != nil {
return nil, fmt.Errorf("failed to get device name: %v", err)
}

View File

@ -53,8 +53,13 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
return nil, fmt.Errorf("failed to create container edits for CSV files: %v", err)
}
name, err := l.deviceNamer.GetDeviceName(0, uuidUnsupported{})
if err != nil {
return nil, fmt.Errorf("failed to get device name: %v", err)
}
deviceSpec := specs.Device{
Name: "all",
Name: name,
ContainerEdits: *e.ContainerEdits,
}
return []specs.Device{deviceSpec}, nil

View File

@ -36,7 +36,7 @@ func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.Mi
return nil, fmt.Errorf("failed to get edits for device: %v", err)
}
name, err := l.deviceNamer.GetMigDeviceName(i, d, j, mig)
name, err := l.deviceNamer.GetMigDeviceName(i, convert{d}, j, convert{mig})
if err != nil {
return nil, fmt.Errorf("failed to get device name: %v", err)
}

View File

@ -17,16 +17,21 @@
package nvcdi
import (
"errors"
"fmt"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
// UUIDer is an interface for getting UUIDs.
type UUIDer interface {
GetUUID() (string, error)
}
// DeviceNamer is an interface for getting device names
type DeviceNamer interface {
GetDeviceName(int, device.Device) (string, error)
GetMigDeviceName(int, device.Device, int, device.MigDevice) (string, error)
GetDeviceName(int, UUIDer) (string, error)
GetMigDeviceName(int, UUIDer, int, UUIDer) (string, error)
}
// Supported device naming strategies
@ -61,29 +66,57 @@ func NewDeviceNamer(strategy string) (DeviceNamer, error) {
}
// GetDeviceName returns the name for the specified device based on the naming strategy
func (s deviceNameIndex) GetDeviceName(i int, d device.Device) (string, error) {
func (s deviceNameIndex) GetDeviceName(i int, _ UUIDer) (string, error) {
return fmt.Sprintf("%s%d", s.gpuPrefix, i), nil
}
// GetMigDeviceName returns the name for the specified device based on the naming strategy
func (s deviceNameIndex) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) {
func (s deviceNameIndex) GetMigDeviceName(i int, _ UUIDer, j int, _ UUIDer) (string, error) {
return fmt.Sprintf("%s%d:%d", s.migPrefix, i, j), nil
}
// GetDeviceName returns the name for the specified device based on the naming strategy
func (s deviceNameUUID) GetDeviceName(i int, d device.Device) (string, error) {
uuid, ret := d.GetUUID()
if ret != nvml.SUCCESS {
return "", fmt.Errorf("failed to get device UUID: %v", ret)
func (s deviceNameUUID) GetDeviceName(i int, d UUIDer) (string, error) {
uuid, err := d.GetUUID()
if err != nil {
return "", fmt.Errorf("failed to get device UUID: %v", err)
}
return uuid, nil
}
// GetMigDeviceName returns the name for the specified device based on the naming strategy
func (s deviceNameUUID) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) {
uuid, ret := mig.GetUUID()
if ret != nvml.SUCCESS {
return "", fmt.Errorf("failed to get device UUID: %v", ret)
func (s deviceNameUUID) GetMigDeviceName(i int, _ UUIDer, j int, mig UUIDer) (string, error) {
uuid, err := mig.GetUUID()
if err != nil {
return "", fmt.Errorf("failed to get device UUID: %v", err)
}
return uuid, nil
}
//go:generate moq -stub -out namer_nvml_mock.go . nvmlUUIDer
type nvmlUUIDer interface {
GetUUID() (string, nvml.Return)
}
type convert struct {
nvmlUUIDer
}
type uuidUnsupported struct{}
func (m convert) GetUUID() (string, error) {
if m.nvmlUUIDer == nil {
return uuidUnsupported{}.GetUUID()
}
uuid, ret := m.nvmlUUIDer.GetUUID()
if ret != nvml.SUCCESS {
return "", ret
}
return uuid, nil
}
var errUUIDUnsupported = errors.New("GetUUID is not supported")
func (m uuidUnsupported) GetUUID() (string, error) {
return "", errUUIDUnsupported
}

View File

@ -0,0 +1,72 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvcdi
import (
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
"sync"
)
// Ensure, that nvmlUUIDerMock does implement nvmlUUIDer.
// If this is not the case, regenerate this file with moq.
var _ nvmlUUIDer = &nvmlUUIDerMock{}
// nvmlUUIDerMock is a mock implementation of nvmlUUIDer.
//
// func TestSomethingThatUsesnvmlUUIDer(t *testing.T) {
//
// // make and configure a mocked nvmlUUIDer
// mockednvmlUUIDer := &nvmlUUIDerMock{
// GetUUIDFunc: func() (string, nvml.Return) {
// panic("mock out the GetUUID method")
// },
// }
//
// // use mockednvmlUUIDer in code that requires nvmlUUIDer
// // and then make assertions.
//
// }
type nvmlUUIDerMock struct {
// GetUUIDFunc mocks the GetUUID method.
GetUUIDFunc func() (string, nvml.Return)
// calls tracks calls to the methods.
calls struct {
// GetUUID holds details about calls to the GetUUID method.
GetUUID []struct {
}
}
lockGetUUID sync.RWMutex
}
// GetUUID calls GetUUIDFunc.
func (mock *nvmlUUIDerMock) GetUUID() (string, nvml.Return) {
callInfo := struct {
}{}
mock.lockGetUUID.Lock()
mock.calls.GetUUID = append(mock.calls.GetUUID, callInfo)
mock.lockGetUUID.Unlock()
if mock.GetUUIDFunc == nil {
var (
sOut string
returnOut nvml.Return
)
return sOut, returnOut
}
return mock.GetUUIDFunc()
}
// GetUUIDCalls gets all the calls that were made to GetUUID.
// Check the length with:
//
// len(mockednvmlUUIDer.GetUUIDCalls())
func (mock *nvmlUUIDerMock) GetUUIDCalls() []struct {
} {
var calls []struct {
}
mock.lockGetUUID.RLock()
calls = mock.calls.GetUUID
mock.lockGetUUID.RUnlock()
return calls
}

67
pkg/nvcdi/namer_test.go Normal file
View File

@ -0,0 +1,67 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvcdi
import (
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
)
func TestConvert(t *testing.T) {
testCases := []struct {
description string
nvml nvmlUUIDer
expectedError error
expecteUUID string
}{
{
description: "empty UUIDer returns error",
expectedError: errUUIDUnsupported,
expecteUUID: "",
},
{
description: "nvmlUUIDer returns UUID",
nvml: &nvmlUUIDerMock{
GetUUIDFunc: func() (string, nvml.Return) {
return "SOME_UUID", nvml.SUCCESS
},
},
expectedError: nil,
expecteUUID: "SOME_UUID",
},
{
description: "nvmlUUIDer returns error",
nvml: &nvmlUUIDerMock{
GetUUIDFunc: func() (string, nvml.Return) {
return "SOME_UUID", nvml.ERROR_UNKNOWN
},
},
expectedError: nvml.ERROR_UNKNOWN,
expecteUUID: "",
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
uuid, err := convert{tc.nvml}.GetUUID()
require.ErrorIs(t, err, tc.expectedError)
require.Equal(t, tc.expecteUUID, uuid)
})
}
}