diff --git a/pkg/nvcdi/full-gpu-nvml.go b/pkg/nvcdi/full-gpu-nvml.go index 174c5d29..65a01b2e 100644 --- a/pkg/nvcdi/full-gpu-nvml.go +++ b/pkg/nvcdi/full-gpu-nvml.go @@ -39,7 +39,7 @@ func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, erro return nil, fmt.Errorf("failed to get edits for device: %v", err) } - name, err := l.deviceNamer.GetDeviceName(i, d) + name, err := l.deviceNamer.GetDeviceName(i, convert{d}) if err != nil { return nil, fmt.Errorf("failed to get device name: %v", err) } diff --git a/pkg/nvcdi/mig-device-nvml.go b/pkg/nvcdi/mig-device-nvml.go index d5d2bdec..28f21892 100644 --- a/pkg/nvcdi/mig-device-nvml.go +++ b/pkg/nvcdi/mig-device-nvml.go @@ -36,7 +36,7 @@ func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.Mi return nil, fmt.Errorf("failed to get edits for device: %v", err) } - name, err := l.deviceNamer.GetMigDeviceName(i, d, j, mig) + name, err := l.deviceNamer.GetMigDeviceName(i, convert{d}, j, convert{mig}) if err != nil { return nil, fmt.Errorf("failed to get device name: %v", err) } diff --git a/pkg/nvcdi/namer.go b/pkg/nvcdi/namer.go index e7b850da..8e443cdb 100644 --- a/pkg/nvcdi/namer.go +++ b/pkg/nvcdi/namer.go @@ -17,16 +17,21 @@ package nvcdi import ( + "errors" "fmt" - "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) +// UUIDer is an interface for getting UUIDs. +type UUIDer interface { + GetUUID() (string, error) +} + // DeviceNamer is an interface for getting device names type DeviceNamer interface { - GetDeviceName(int, device.Device) (string, error) - GetMigDeviceName(int, device.Device, int, device.MigDevice) (string, error) + GetDeviceName(int, UUIDer) (string, error) + GetMigDeviceName(int, UUIDer, int, UUIDer) (string, error) } // Supported device naming strategies @@ -61,29 +66,57 @@ func NewDeviceNamer(strategy string) (DeviceNamer, error) { } // GetDeviceName returns the name for the specified device based on the naming strategy -func (s deviceNameIndex) GetDeviceName(i int, d device.Device) (string, error) { +func (s deviceNameIndex) GetDeviceName(i int, _ UUIDer) (string, error) { return fmt.Sprintf("%s%d", s.gpuPrefix, i), nil } // GetMigDeviceName returns the name for the specified device based on the naming strategy -func (s deviceNameIndex) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) { +func (s deviceNameIndex) GetMigDeviceName(i int, _ UUIDer, j int, _ UUIDer) (string, error) { return fmt.Sprintf("%s%d:%d", s.migPrefix, i, j), nil } // GetDeviceName returns the name for the specified device based on the naming strategy -func (s deviceNameUUID) GetDeviceName(i int, d device.Device) (string, error) { - uuid, ret := d.GetUUID() - if ret != nvml.SUCCESS { - return "", fmt.Errorf("failed to get device UUID: %v", ret) +func (s deviceNameUUID) GetDeviceName(i int, d UUIDer) (string, error) { + uuid, err := d.GetUUID() + if err != nil { + return "", fmt.Errorf("failed to get device UUID: %v", err) } return uuid, nil } // GetMigDeviceName returns the name for the specified device based on the naming strategy -func (s deviceNameUUID) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) { - uuid, ret := mig.GetUUID() - if ret != nvml.SUCCESS { - return "", fmt.Errorf("failed to get device UUID: %v", ret) +func (s deviceNameUUID) GetMigDeviceName(i int, _ UUIDer, j int, mig UUIDer) (string, error) { + uuid, err := mig.GetUUID() + if err != nil { + return "", fmt.Errorf("failed to get device UUID: %v", err) } return uuid, nil } + +//go:generate moq -stub -out namer_nvml_mock.go . nvmlUUIDer +type nvmlUUIDer interface { + GetUUID() (string, nvml.Return) +} + +type convert struct { + nvmlUUIDer +} + +type uuidUnsupported struct{} + +func (m convert) GetUUID() (string, error) { + if m.nvmlUUIDer == nil { + return uuidUnsupported{}.GetUUID() + } + uuid, ret := m.nvmlUUIDer.GetUUID() + if ret != nvml.SUCCESS { + return "", ret + } + return uuid, nil +} + +var errUUIDUnsupported = errors.New("GetUUID is not supported") + +func (m uuidUnsupported) GetUUID() (string, error) { + return "", errUUIDUnsupported +} diff --git a/pkg/nvcdi/namer_nvml_mock.go b/pkg/nvcdi/namer_nvml_mock.go new file mode 100644 index 00000000..f87a0fcb --- /dev/null +++ b/pkg/nvcdi/namer_nvml_mock.go @@ -0,0 +1,72 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package nvcdi + +import ( + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" + "sync" +) + +// Ensure, that nvmlUUIDerMock does implement nvmlUUIDer. +// If this is not the case, regenerate this file with moq. +var _ nvmlUUIDer = &nvmlUUIDerMock{} + +// nvmlUUIDerMock is a mock implementation of nvmlUUIDer. +// +// func TestSomethingThatUsesnvmlUUIDer(t *testing.T) { +// +// // make and configure a mocked nvmlUUIDer +// mockednvmlUUIDer := &nvmlUUIDerMock{ +// GetUUIDFunc: func() (string, nvml.Return) { +// panic("mock out the GetUUID method") +// }, +// } +// +// // use mockednvmlUUIDer in code that requires nvmlUUIDer +// // and then make assertions. +// +// } +type nvmlUUIDerMock struct { + // GetUUIDFunc mocks the GetUUID method. + GetUUIDFunc func() (string, nvml.Return) + + // calls tracks calls to the methods. + calls struct { + // GetUUID holds details about calls to the GetUUID method. + GetUUID []struct { + } + } + lockGetUUID sync.RWMutex +} + +// GetUUID calls GetUUIDFunc. +func (mock *nvmlUUIDerMock) GetUUID() (string, nvml.Return) { + callInfo := struct { + }{} + mock.lockGetUUID.Lock() + mock.calls.GetUUID = append(mock.calls.GetUUID, callInfo) + mock.lockGetUUID.Unlock() + if mock.GetUUIDFunc == nil { + var ( + sOut string + returnOut nvml.Return + ) + return sOut, returnOut + } + return mock.GetUUIDFunc() +} + +// GetUUIDCalls gets all the calls that were made to GetUUID. +// Check the length with: +// +// len(mockednvmlUUIDer.GetUUIDCalls()) +func (mock *nvmlUUIDerMock) GetUUIDCalls() []struct { +} { + var calls []struct { + } + mock.lockGetUUID.RLock() + calls = mock.calls.GetUUID + mock.lockGetUUID.RUnlock() + return calls +} diff --git a/pkg/nvcdi/namer_test.go b/pkg/nvcdi/namer_test.go new file mode 100644 index 00000000..30169267 --- /dev/null +++ b/pkg/nvcdi/namer_test.go @@ -0,0 +1,67 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "testing" + + "github.com/stretchr/testify/require" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" +) + +func TestConvert(t *testing.T) { + testCases := []struct { + description string + nvml nvmlUUIDer + expectedError error + expecteUUID string + }{ + { + description: "empty UUIDer returns error", + expectedError: errUUIDUnsupported, + expecteUUID: "", + }, + { + description: "nvmlUUIDer returns UUID", + nvml: &nvmlUUIDerMock{ + GetUUIDFunc: func() (string, nvml.Return) { + return "SOME_UUID", nvml.SUCCESS + }, + }, + expectedError: nil, + expecteUUID: "SOME_UUID", + }, + { + description: "nvmlUUIDer returns error", + nvml: &nvmlUUIDerMock{ + GetUUIDFunc: func() (string, nvml.Return) { + return "SOME_UUID", nvml.ERROR_UNKNOWN + }, + }, + expectedError: nvml.ERROR_UNKNOWN, + expecteUUID: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + uuid, err := convert{tc.nvml}.GetUUID() + require.ErrorIs(t, err, tc.expectedError) + require.Equal(t, tc.expecteUUID, uuid) + }) + } +}