diff --git a/cmd/nvidia-ctk/cdi/generate/ipc.go b/cmd/nvidia-ctk/cdi/generate/ipc.go index 0f5157ee..e7c940b1 100644 --- a/cmd/nvidia-ctk/cdi/generate/ipc.go +++ b/cmd/nvidia-ctk/cdi/generate/ipc.go @@ -26,7 +26,10 @@ import ( func NewIPCDiscoverer(logger *logrus.Logger, root string) (discover.Discover, error) { d := discover.NewMounts( logger, - lookup.NewFileLocator(logger, root), + lookup.NewFileLocator( + lookup.WithLogger(logger), + lookup.WithRoot(root), + ), root, []string{ "/var/run/nvidia-persistenced/socket", diff --git a/internal/discover/gds.go b/internal/discover/gds.go index dab2f735..7da5e07d 100644 --- a/internal/discover/gds.go +++ b/internal/discover/gds.go @@ -45,7 +45,10 @@ func NewGDSDiscoverer(logger *logrus.Logger, root string) (Discover, error) { cufile := NewMounts( logger, - lookup.NewFileLocator(logger, root), + lookup.NewFileLocator( + lookup.WithLogger(logger), + lookup.WithRoot(root), + ), root, []string{"/etc/cufile.json"}, ) diff --git a/internal/discover/graphics.go b/internal/discover/graphics.go index c24cb4ad..ab858405 100644 --- a/internal/discover/graphics.go +++ b/internal/discover/graphics.go @@ -70,17 +70,18 @@ func NewGraphicsMountsDiscoverer(logger *logrus.Logger, root string) (Discover, jsonMounts := NewMounts( logger, - lookup.NewFileLocator(logger, root), + lookup.NewFileLocator( + lookup.WithLogger(logger), + lookup.WithRoot(root), + lookup.WithSearchPaths("/etc", "/usr/share"), + ), root, []string{ - // TODO: We should handle this more cleanly - "/etc/glvnd/egl_vendor.d/10_nvidia.json", - "/etc/vulkan/icd.d/nvidia_icd.json", - "/etc/vulkan/implicit_layer.d/nvidia_layers.json", - "/usr/share/glvnd/egl_vendor.d/10_nvidia.json", - "/usr/share/vulkan/icd.d/nvidia_icd.json", - "/usr/share/vulkan/implicit_layer.d/nvidia_layers.json", - "/usr/share/egl/egl_external_platform.d/15_nvidia_gbm.json", + "glvnd/egl_vendor.d/10_nvidia.json", + "vulkan/icd.d/nvidia_icd.json", + "vulkan/implicit_layer.d/nvidia_layers.json", + "egl/egl_external_platform.d/15_nvidia_gbm.json", + "egl/egl_external_platform.d/10_nvidia_wayland.json", }, ) @@ -164,7 +165,10 @@ func (d drmDevicesByPath) getSpecificLinkArgs(devices []Device) ([]string, error selectedDevices[filepath.Base(d.HostPath)] = true } - linkLocator := lookup.NewFileLocator(d.logger, d.root) + linkLocator := lookup.NewFileLocator( + lookup.WithLogger(d.logger), + lookup.WithRoot(d.root), + ) candidates, err := linkLocator.Locate("/dev/dri/by-path/pci-*-*") if err != nil { d.logger.Warningf("Failed to locate by-path links: %v; ignoring", err) diff --git a/internal/lookup/device.go b/internal/lookup/device.go index 29738cc7..dfa774ea 100644 --- a/internal/lookup/device.go +++ b/internal/lookup/device.go @@ -19,7 +19,6 @@ package lookup import ( "fmt" "os" - "path/filepath" "github.com/sirupsen/logrus" ) @@ -31,13 +30,12 @@ const ( // NewCharDeviceLocator creates a Locator that can be used to find char devices at the specified root. A logger is // also specified. func NewCharDeviceLocator(logger *logrus.Logger, root string) Locator { - l := file{ - logger: logger, - prefixes: []string{root, filepath.Join(root, devRoot)}, - filter: assertCharDevice, - } - - return &l + return NewFileLocator( + WithLogger(logger), + WithRoot(root), + WithSearchPaths("", devRoot), + WithFilter(assertCharDevice), + ) } // assertCharDevice checks whether the specified path is a char device and returns an error if this is not the case. diff --git a/internal/lookup/device_test.go b/internal/lookup/device_test.go new file mode 100644 index 00000000..7062e51e --- /dev/null +++ b/internal/lookup/device_test.go @@ -0,0 +1,55 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package lookup + +import ( + "fmt" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestCharDeviceLocator(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + testCases := []struct { + root string + expectedPrefixes []string + }{ + { + root: "", + expectedPrefixes: []string{"", "/dev"}, + }, + { + root: "/", + expectedPrefixes: []string{"/", "/dev"}, + }, + { + root: "/some/root", + expectedPrefixes: []string{"/some/root", "/some/root/dev"}, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + f := NewCharDeviceLocator(logger, tc.root).(*file) + + require.EqualValues(t, tc.expectedPrefixes, f.prefixes) + }) + } +} diff --git a/internal/lookup/dir.go b/internal/lookup/dir.go index 695025a5..87b32507 100644 --- a/internal/lookup/dir.go +++ b/internal/lookup/dir.go @@ -26,13 +26,11 @@ import ( // NewDirectoryLocator creates a Locator that can be used to find directories at the specified root. A logger // is also specified. func NewDirectoryLocator(logger *log.Logger, root string) Locator { - l := file{ - logger: logger, - prefixes: []string{root}, - filter: assertDirectory, - } - - return &l + return NewFileLocator( + WithLogger(logger), + WithRoot(root), + WithFilter(assertDirectory), + ) } // assertDirectory checks wither the specified path is a directory. diff --git a/internal/lookup/executable.go b/internal/lookup/executable.go index 2bfe04ea..e7ddf450 100644 --- a/internal/lookup/executable.go +++ b/internal/lookup/executable.go @@ -19,7 +19,6 @@ package lookup import ( "fmt" "os" - "path/filepath" "strings" log "github.com/sirupsen/logrus" @@ -33,17 +32,21 @@ type executable struct { func NewExecutableLocator(logger *log.Logger, root string) Locator { paths := GetPaths(root) - var prefixes []string - for _, dir := range paths { - prefixes = append(prefixes, filepath.Join(root, dir)) - } + return newExecutableLocator(logger, root, paths...) +} + +func newExecutableLocator(logger *log.Logger, root string, paths ...string) *executable { + f := newFileLocator( + WithLogger(logger), + WithRoot(root), + WithSearchPaths(paths...), + WithFilter(assertExecutable), + ) + l := executable{ - file: file{ - logger: logger, - prefixes: prefixes, - filter: assertExecutable, - }, + file: *f, } + return &l } diff --git a/internal/lookup/executable_test.go b/internal/lookup/executable_test.go new file mode 100644 index 00000000..09bbf2d9 --- /dev/null +++ b/internal/lookup/executable_test.go @@ -0,0 +1,77 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package lookup + +import ( + "fmt" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/require" +) + +func TestExecutableLocator(t *testing.T) { + logger, _ := testlog.NewNullLogger() + + testCases := []struct { + root string + paths []string + expectedPrefixes []string + }{ + { + root: "", + expectedPrefixes: []string{""}, + }, + { + root: "", + paths: []string{"/"}, + expectedPrefixes: []string{"/"}, + }, + { + root: "", + paths: []string{"/", "/bin"}, + expectedPrefixes: []string{"/", "/bin"}, + }, + { + root: "/", + expectedPrefixes: []string{"/"}, + }, + { + root: "/", + paths: []string{"/"}, + expectedPrefixes: []string{"/"}, + }, + { + root: "/", + paths: []string{"/", "/bin"}, + expectedPrefixes: []string{"/", "/bin"}, + }, + { + root: "/some/path", + paths: []string{"/", "/bin"}, + expectedPrefixes: []string{"/some/path", "/some/path/bin"}, + }, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + e := newExecutableLocator(logger, tc.root, tc.paths...) + + require.EqualValues(t, tc.expectedPrefixes, e.prefixes) + }) + } +} diff --git a/internal/lookup/file.go b/internal/lookup/file.go index ab52d03d..94702174 100644 --- a/internal/lookup/file.go +++ b/internal/lookup/file.go @@ -28,26 +28,98 @@ import ( // prefixes. The validity of a file is determined by a filter function. type file struct { logger *log.Logger + root string prefixes []string filter func(string) error } -// NewFileLocator creates a Locator that can be used to find files at the specified root. A logger -// can also be specified. -func NewFileLocator(logger *log.Logger, root string) Locator { - l := newFileLocator(logger, root) +// Option defines a function for passing options to the NewFileLocator() call +type Option func(*file) - return &l +// WithRoot sets the root for the file locator +func WithRoot(root string) Option { + return func(f *file) { + f.root = root + } } -func newFileLocator(logger *log.Logger, root string) file { - return file{ - logger: logger, - prefixes: []string{root}, - filter: assertFile, +// WithLogger sets the logger for the file locator +func WithLogger(logger *log.Logger) Option { + return func(f *file) { + f.logger = logger } } +// WithSearchPaths sets the search paths for the file locator. +func WithSearchPaths(paths ...string) Option { + return func(f *file) { + f.prefixes = paths + } +} + +// WithFilter sets the filter for the file locator +// The filter is called for each candidate file and candidates that return nil are considered. +func WithFilter(assert func(string) error) Option { + return func(f *file) { + f.filter = assert + } +} + +// NewFileLocator creates a Locator that can be used to find files with the specified options. +func NewFileLocator(opts ...Option) Locator { + return newFileLocator(opts...) +} + +func newFileLocator(opts ...Option) *file { + f := &file{} + for _, opt := range opts { + opt(f) + } + if f.logger == nil { + f.logger = log.StandardLogger() + } + if f.filter == nil { + f.filter = assertFile + } + // Since the `Locate` implementations rely on the root already being specified we update + // the prefixes to include the root. + f.prefixes = getSearchPrefixes(f.root, f.prefixes...) + + return f +} + +// getSearchPrefixes generates a list of unique paths to be searched by a file locator. +// +// For each of the unique prefixes
specified, the path is searched, where