diff --git a/test/container/toolkit_test.sh b/test/container/toolkit_test.sh index 5fcc4a08..7a21169c 100644 --- a/test/container/toolkit_test.sh +++ b/test/container/toolkit_test.sh @@ -31,6 +31,8 @@ testing::toolkit::install() { test -L "${shared_dir}/usr/local/nvidia/toolkit/libnvidia-container.so.1" test -e "$(${READLINK} -f "${shared_dir}/usr/local/nvidia/toolkit/libnvidia-container.so.1")" + test -L "${shared_dir}/usr/local/nvidia/toolkit/libnvidia-container-go.so.1" + test -e "$(${READLINK} -f "${shared_dir}/usr/local/nvidia/toolkit/libnvidia-container-go.so.1")" test -e "${shared_dir}/usr/local/nvidia/toolkit/nvidia-container-cli" test -e "${shared_dir}/usr/local/nvidia/toolkit/nvidia-container-toolkit" diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index 6aa3fc3b..bcc8f805 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -152,7 +152,7 @@ func Install(cli *cli.Context) error { return fmt.Errorf("could not create required directories: %v", err) } - err = installContainerLibrary(toolkitDirArg) + err = installContainerLibraries(toolkitDirArg) if err != nil { return fmt.Errorf("error installing NVIDIA container library: %v", err) } @@ -180,14 +180,31 @@ func Install(cli *cli.Context) error { return nil } -// installContainerLibrary locates and installs the libnvidia-container.so.1 library. +// installContainerLibraries locates and installs the libraries that are part of +// the nvidia-container-toolkit. // A predefined set of library candidates are considered, with the first one // resulting in success being installed to the toolkit folder. The install process // resolves the symlink for the library and copies the versioned library itself. -func installContainerLibrary(toolkitDir string) error { +func installContainerLibraries(toolkitDir string) error { log.Infof("Installing NVIDIA container library to '%v'", toolkitDir) - const libName = "libnvidia-container.so.1" + libs := []string{ + "libnvidia-container.so.1", + "libnvidia-container-go.so.1", + } + + for _, l := range libs { + err := installLibrary(l, toolkitDir) + if err != nil { + return fmt.Errorf("failed to install %s: %v", l, err) + } + } + + return nil +} + +// installLibrary installs the specified library to the toolkit directory. +func installLibrary(libName string, toolkitDir string) error { libraryPath, err := findLibrary("", libName) if err != nil { return fmt.Errorf("error locating NVIDIA container library: %v", err)