diff --git a/internal/discover/csv.go b/internal/discover/csv.go new file mode 100644 index 00000000..987da988 --- /dev/null +++ b/internal/discover/csv.go @@ -0,0 +1,98 @@ +/** +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package discover + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover/csv" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/sirupsen/logrus" +) + +type csvDiscoverer struct { + mounts + filename string + mountType csv.MountSpecType +} + +var _ Discover = (*csvDiscoverer)(nil) + +// NewFromCSV creates a discoverer for the CSV files at the specified root. A logger is also supplied. +func NewFromCSV(logger *logrus.Logger, csvRoot string, root string) (Discover, error) { + logger.Debugf("Loading CSV files from: %v", csvRoot) + + files, err := csv.GetFileList(csvRoot) + if err != nil { + return nil, fmt.Errorf("failed to get CSV file from %v: %v", csvRoot, err) + } + if len(files) == 0 { + logger.Warnf("No CSV files found in %v", csvRoot) + return None{}, nil + } + + locators := make(map[csv.MountSpecType]lookup.Locator) + locators[csv.MountSpecDev] = lookup.NewCharDeviceLocator(logger, root) + locators[csv.MountSpecDir] = lookup.NewDirectoryLocator(logger, root) + // Libraries and symlinks are handled in the same way + symlinkLocator := lookup.NewSymlinkLocator(logger, root) + locators[csv.MountSpecLib] = symlinkLocator + locators[csv.MountSpecSym] = symlinkLocator + + var discoverers []Discover + // Create a discoverer for each file-kind combination + for _, file := range files { + targets, err := csv.ParseFile(logger, file) + if err != nil { + logger.Warnf("Skipping failed CSV file %v: %v", file, err) + continue + } + if len(targets) == 0 { + logger.Warnf("Skipping empty CSV file %v", file) + continue + } + + candidatesByType := make(map[csv.MountSpecType][]string) + for _, t := range targets { + candidatesByType[t.Type] = append(candidatesByType[t.Type], t.Path) + } + + for t, candidates := range candidatesByType { + d := csvDiscoverer{ + filename: file, + mountType: t, + mounts: mounts{ + logger: logger, + lookup: locators[t], + required: candidates, + }, + } + discoverers = append(discoverers, &d) + } + + } + + return &list{discoverers: discoverers}, nil +} + +func (d csvDiscoverer) Mounts() ([]Mount, error) { + if d.mountType == csv.MountSpecDev { + return d.None.Mounts() + } + + return d.mounts.Mounts() +} diff --git a/internal/discover/csv_test.go b/internal/discover/csv_test.go new file mode 100644 index 00000000..8970d62e --- /dev/null +++ b/internal/discover/csv_test.go @@ -0,0 +1,17 @@ +/** +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package discover diff --git a/internal/discover/discover.go b/internal/discover/discover.go index 4858b282..767a632a 100644 --- a/internal/discover/discover.go +++ b/internal/discover/discover.go @@ -16,6 +16,11 @@ package discover +// Mount represents a discovered mount. +type Mount struct { + Path string +} + // Hook represents a discovered hook. type Hook struct { Lifecycle string @@ -24,7 +29,8 @@ type Hook struct { } //go:generate moq -stub -out discover_mock.go . Discover -// Discover defines an interface for discovering the devices, mounts, and hooks available on a system +// Discover defines an interface for discovering the hooks and mounts available on a system type Discover interface { + Mounts() ([]Mount, error) Hooks() ([]Hook, error) } diff --git a/internal/discover/discover_mock.go b/internal/discover/discover_mock.go index 0a0f211f..488ba7f4 100644 --- a/internal/discover/discover_mock.go +++ b/internal/discover/discover_mock.go @@ -20,6 +20,9 @@ var _ Discover = &DiscoverMock{} // HooksFunc: func() ([]Hook, error) { // panic("mock out the Hooks method") // }, +// MountsFunc: func() ([]Mount, error) { +// panic("mock out the Mounts method") +// }, // } // // // use mockedDiscover in code that requires Discover @@ -30,13 +33,20 @@ type DiscoverMock struct { // HooksFunc mocks the Hooks method. HooksFunc func() ([]Hook, error) + // MountsFunc mocks the Mounts method. + MountsFunc func() ([]Mount, error) + // calls tracks calls to the methods. calls struct { // Hooks holds details about calls to the Hooks method. Hooks []struct { } + // Mounts holds details about calls to the Mounts method. + Mounts []struct { + } } - lockHooks sync.RWMutex + lockHooks sync.RWMutex + lockMounts sync.RWMutex } // Hooks calls HooksFunc. @@ -68,3 +78,33 @@ func (mock *DiscoverMock) HooksCalls() []struct { mock.lockHooks.RUnlock() return calls } + +// Mounts calls MountsFunc. +func (mock *DiscoverMock) Mounts() ([]Mount, error) { + callInfo := struct { + }{} + mock.lockMounts.Lock() + mock.calls.Mounts = append(mock.calls.Mounts, callInfo) + mock.lockMounts.Unlock() + if mock.MountsFunc == nil { + var ( + mountsOut []Mount + errOut error + ) + return mountsOut, errOut + } + return mock.MountsFunc() +} + +// MountsCalls gets all the calls that were made to Mounts. +// Check the length with: +// len(mockedDiscover.MountsCalls()) +func (mock *DiscoverMock) MountsCalls() []struct { +} { + var calls []struct { + } + mock.lockMounts.RLock() + calls = mock.calls.Mounts + mock.lockMounts.RUnlock() + return calls +} diff --git a/internal/discover/legacy.go b/internal/discover/legacy.go index 53f6ffc1..7f4c5276 100644 --- a/internal/discover/legacy.go +++ b/internal/discover/legacy.go @@ -23,6 +23,7 @@ import ( ) type legacy struct { + None logger *logrus.Logger lookup lookup.Locator } diff --git a/internal/discover/list.go b/internal/discover/list.go new file mode 100644 index 00000000..633b2260 --- /dev/null +++ b/internal/discover/list.go @@ -0,0 +1,58 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package discover + +import "fmt" + +// list is a discoverer that contains a list of Discoverers. The output of the +// Mounts functions is the concatenation of the output for each of the +// elements in the list. +type list struct { + discoverers []Discover +} + +var _ Discover = (*list)(nil) + +// Mounts returns all mounts from the included discoverers +func (d list) Mounts() ([]Mount, error) { + var allMounts []Mount + + for i, di := range d.discoverers { + mounts, err := di.Mounts() + if err != nil { + return nil, fmt.Errorf("error discovering mounts for discoverer %v: %v", i, err) + } + allMounts = append(allMounts, mounts...) + } + + return allMounts, nil +} + +// Hooks returns all Hooks from the included discoverers +func (d list) Hooks() ([]Hook, error) { + var allHooks []Hook + + for i, di := range d.discoverers { + hooks, err := di.Hooks() + if err != nil { + return nil, fmt.Errorf("error discovering hooks for discoverer %v: %v", i, err) + } + allHooks = append(allHooks, hooks...) + } + + return allHooks, nil +} diff --git a/internal/discover/mounts.go b/internal/discover/mounts.go new file mode 100644 index 00000000..f294e522 --- /dev/null +++ b/internal/discover/mounts.go @@ -0,0 +1,72 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package discover + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/sirupsen/logrus" +) + +// mounts is a generic discoverer for Mounts. It is customized by specifying the +// required entities as a list and a Locator that is used to find the target mounts +// based on the entry in the list. +type mounts struct { + None + logger *logrus.Logger + lookup lookup.Locator + required []string +} + +var _ Discover = (*mounts)(nil) + +func (d mounts) Mounts() ([]Mount, error) { + if d.lookup == nil { + return nil, fmt.Errorf("no lookup defined") + } + + paths := make(map[string]bool) + + for _, candidate := range d.required { + d.logger.Debugf("Locating %v", candidate) + located, err := d.lookup.Locate(candidate) + if err != nil { + d.logger.Warnf("Could not locate %v: %v", candidate, err) + continue + } + if len(located) == 0 { + d.logger.Warnf("Missing %v", candidate) + continue + } + d.logger.Debugf("Located %v as %v", candidate, located) + for _, p := range located { + paths[p] = true + } + } + + var mounts []Mount + for path := range paths { + d.logger.Infof("Selecting %v", path) + mount := Mount{ + Path: path, + } + mounts = append(mounts, mount) + } + + return mounts, nil +} diff --git a/internal/discover/mounts_test.go b/internal/discover/mounts_test.go new file mode 100644 index 00000000..02165a82 --- /dev/null +++ b/internal/discover/mounts_test.go @@ -0,0 +1,156 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package discover + +import ( + "fmt" + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/stretchr/testify/require" + + testlog "github.com/sirupsen/logrus/hooks/test" +) + +func TestMounts(t *testing.T) { + logger, logHook := testlog.NewNullLogger() + + testCases := []struct { + description string + expectedError error + expectedMounts []Mount + input mounts + }{ + { + description: "nill lookup returns error", + expectedError: fmt.Errorf("no lookup defined"), + }, + { + description: "empty required returns no mounts", + expectedError: nil, + input: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(string) ([]string, error) { + return []string{"located"}, nil + }, + }, + }, + }, + { + description: "required returns located", + expectedError: nil, + input: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(string) ([]string, error) { + return []string{"located"}, nil + }, + }, + required: []string{"required"}, + }, + expectedMounts: []Mount{{Path: "located"}}, + }, + { + description: "mounts removes located duplicates", + expectedError: nil, + input: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(string) ([]string, error) { + return []string{"located"}, nil + }, + }, + required: []string{"required0", "required1"}, + }, + expectedMounts: []Mount{{Path: "located"}}, + }, + { + description: "mounts skips located errors", + input: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(s string) ([]string, error) { + if s == "error" { + return nil, fmt.Errorf(s) + } + return []string{s}, nil + }, + }, + required: []string{"required0", "error", "required1"}, + }, + expectedMounts: []Mount{{Path: "required0"}, {Path: "required1"}}, + }, + { + description: "mounts skips unlocated", + input: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(s string) ([]string, error) { + if s == "empty" { + return nil, nil + } + return []string{s}, nil + }, + }, + required: []string{"required0", "empty", "required1"}, + }, + expectedMounts: []Mount{{Path: "required0"}, {Path: "required1"}}, + }, + { + description: "mounts skips unlocated", + input: mounts{ + lookup: &lookup.LocatorMock{ + LocateFunc: func(s string) ([]string, error) { + if s == "multiple" { + return []string{"multiple0", "multiple1"}, nil + } + return []string{s}, nil + }, + }, + required: []string{"required0", "multiple", "required1"}, + }, + expectedMounts: []Mount{ + {Path: "required0"}, + {Path: "multiple0"}, + {Path: "multiple1"}, + {Path: "required1"}, + }, + }, + } + + for _, tc := range testCases { + logHook.Reset() + t.Run(tc.description, func(t *testing.T) { + tc.input.logger = logger + mounts, err := tc.input.Mounts() + + if tc.expectedError != nil { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.ElementsMatch(t, tc.expectedMounts, mounts) + + // We check that the mock is called for each element of required + if tc.input.lookup != nil { + mock := tc.input.lookup.(*lookup.LocatorMock) + require.Len(t, mock.LocateCalls(), len(tc.input.required)) + var args []string + for _, c := range mock.LocateCalls() { + args = append(args, c.S) + } + require.EqualValues(t, args, tc.input.required) + } + }) + } +} diff --git a/internal/discover/none.go b/internal/discover/none.go new file mode 100644 index 00000000..57671ef6 --- /dev/null +++ b/internal/discover/none.go @@ -0,0 +1,33 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package discover + +// None is a null discoverer that returns an empty list of devices and +// mounts. +type None struct{} + +var _ Discover = (*None)(nil) + +// Mounts returns an empty list of mounts +func (e None) Mounts() ([]Mount, error) { + return []Mount{}, nil +} + +// Hooks returns and empty list of hooks +func (e None) Hooks() ([]Hook, error) { + return []Hook{}, nil +} diff --git a/internal/discover/none_test.go b/internal/discover/none_test.go new file mode 100644 index 00000000..d69bbd53 --- /dev/null +++ b/internal/discover/none_test.go @@ -0,0 +1,31 @@ +/* +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +*/ + +package discover + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNone(t *testing.T) { + d := None{} + + mounts, err := d.Mounts() + require.NoError(t, err) + require.Empty(t, mounts) +}