diff --git a/internal/discover/csv.go b/internal/discover/csv.go index 987da988..37130856 100644 --- a/internal/discover/csv.go +++ b/internal/discover/csv.go @@ -54,41 +54,69 @@ func NewFromCSV(logger *logrus.Logger, csvRoot string, root string) (Discover, e locators[csv.MountSpecSym] = symlinkLocator var discoverers []Discover - // Create a discoverer for each file-kind combination - for _, file := range files { - targets, err := csv.ParseFile(logger, file) + for _, filename := range files { + d, err := NewFromCSVFile(logger, locators, filename) if err != nil { - logger.Warnf("Skipping failed CSV file %v: %v", file, err) + logger.Warnf("Skipping CSV file %v: %v", filename, err) continue } - if len(targets) == 0 { - logger.Warnf("Skipping empty CSV file %v", file) - continue - } - - candidatesByType := make(map[csv.MountSpecType][]string) - for _, t := range targets { - candidatesByType[t.Type] = append(candidatesByType[t.Type], t.Path) - } - - for t, candidates := range candidatesByType { - d := csvDiscoverer{ - filename: file, - mountType: t, - mounts: mounts{ - logger: logger, - lookup: locators[t], - required: candidates, - }, - } - discoverers = append(discoverers, &d) - } - + discoverers = append(discoverers, d) } return &list{discoverers: discoverers}, nil } +// NewFromCSVFile creates a discoverer for the CSV file. A logger is also supplied. +func NewFromCSVFile(logger *logrus.Logger, locators map[csv.MountSpecType]lookup.Locator, filename string) (Discover, error) { + // Create a discoverer for each file-kind combination + targets, err := csv.ParseFile(logger, filename) + if err != nil { + return nil, fmt.Errorf("failed to parse CSV file: %v", err) + } + if len(targets) == 0 { + return nil, fmt.Errorf("CSV file is empty") + } + + csvDiscoverers, err := newFromMountSpecs(logger, locators, targets) + if err != nil { + return nil, err + } + var discoverers []Discover + for _, d := range csvDiscoverers { + d.filename = filename + discoverers = append(discoverers, d) + } + + return &list{discoverers: discoverers}, nil +} + +// newFromMountSpecs creates a discoverer for the CSV file. A logger is also supplied. +func newFromMountSpecs(logger *logrus.Logger, locators map[csv.MountSpecType]lookup.Locator, targets []*csv.MountSpec) ([]*csvDiscoverer, error) { + var discoverers []*csvDiscoverer + candidatesByType := make(map[csv.MountSpecType][]string) + for _, t := range targets { + candidatesByType[t.Type] = append(candidatesByType[t.Type], t.Path) + } + + for t, candidates := range candidatesByType { + locator, exists := locators[t] + if !exists { + return nil, fmt.Errorf("no locator defined for '%v'", t) + } + d := csvDiscoverer{ + mounts: mounts{ + logger: logger, + lookup: locator, + required: candidates, + }, + mountType: t, + } + discoverers = append(discoverers, &d) + } + + return discoverers, nil +} + func (d csvDiscoverer) Mounts() ([]Mount, error) { if d.mountType == csv.MountSpecDev { return d.None.Mounts() @@ -96,3 +124,21 @@ func (d csvDiscoverer) Mounts() ([]Mount, error) { return d.mounts.Mounts() } + +func (d csvDiscoverer) Devices() ([]Device, error) { + if d.mountType != csv.MountSpecDev { + return d.None.Devices() + } + + mounts, err := d.mounts.Mounts() + if err != nil { + return nil, err + } + var devices []Device + for _, mount := range mounts { + device := Device(mount) + devices = append(devices, device) + } + + return devices, nil +} diff --git a/internal/discover/csv_test.go b/internal/discover/csv_test.go index 8970d62e..f11d7af0 100644 --- a/internal/discover/csv_test.go +++ b/internal/discover/csv_test.go @@ -15,3 +15,172 @@ **/ package discover + +import ( + "fmt" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestCSVDiscoverer(t *testing.T) { + logger, logHook := testlog.NewNullLogger() + + testCases := []struct { + description string + input csvDiscoverer + expectedMounts []Mount + expectedMountsError error + expectedDevicesError error + expectedDevices []Device + }{ + { + description: "dev mounts are empty", + input: csvDiscoverer{ + mounts: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(string) ([]string, error) { + return []string{"located"}, nil + }, + }, + required: []string{"required"}, + }, + mountType: "dev", + }, + expectedDevices: []Device{{Path: "located"}}, + }, + { + description: "dev devices returns error for nil lookup", + input: csvDiscoverer{ + mountType: "dev", + }, + expectedDevicesError: fmt.Errorf("no lookup defined"), + }, + { + description: "lib devices are empty", + input: csvDiscoverer{ + mounts: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(string) ([]string, error) { + return []string{"located"}, nil + }, + }, + required: []string{"required"}, + }, + mountType: "lib", + }, + expectedMounts: []Mount{{Path: "located"}}, + }, + { + description: "lib mounts returns error for nil lookup", + input: csvDiscoverer{ + mountType: "lib", + }, + expectedMountsError: fmt.Errorf("no lookup defined"), + }, + } + + for _, tc := range testCases { + logHook.Reset() + + t.Run(tc.description, func(t *testing.T) { + tc.input.logger = logger + + mounts, err := tc.input.Mounts() + if tc.expectedMountsError != nil { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.ElementsMatch(t, tc.expectedMounts, mounts) + + devices, err := tc.input.Devices() + if tc.expectedDevicesError != nil { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.ElementsMatch(t, tc.expectedDevices, devices) + }) + } +} + +func TestNewFromMountSpec(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + locators := map[csv.MountSpecType]lookup.Locator{ + "dev": &lookup.LocatorMock{}, + "lib": &lookup.LocatorMock{}, + } + + testCases := []struct { + description string + targets []*csv.MountSpec + expectedError error + expectedCSVDiscoverers []*csvDiscoverer + }{ + { + description: "empty targets returns empyt list", + }, + { + description: "unexpected locator returns error", + targets: []*csv.MountSpec{ + { + Type: "foo", + Path: "bar", + }, + }, + expectedError: fmt.Errorf("no locator defined for foo"), + }, + { + description: "creates discoverers based on type", + targets: []*csv.MountSpec{ + { + Type: "dev", + Path: "dev0", + }, + { + Type: "lib", + Path: "lib0", + }, + { + Type: "dev", + Path: "dev1", + }, + }, + expectedCSVDiscoverers: []*csvDiscoverer{ + { + mountType: "dev", + mounts: mounts{ + logger: logger, + lookup: locators["dev"], + required: []string{"dev0", "dev1"}, + }, + }, + { + mountType: "lib", + mounts: mounts{ + logger: logger, + lookup: locators["lib"], + required: []string{"lib0"}, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + discoverers, err := newFromMountSpecs(logger, locators, tc.targets) + if tc.expectedError != nil { + require.Error(t, err) + return + } + require.NoError(t, err) + require.ElementsMatch(t, tc.expectedCSVDiscoverers, discoverers) + }) + } +} diff --git a/internal/discover/discover.go b/internal/discover/discover.go index 767a632a..77a8fdd5 100644 --- a/internal/discover/discover.go +++ b/internal/discover/discover.go @@ -16,6 +16,11 @@ package discover +// Device represents a discovered character device. +type Device struct { + Path string +} + // Mount represents a discovered mount. type Mount struct { Path string @@ -29,8 +34,9 @@ type Hook struct { } //go:generate moq -stub -out discover_mock.go . Discover -// Discover defines an interface for discovering the hooks and mounts available on a system +// Discover defines an interface for discovering the devices, mounts, and hooks available on a system type Discover interface { + Devices() ([]Device, error) Mounts() ([]Mount, error) Hooks() ([]Hook, error) } diff --git a/internal/discover/discover_mock.go b/internal/discover/discover_mock.go index 488ba7f4..e3f57579 100644 --- a/internal/discover/discover_mock.go +++ b/internal/discover/discover_mock.go @@ -17,6 +17,9 @@ var _ Discover = &DiscoverMock{} // // // make and configure a mocked Discover // mockedDiscover := &DiscoverMock{ +// DevicesFunc: func() ([]Device, error) { +// panic("mock out the Devices method") +// }, // HooksFunc: func() ([]Hook, error) { // panic("mock out the Hooks method") // }, @@ -30,6 +33,9 @@ var _ Discover = &DiscoverMock{} // // } type DiscoverMock struct { + // DevicesFunc mocks the Devices method. + DevicesFunc func() ([]Device, error) + // HooksFunc mocks the Hooks method. HooksFunc func() ([]Hook, error) @@ -38,6 +44,9 @@ type DiscoverMock struct { // calls tracks calls to the methods. calls struct { + // Devices holds details about calls to the Devices method. + Devices []struct { + } // Hooks holds details about calls to the Hooks method. Hooks []struct { } @@ -45,8 +54,39 @@ type DiscoverMock struct { Mounts []struct { } } - lockHooks sync.RWMutex - lockMounts sync.RWMutex + lockDevices sync.RWMutex + lockHooks sync.RWMutex + lockMounts sync.RWMutex +} + +// Devices calls DevicesFunc. +func (mock *DiscoverMock) Devices() ([]Device, error) { + callInfo := struct { + }{} + mock.lockDevices.Lock() + mock.calls.Devices = append(mock.calls.Devices, callInfo) + mock.lockDevices.Unlock() + if mock.DevicesFunc == nil { + var ( + devicesOut []Device + errOut error + ) + return devicesOut, errOut + } + return mock.DevicesFunc() +} + +// DevicesCalls gets all the calls that were made to Devices. +// Check the length with: +// len(mockedDiscover.DevicesCalls()) +func (mock *DiscoverMock) DevicesCalls() []struct { +} { + var calls []struct { + } + mock.lockDevices.RLock() + calls = mock.calls.Devices + mock.lockDevices.RUnlock() + return calls } // Hooks calls HooksFunc. diff --git a/internal/discover/list.go b/internal/discover/list.go index 633b2260..f19c8c15 100644 --- a/internal/discover/list.go +++ b/internal/discover/list.go @@ -27,6 +27,21 @@ type list struct { var _ Discover = (*list)(nil) +// Devices returns all devices from the included discoverers +func (d list) Devices() ([]Device, error) { + var allDevices []Device + + for i, di := range d.discoverers { + devices, err := di.Devices() + if err != nil { + return nil, fmt.Errorf("error discovering devices for discoverer %v: %v", i, err) + } + allDevices = append(allDevices, devices...) + } + + return allDevices, nil +} + // Mounts returns all mounts from the included discoverers func (d list) Mounts() ([]Mount, error) { var allMounts []Mount diff --git a/internal/discover/mounts_test.go b/internal/discover/mounts_test.go index 02165a82..35e68e76 100644 --- a/internal/discover/mounts_test.go +++ b/internal/discover/mounts_test.go @@ -26,6 +26,14 @@ import ( testlog "github.com/sirupsen/logrus/hooks/test" ) +func TestMountsReturnsEmptyDevices(t *testing.T) { + d := mounts{} + devices, err := d.Devices() + + require.NoError(t, err) + require.Empty(t, devices) +} + func TestMounts(t *testing.T) { logger, logHook := testlog.NewNullLogger() diff --git a/internal/discover/none.go b/internal/discover/none.go index 57671ef6..989a2e16 100644 --- a/internal/discover/none.go +++ b/internal/discover/none.go @@ -22,6 +22,11 @@ type None struct{} var _ Discover = (*None)(nil) +// Devices returns an empty list of devices +func (e None) Devices() ([]Device, error) { + return []Device{}, nil +} + // Mounts returns an empty list of mounts func (e None) Mounts() ([]Mount, error) { return []Mount{}, nil