From 23f1ba3e9355d066b2dc6077b2d4130579e165e5 Mon Sep 17 00:00:00 2001 From: Christopher Desiniotis Date: Wed, 16 Oct 2024 14:50:59 -0700 Subject: [PATCH] Add unit tests for create-symlinks hook Signed-off-by: Christopher Desiniotis Signed-off-by: Evan Lezar --- .../create-symlinks/create-symlinks.go | 11 +++ .../create-symlinks/create-symlinks_test.go | 96 +++++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 cmd/nvidia-cdi-hook/create-symlinks/create-symlinks_test.go diff --git a/cmd/nvidia-cdi-hook/create-symlinks/create-symlinks.go b/cmd/nvidia-cdi-hook/create-symlinks/create-symlinks.go index 1526b049..dd9b5710 100644 --- a/cmd/nvidia-cdi-hook/create-symlinks/create-symlinks.go +++ b/cmd/nvidia-cdi-hook/create-symlinks/create-symlinks.go @@ -107,6 +107,17 @@ func (m command) run(c *cli.Context, cfg *config) error { return nil } +// createLink creates a symbolic link in the specified container root. +// This is equivalent to: +// +// chroot {{ .containerRoot }} ln -s {{ .target }} {{ .link }} +// +// If the specified link already exists and points to the same target, this +// operation is a no-op. If the link points to a different target, an error is +// returned. +// +// Note that if the link path resolves to an absolute path oudside of the +// specified root, this is treated as an absolute path in this root. func (m command) createLink(containerRoot string, targetPath string, link string) error { linkPath := filepath.Join(containerRoot, link) diff --git a/cmd/nvidia-cdi-hook/create-symlinks/create-symlinks_test.go b/cmd/nvidia-cdi-hook/create-symlinks/create-symlinks_test.go new file mode 100644 index 00000000..7ce481cf --- /dev/null +++ b/cmd/nvidia-cdi-hook/create-symlinks/create-symlinks_test.go @@ -0,0 +1,96 @@ +package symlinks + +import ( + "os" + "path/filepath" + "testing" + + testlog "github.com/sirupsen/logrus/hooks/test" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/symlinks" +) + +func TestCreateLinkRelativePath(t *testing.T) { + tmpDir := t.TempDir() + hostRoot := filepath.Join(tmpDir, "/host-root/") + containerRoot := filepath.Join(tmpDir, "/container-root") + + require.NoError(t, makeFs(hostRoot)) + require.NoError(t, makeFs(containerRoot, dirOrLink{path: "/lib/"})) + + // nvidia-cdi-hook create-symlinks --link libfoo.so.1::/lib/libfoo.so + err := getTestCommand().createLink(containerRoot, "libfoo.so.1", "/lib/libfoo.so") + require.NoError(t, err) + + target, err := symlinks.Resolve(filepath.Join(containerRoot, "/lib/libfoo.so")) + require.NoError(t, err) + require.Equal(t, "libfoo.so.1", target) +} + +func TestCreateLinkAbsolutePath(t *testing.T) { + tmpDir := t.TempDir() + hostRoot := filepath.Join(tmpDir, "/host-root/") + containerRoot := filepath.Join(tmpDir, "/container-root") + + require.NoError(t, makeFs(hostRoot)) + require.NoError(t, makeFs(containerRoot, dirOrLink{path: "/lib/"})) + + // nvidia-cdi-hook create-symlinks --link /lib/libfoo.so.1::/lib/libfoo.so + err := getTestCommand().createLink(containerRoot, "/lib/libfoo.so.1", "/lib/libfoo.so") + require.NoError(t, err) + + target, err := symlinks.Resolve(filepath.Join(containerRoot, "/lib/libfoo.so")) + require.NoError(t, err) + require.Equal(t, "/lib/libfoo.so.1", target) +} + +func TestCreateLinkAlreadyExists(t *testing.T) { + tmpDir := t.TempDir() + hostRoot := filepath.Join(tmpDir, "/host-root/") + containerRoot := filepath.Join(tmpDir, "/container-root") + + require.NoError(t, makeFs(hostRoot)) + require.NoError(t, makeFs(containerRoot, dirOrLink{path: "/lib/libfoo.so", target: "libfoo.so.1"})) + + // nvidia-cdi-hook create-symlinks --link libfoo.so.1::/lib/libfoo.so + err := getTestCommand().createLink(containerRoot, "libfoo.so.1", "/lib/libfoo.so") + require.Error(t, err) + target, err := symlinks.Resolve(filepath.Join(containerRoot, "lib/libfoo.so")) + require.NoError(t, err) + require.Equal(t, "libfoo.so.1", target) +} + +type dirOrLink struct { + path string + target string +} + +func makeFs(tmpdir string, fs ...dirOrLink) error { + if err := os.MkdirAll(tmpdir, 0o755); err != nil { + return err + } + for _, s := range fs { + s.path = filepath.Join(tmpdir, s.path) + if s.target == "" { + _ = os.MkdirAll(s.path, 0o755) + continue + } + if err := os.MkdirAll(filepath.Dir(s.path), 0o755); err != nil { + return err + } + if err := os.Symlink(s.target, s.path); err != nil && !os.IsExist(err) { + return err + } + } + return nil +} + +// getTestCommand creates a command for running tests against. +func getTestCommand() *command { + logger, _ := testlog.NewNullLogger() + return &command{ + logger: logger, + } +}