From 5d011c13337c189e4726b4950f42f21be32c6c59 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Tue, 7 Feb 2023 23:32:04 +0100 Subject: [PATCH 1/6] Add Discoverer to create a single symlink Signed-off-by: Evan Lezar --- internal/discover/hooks.go | 41 +++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/internal/discover/hooks.go b/internal/discover/hooks.go index dfe98988..87202e7c 100644 --- a/internal/discover/hooks.go +++ b/internal/discover/hooks.go @@ -29,12 +29,47 @@ const ( nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk" ) +var _ Discover = (*Hook)(nil) + +// Devices returns an empty list of devices for a Hook discoverer. +func (h Hook) Devices() ([]Device, error) { + return nil, nil +} + +// Mounts returns an empty list of mounts for a Hook discoverer. +func (h Hook) Mounts() ([]Mount, error) { + return nil, nil +} + +// Hooks allows the Hook type to also implement the Discoverer interface. +// It returns a single hook +func (h Hook) Hooks() ([]Hook, error) { + return []Hook{h}, nil +} + +// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target. +func CreateCreateSymlinkHook(nvidiaCTKPath string, links []string) Discover { + if len(links) == 0 { + return None{} + } + + var args []string + for _, link := range links { + args = append(args, "--link", link) + } + return CreateNvidiaCTKHook( + nvidiaCTKPath, + "create-symlinks", + args..., + ) +} + // CreateNvidiaCTKHook creates a hook which invokes the NVIDIA Container CLI hook subcommand. -func CreateNvidiaCTKHook(executable string, hookName string, additionalArgs ...string) Hook { +func CreateNvidiaCTKHook(nvidiaCTKPath string, hookName string, additionalArgs ...string) Hook { return Hook{ Lifecycle: cdi.CreateContainerHook, - Path: executable, - Args: append([]string{filepath.Base(executable), "hook", hookName}, additionalArgs...), + Path: nvidiaCTKPath, + Args: append([]string{filepath.Base(nvidiaCTKPath), "hook", hookName}, additionalArgs...), } } From 7eb435eb730f28e2cd95baeb6e827e18400051e9 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 1 Feb 2023 16:17:57 +0100 Subject: [PATCH 2/6] Add basic dxcore bindings This change copies dxcore.h and dxcore.c from libnvidia-container to allow for the driver store path to be queried. Modifications are made to dxcore to remove the code associated with checking the components in the driver store path. Signed-off-by: Evan Lezar --- internal/dxcore/api.go | 58 +++++++ internal/dxcore/dxcore.c | 334 ++++++++++++++++++++++++++++++++++++++ internal/dxcore/dxcore.go | 59 +++++++ internal/dxcore/dxcore.h | 39 +++++ 4 files changed, 490 insertions(+) create mode 100644 internal/dxcore/api.go create mode 100644 internal/dxcore/dxcore.c create mode 100644 internal/dxcore/dxcore.go create mode 100644 internal/dxcore/dxcore.h diff --git a/internal/dxcore/api.go b/internal/dxcore/api.go new file mode 100644 index 00000000..4408c29a --- /dev/null +++ b/internal/dxcore/api.go @@ -0,0 +1,58 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package dxcore + +import ( + "github.com/NVIDIA/go-nvml/pkg/dl" +) + +const ( + libraryName = "libdxcore.so" + libraryLoadFlags = dl.RTLD_LAZY | dl.RTLD_GLOBAL +) + +// dxcore stores a reference the dxcore dynamic library +var dxcore *context + +// Init initializes the dxcore dynamic library +func Init() error { + c, err := initContext() + if err != nil { + return err + } + dxcore = c + return nil +} + +// Shutdown closes the dxcore dynamic library +func Shutdown() error { + if dxcore != nil && dxcore.initialized != 0 { + dxcore.deinitContext() + } + return nil +} + +// GetDriverStorePaths returns the list of driver store paths +func GetDriverStorePaths() []string { + var paths []string + for i := 0; i < dxcore.getAdapterCount(); i++ { + adapter := dxcore.getAdapter(i) + paths = append(paths, adapter.getDriverStorePath()) + } + + return paths +} diff --git a/internal/dxcore/dxcore.c b/internal/dxcore/dxcore.c new file mode 100644 index 00000000..0b61143f --- /dev/null +++ b/internal/dxcore/dxcore.c @@ -0,0 +1,334 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + */ + +#include +#include + +#include "dxcore.h" + +// We define log_write as an empty macro to allow dxcore to remain unchanged. +#define log_write(...) + +// We define the following macros to allow dxcore to remain largely unchanged. +#define log_info(msg) log_write('I', __FILE__, __LINE__, msg) +#define log_warn(msg) log_write('W', __FILE__, __LINE__, msg) +#define log_err(msg) log_write('E', __FILE__, __LINE__, msg) +#define log_infof(fmt, ...) log_write('I', __FILE__, __LINE__, fmt, __VA_ARGS__) +#define log_warnf(fmt, ...) log_write('W', __FILE__, __LINE__, fmt, __VA_ARGS__) +#define log_errf(fmt, ...) log_write('E', __FILE__, __LINE__, fmt, __VA_ARGS__) + + +#define DXCORE_MAX_PATH 260 + +/* + * List of components we expect to find in the driver store that we need to mount + */ +static const char * const dxcore_nvidia_driver_store_components[] = { + "libcuda.so.1.1", /* Core library for cuda support */ + "libcuda_loader.so", /* Core library for cuda support on WSL */ + "libnvidia-ptxjitcompiler.so.1", /* Core library for PTX Jit support */ + "libnvidia-ml.so.1", /* Core library for nvml */ + "libnvidia-ml_loader.so", /* Core library for nvml on WSL */ + "nvidia-smi", /* nvidia-smi binary*/ + "nvcubins.bin", /* Binary containing GPU code for cuda */ +}; + + +/* + * List of functions and structures we need to communicate with libdxcore. + * Documentation on these functions can be found on docs.microsoft.com in d3dkmthk. + */ + +struct dxcore_enumAdapters2; +struct dxcore_queryAdapterInfo; + +typedef int(*pfnDxcoreEnumAdapters2)(struct dxcore_enumAdapters2* pParams); +typedef int(*pfnDxcoreQueryAdapterInfo)(struct dxcore_queryAdapterInfo* pParams); + +struct dxcore_lib { + void* hDxcoreLib; + pfnDxcoreEnumAdapters2 pDxcoreEnumAdapters2; + pfnDxcoreQueryAdapterInfo pDxcoreQueryAdapterInfo; +}; + +struct dxcore_adapterInfo +{ + unsigned int hAdapter; + struct dxcore_luid AdapterLuid; + unsigned int NumOfSources; + unsigned int bPresentMoveRegionsPreferred; +}; + +struct dxcore_enumAdapters2 +{ + unsigned int NumAdapters; + struct dxcore_adapterInfo *pAdapters; +}; + +enum dxcore_kmtqueryAdapterInfoType +{ + DXCORE_QUERYDRIVERVERSION = 13, + DXCORE_QUERYREGISTRY = 48, +}; + +enum dxcore_queryregistry_type { + DXCORE_QUERYREGISTRY_DRIVERSTOREPATH = 2, + DXCORE_QUERYREGISTRY_DRIVERIMAGEPATH = 3, +}; + +enum dxcore_queryregistry_status { + DXCORE_QUERYREGISTRY_STATUS_SUCCESS = 0, + DXCORE_QUERYREGISTRY_STATUS_BUFFER_OVERFLOW = 1, + DXCORE_QUERYREGISTRY_STATUS_FAIL = 2, +}; + +struct dxcore_queryregistry_info { + enum dxcore_queryregistry_type QueryType; + unsigned int QueryFlags; + wchar_t ValueName[DXCORE_MAX_PATH]; + unsigned int ValueType; + unsigned int PhysicalAdapterIndex; + unsigned int OutputValueSize; + enum dxcore_queryregistry_status Status; + union { + unsigned long long OutputQword; + wchar_t Output; + }; +}; + +struct dxcore_queryAdapterInfo +{ + unsigned int hAdapter; + enum dxcore_kmtqueryAdapterInfoType Type; + void *pPrivateDriverData; + unsigned int PrivateDriverDataSize; +}; + +static int dxcore_query_adapter_info_helper(struct dxcore_lib* pLib, + unsigned int hAdapter, + enum dxcore_kmtqueryAdapterInfoType type, + void* pPrivateDriverDate, + unsigned int privateDriverDataSize) +{ + struct dxcore_queryAdapterInfo queryAdapterInfo = { 0 }; + + queryAdapterInfo.hAdapter = hAdapter; + queryAdapterInfo.Type = type; + queryAdapterInfo.pPrivateDriverData = pPrivateDriverDate; + queryAdapterInfo.PrivateDriverDataSize = privateDriverDataSize; + + return pLib->pDxcoreQueryAdapterInfo(&queryAdapterInfo); +} + +static int dxcore_query_adapter_wddm_version(struct dxcore_lib* pLib, unsigned int hAdapter, unsigned int* version) +{ + return dxcore_query_adapter_info_helper(pLib, + hAdapter, + DXCORE_QUERYDRIVERVERSION, + (void*)version, + sizeof(*version)); +} + +static int dxcore_query_adapter_driverstore(struct dxcore_lib* pLib, unsigned int hAdapter, char** ppDriverStorePath) +{ + struct dxcore_queryregistry_info params = {0}; + struct dxcore_queryregistry_info* pValue = NULL; + wchar_t* pOutput; + size_t outputSizeInBytes; + size_t outputSize; + + params.QueryType = DXCORE_QUERYREGISTRY_DRIVERSTOREPATH; + + if (dxcore_query_adapter_info_helper(pLib, + hAdapter, + DXCORE_QUERYREGISTRY, + (void*)¶ms, + sizeof(params))) + { + log_err("Failed to query driver store path size for the WDDM Adapter"); + return (-1); + } + + if (params.OutputValueSize > DXCORE_MAX_PATH * sizeof(wchar_t)) { + log_err("The driver store path size returned by dxcore is not valid"); + return (-1); + } + + outputSizeInBytes = (size_t)params.OutputValueSize; + outputSize = outputSizeInBytes / sizeof(wchar_t); + + pValue = calloc(sizeof(struct dxcore_queryregistry_info) + outputSizeInBytes + sizeof(wchar_t), 1); + if (!pValue) { + log_err("Out of memory while allocating temp buffer to query adapter info"); + return (-1); + } + + pValue->QueryType = DXCORE_QUERYREGISTRY_DRIVERSTOREPATH; + pValue->OutputValueSize = (unsigned int)outputSizeInBytes; + + if (dxcore_query_adapter_info_helper(pLib, + hAdapter, + DXCORE_QUERYREGISTRY, + (void*)pValue, + (unsigned int)(sizeof(struct dxcore_queryregistry_info) + outputSizeInBytes))) + { + log_err("Failed to query driver store path data for the WDDM Adapter"); + free(pValue); + return (-1); + } + pOutput = (wchar_t*)(&pValue->Output); + + // Make sure no matter what happened the wchar_t string is null terminated + pOutput[outputSize] = L'\0'; + + // Convert the output into a regular c string + *ppDriverStorePath = (char*)calloc(outputSize + 1, sizeof(char)); + if (!*ppDriverStorePath) { + log_err("Out of memory while allocating the buffer for the driver store path"); + free(pValue); + return (-1); + } + wcstombs(*ppDriverStorePath, pOutput, outputSize); + + free(pValue); + + return 0; +} + +static void dxcore_add_adapter(struct dxcore_context* pCtx, struct dxcore_lib* pLib, struct dxcore_adapterInfo *pAdapterInfo) +{ + unsigned int wddmVersion = 0; + char* driverStorePath = NULL; + + log_infof("Creating a new WDDM Adapter for hAdapter:%x luid:%llx", pAdapterInfo->hAdapter, *((unsigned long long*)&pAdapterInfo->AdapterLuid)); + + if (dxcore_query_adapter_wddm_version(pLib, pAdapterInfo->hAdapter, &wddmVersion)) { + log_err("Failed to query the WDDM version for the specified adapter. Skipping it."); + return; + } + + if (wddmVersion < 2700) { + log_err("Found a WDDM adapter running a driver with pre-WDDM 2.7 . Skipping it."); + return; + } + + if (dxcore_query_adapter_driverstore(pLib, pAdapterInfo->hAdapter, &driverStorePath)) { + log_err("Failed to query driver store path for the WDDM Adapter . Skipping it."); + return; + } + + // We got all the info we needed. Adding it to the tracking structure. + { + struct dxcore_adapter* newList; + newList = realloc(pCtx->adapterList, sizeof(struct dxcore_adapter) * (pCtx->adapterCount + 1)); + if (!newList) { + log_err("Out of memory when trying to add a new WDDM Adapter to the list of valid adapters"); + free(driverStorePath); + return; + } + + pCtx->adapterList = newList; + + pCtx->adapterList[pCtx->adapterCount].hAdapter = pAdapterInfo->hAdapter; + pCtx->adapterList[pCtx->adapterCount].pDriverStorePath = driverStorePath; + pCtx->adapterList[pCtx->adapterCount].wddmVersion = wddmVersion; + pCtx->adapterCount++; + } + + log_infof("Adding new adapter via dxcore hAdapter:%x luid:%llx wddm version:%d", pAdapterInfo->hAdapter, *((unsigned long long*)&pAdapterInfo->AdapterLuid), wddmVersion); +} + +static void dxcore_enum_adapters(struct dxcore_context* pCtx, struct dxcore_lib* pLib) +{ + struct dxcore_enumAdapters2 params = {0}; + unsigned int adapterIndex = 0; + + params.NumAdapters = 0; + params.pAdapters = NULL; + + if (pLib->pDxcoreEnumAdapters2(¶ms)) { + log_err("Failed to enumerate adapters via dxcore"); + return; + } + + params.pAdapters = malloc(sizeof(struct dxcore_adapterInfo) * params.NumAdapters); + if (pLib->pDxcoreEnumAdapters2(¶ms)) { + free(params.pAdapters); + log_err("Failed to enumerate adapters via dxcore"); + return; + } + + for (adapterIndex = 0; adapterIndex < params.NumAdapters; adapterIndex++) { + dxcore_add_adapter(pCtx, pLib, ¶ms.pAdapters[adapterIndex]); + } + + free(params.pAdapters); +} + +int dxcore_init_context(struct dxcore_context* pCtx) +{ + struct dxcore_lib lib = {0}; + + pCtx->initialized = 0; + pCtx->adapterCount = 0; + pCtx->adapterList = NULL; + + lib.hDxcoreLib = dlopen("libdxcore.so", RTLD_LAZY); + if (!lib.hDxcoreLib) { + goto error; + } + + lib.pDxcoreEnumAdapters2 = (pfnDxcoreEnumAdapters2)dlsym(lib.hDxcoreLib, "D3DKMTEnumAdapters2"); + if (!lib.pDxcoreEnumAdapters2) { + log_err("dxcore library is present but the symbol D3DKMTEnumAdapters2 is missing"); + goto error; + } + + lib.pDxcoreQueryAdapterInfo = (pfnDxcoreQueryAdapterInfo)dlsym(lib.hDxcoreLib, "D3DKMTQueryAdapterInfo"); + if (!lib.pDxcoreQueryAdapterInfo) { + log_err("dxcore library is present but the symbol D3DKMTQueryAdapterInfo is missing"); + goto error; + } + + dxcore_enum_adapters(pCtx, &lib); + + log_info("dxcore layer initialized successfully"); + pCtx->initialized = 1; + + dlclose(lib.hDxcoreLib); + + return 0; + +error: + dxcore_deinit_context(pCtx); + + if (lib.hDxcoreLib) + dlclose(lib.hDxcoreLib); + + return (-1); +} + +static void dxcore_deinit_adapter(struct dxcore_adapter* pAdapter) +{ + if (!pAdapter) + return; + + free(pAdapter->pDriverStorePath); +} + +void dxcore_deinit_context(struct dxcore_context* pCtx) +{ + unsigned int adapterIndex = 0; + + if (!pCtx) + return; + + for (adapterIndex = 0; adapterIndex < pCtx->adapterCount; adapterIndex++) { + dxcore_deinit_adapter(&pCtx->adapterList[adapterIndex]); + } + + free(pCtx->adapterList); + + pCtx->initialized = 0; +} diff --git a/internal/dxcore/dxcore.go b/internal/dxcore/dxcore.go new file mode 100644 index 00000000..76cc53f8 --- /dev/null +++ b/internal/dxcore/dxcore.go @@ -0,0 +1,59 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package dxcore + +/* +#cgo LDFLAGS: -Wl,--unresolved-symbols=ignore-in-object-files +#include +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +type context C.struct_dxcore_context +type adapter C.struct_dxcore_adapter + +// initContext initializes the dxcore context and populates the list of adapters. +func initContext() (*context, error) { + cContext := C.struct_dxcore_context{} + if C.dxcore_init_context(&cContext) != 0 { + return nil, fmt.Errorf("failed to initialize dxcore context") + } + c := (*context)(&cContext) + return c, nil +} + +// deinitContext deinitializes the dxcore context and frees the list of adapters. +func (c context) deinitContext() { + cContext := C.struct_dxcore_context(c) + C.dxcore_deinit_context(&cContext) +} + +func (c context) getAdapterCount() int { + return int(c.adapterCount) +} + +func (c context) getAdapter(index int) adapter { + arrayPointer := (*[1 << 30]C.struct_dxcore_adapter)(unsafe.Pointer(c.adapterList)) + return adapter(arrayPointer[index]) +} + +func (a adapter) getDriverStorePath() string { + return C.GoString(a.pDriverStorePath) +} diff --git a/internal/dxcore/dxcore.h b/internal/dxcore/dxcore.h new file mode 100644 index 00000000..9c044fee --- /dev/null +++ b/internal/dxcore/dxcore.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + */ + +#ifndef HEADER_DXCORE_H_ +#define HEADER_DXCORE_H_ + +#define MAX_DXCORE_DRIVERSTORE_LIBRAIRIES (16) + +struct dxcore_luid +{ + unsigned int lowPart; + int highPart; +}; + +struct dxcore_adapter +{ + unsigned int hAdapter; + unsigned int wddmVersion; + char* pDriverStorePath; + unsigned int driverStoreComponentCount; + const char* pDriverStoreComponents[MAX_DXCORE_DRIVERSTORE_LIBRAIRIES]; + struct dxcore_context *pContext; +}; + +struct dxcore_context +{ + unsigned int adapterCount; + struct dxcore_adapter *adapterList; + + int initialized; +}; + + + +int dxcore_init_context(struct dxcore_context* pDxcore_context); +void dxcore_deinit_context(struct dxcore_context* pDxcore_context); + +#endif // HEADER_DXCORE_H_ From 5103adab89a82b87ba8a66df3b8313876502cc5d Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 16 Feb 2023 15:52:39 +0100 Subject: [PATCH 3/6] Add mode option to nvcdi API Signed-off-by: Evan Lezar --- pkg/nvcdi/lib.go | 1 + pkg/nvcdi/options.go | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 0ea1797f..592ff186 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -28,6 +28,7 @@ import ( type nvcdilib struct { logger *logrus.Logger nvmllib nvml.Interface + mode string devicelib device.Interface deviceNamer DeviceNamer driverRoot string diff --git a/pkg/nvcdi/options.go b/pkg/nvcdi/options.go index 50b699a6..317cace2 100644 --- a/pkg/nvcdi/options.go +++ b/pkg/nvcdi/options.go @@ -66,3 +66,10 @@ func WithNvmlLib(nvmllib nvml.Interface) Option { l.nvmllib = nvmllib } } + +// WithMode sets the discovery mode for the library +func WithMode(mode string) Option { + return func(l *nvcdilib) { + l.mode = mode + } +} From 20d6e9af04086328c955d68c66bcc14d5565a153 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 8 Feb 2023 15:29:35 +0100 Subject: [PATCH 4/6] Add --discovery-mode to nvidia-ctk cdi generate command This change adds --discovery-mode flag to the nvidia-ctk cdi generate command and plumbs this through to the CDI API. Signed-off-by: Evan Lezar --- cmd/nvidia-ctk/cdi/generate/generate.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 77165e71..eb869d28 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -36,6 +36,8 @@ import ( ) const ( + discoveryModeNVML = "nvml" + formatJSON = "json" formatYAML = "yaml" ) @@ -50,6 +52,7 @@ type config struct { deviceNameStrategy string driverRoot string nvidiaCTKPath string + discoveryMode string } // NewCommand constructs a generate-cdi command with the specified logger @@ -88,6 +91,12 @@ func (m command) build() *cli.Command { Value: formatYAML, Destination: &cfg.format, }, + &cli.StringFlag{ + Name: "discovery-mode", + Usage: "The mode to use when discovering the available entities. One of [nvml]", + Value: discoveryModeNVML, + Destination: &cfg.discoveryMode, + }, &cli.StringFlag{ Name: "device-name-strategy", Usage: "Specify the strategy for generating device names. One of [index | uuid | type-index]", @@ -118,6 +127,13 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { return fmt.Errorf("invalid output format: %v", cfg.format) } + cfg.discoveryMode = strings.ToLower(cfg.discoveryMode) + switch cfg.discoveryMode { + case discoveryModeNVML: + default: + return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode) + } + _, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy) if err != nil { return err @@ -229,6 +245,7 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { nvcdi.WithDeviceNamer(deviceNamer), nvcdi.WithDeviceLib(devicelib), nvcdi.WithNvmlLib(nvmllib), + nvcdi.WithMode(string(cfg.discoveryMode)), ) deviceSpecs, err := cdilib.GetAllDeviceSpecs() @@ -298,3 +315,5 @@ func createParentDirsIfRequired(filename string) error { } return os.MkdirAll(dir, 0755) } + +type discoveryMode string From d226925fe771fdfc2a8a5697dafc130f2ae70181 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Thu, 16 Feb 2023 17:29:53 +0200 Subject: [PATCH 5/6] Construct nvml-based CDI lib based on mode Signed-off-by: Evan Lezar --- pkg/nvcdi/{common.go => common-nvml.go} | 16 +--- pkg/nvcdi/{driver.go => driver-nvml.go} | 0 pkg/nvcdi/{full-gpu.go => full-gpu-nvml.go} | 4 +- pkg/nvcdi/lib-nvml.go | 93 +++++++++++++++++++ pkg/nvcdi/lib.go | 71 +++----------- .../{mig-device.go => mig-device-nvml.go} | 4 +- 6 files changed, 111 insertions(+), 77 deletions(-) rename pkg/nvcdi/{common.go => common-nvml.go} (68%) rename pkg/nvcdi/{driver.go => driver-nvml.go} (100%) rename pkg/nvcdi/{full-gpu.go => full-gpu-nvml.go} (97%) create mode 100644 pkg/nvcdi/lib-nvml.go rename pkg/nvcdi/{mig-device.go => mig-device-nvml.go} (94%) diff --git a/pkg/nvcdi/common.go b/pkg/nvcdi/common-nvml.go similarity index 68% rename from pkg/nvcdi/common.go rename to pkg/nvcdi/common-nvml.go index 1d04d420..df81fc29 100644 --- a/pkg/nvcdi/common.go +++ b/pkg/nvcdi/common-nvml.go @@ -20,27 +20,15 @@ import ( "fmt" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" - "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" - "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" ) -// GetCommonEdits generates a CDI specification that can be used for ANY devices -func (l *nvcdilib) GetCommonEdits() (*cdi.ContainerEdits, error) { - common, err := newCommonDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) - if err != nil { - return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err) - } - - return edits.FromDiscoverer(common) -} - -// newCommonDiscoverer returns a discoverer for entities that are not associated with a specific CDI device. +// newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device. // This includes driver libraries and meta devices, for example. -func newCommonDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { +func newCommonNVMLDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) { metaDevices := discover.NewDeviceDiscoverer( logger, lookup.NewCharDeviceLocator( diff --git a/pkg/nvcdi/driver.go b/pkg/nvcdi/driver-nvml.go similarity index 100% rename from pkg/nvcdi/driver.go rename to pkg/nvcdi/driver-nvml.go diff --git a/pkg/nvcdi/full-gpu.go b/pkg/nvcdi/full-gpu-nvml.go similarity index 97% rename from pkg/nvcdi/full-gpu.go rename to pkg/nvcdi/full-gpu-nvml.go index 7e61477c..9dc6780e 100644 --- a/pkg/nvcdi/full-gpu.go +++ b/pkg/nvcdi/full-gpu-nvml.go @@ -33,7 +33,7 @@ import ( ) // GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'. -func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) { +func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) { edits, err := l.GetGPUDeviceEdits(d) if err != nil { return nil, fmt.Errorf("failed to get edits for device: %v", err) @@ -53,7 +53,7 @@ func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, err } // GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'. -func (l *nvcdilib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) { +func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) { device, err := newFullGPUDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, d) if err != nil { return nil, fmt.Errorf("failed to create device discoverer: %v", err) diff --git a/pkg/nvcdi/lib-nvml.go b/pkg/nvcdi/lib-nvml.go new file mode 100644 index 00000000..aaca382e --- /dev/null +++ b/pkg/nvcdi/lib-nvml.go @@ -0,0 +1,93 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" +) + +type nvmllib nvcdilib + +var _ Interface = (*nvmllib)(nil) + +// GetAllDeviceSpecs returns the device specs for all available devices. +func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) { + var deviceSpecs []specs.Device + + gpuDeviceSpecs, err := l.getGPUDeviceSpecs() + if err != nil { + return nil, err + } + deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...) + + migDeviceSpecs, err := l.getMigDeviceSpecs() + if err != nil { + return nil, err + } + deviceSpecs = append(deviceSpecs, migDeviceSpecs...) + + return deviceSpecs, nil +} + +// GetCommonEdits generates a CDI specification that can be used for ANY devices +func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) { + common, err := newCommonNVMLDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib) + if err != nil { + return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err) + } + + return edits.FromDiscoverer(common) +} + +func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) { + var deviceSpecs []specs.Device + err := l.devicelib.VisitDevices(func(i int, d device.Device) error { + deviceSpec, err := l.GetGPUDeviceSpecs(i, d) + if err != nil { + return err + } + deviceSpecs = append(deviceSpecs, *deviceSpec) + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) + } + return deviceSpecs, err +} + +func (l *nvmllib) getMigDeviceSpecs() ([]specs.Device, error) { + var deviceSpecs []specs.Device + err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error { + deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig) + if err != nil { + return err + } + deviceSpecs = append(deviceSpecs, *deviceSpec) + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) + } + return deviceSpecs, err +} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 592ff186..985e6850 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -17,9 +17,6 @@ package nvcdi import ( - "fmt" - - "github.com/container-orchestrated-devices/container-device-interface/specs-go" "github.com/sirupsen/logrus" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" @@ -41,12 +38,8 @@ func New(opts ...Option) Interface { for _, opt := range opts { opt(l) } - - if l.nvmllib == nil { - l.nvmllib = nvml.New() - } - if l.devicelib == nil { - l.devicelib = device.New(device.WithNvml(l.nvmllib)) + if l.mode == "" { + l.mode = "nvml" } if l.logger == nil { l.logger = logrus.StandardLogger() @@ -61,58 +54,18 @@ func New(opts ...Option) Interface { l.nvidiaCTKPath = "/usr/bin/nvidia-ctk" } - return l -} - -// GetAllDeviceSpecs returns the device specs for all available devices. -func (l *nvcdilib) GetAllDeviceSpecs() ([]specs.Device, error) { - var deviceSpecs []specs.Device - - gpuDeviceSpecs, err := l.getGPUDeviceSpecs() - if err != nil { - return nil, err - } - deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...) - - migDeviceSpecs, err := l.getMigDeviceSpecs() - if err != nil { - return nil, err - } - deviceSpecs = append(deviceSpecs, migDeviceSpecs...) - - return deviceSpecs, nil -} - -func (l *nvcdilib) getGPUDeviceSpecs() ([]specs.Device, error) { - var deviceSpecs []specs.Device - err := l.devicelib.VisitDevices(func(i int, d device.Device) error { - deviceSpec, err := l.GetGPUDeviceSpecs(i, d) - if err != nil { - return err + switch l.mode { + case "nvml": + if l.nvmllib == nil { + l.nvmllib = nvml.New() } - deviceSpecs = append(deviceSpecs, *deviceSpec) - - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) - } - return deviceSpecs, err -} - -func (l *nvcdilib) getMigDeviceSpecs() ([]specs.Device, error) { - var deviceSpecs []specs.Device - err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error { - deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig) - if err != nil { - return err + if l.devicelib == nil { + l.devicelib = device.New(device.WithNvml(l.nvmllib)) } - deviceSpecs = append(deviceSpecs, *deviceSpec) - return nil - }) - if err != nil { - return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err) + return (*nvmllib)(l) } - return deviceSpecs, err + + // TODO: We want an error here. + return nil } diff --git a/pkg/nvcdi/mig-device.go b/pkg/nvcdi/mig-device-nvml.go similarity index 94% rename from pkg/nvcdi/mig-device.go rename to pkg/nvcdi/mig-device-nvml.go index 3d0a91f2..7864ff91 100644 --- a/pkg/nvcdi/mig-device.go +++ b/pkg/nvcdi/mig-device-nvml.go @@ -30,7 +30,7 @@ import ( ) // GetMIGDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'. -func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) { +func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) { edits, err := l.GetMIGDeviceEdits(d, mig) if err != nil { return nil, fmt.Errorf("failed to get edits for device: %v", err) @@ -50,7 +50,7 @@ func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.M } // GetMIGDeviceEdits returns the CDI edits for the MIG device represented by 'mig' on 'parent'. -func (l *nvcdilib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) { +func (l *nvmllib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) { gpu, ret := parent.GetMinorNumber() if ret != nvml.SUCCESS { return nil, fmt.Errorf("error getting GPU minor: %v", ret) From b21dc929efa336495ee6c1b6bd8c2d155a3ba57b Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Mon, 6 Feb 2023 18:53:24 +0100 Subject: [PATCH 6/6] Add WSL2 discovery and spec generation These changes add a wsl discovery mode to the nvidia-ctk cdi generate command. If wsl mode is enabled, the driver store for the available devices is used as the source for discovered entities. Signed-off-by: Evan Lezar --- cmd/nvidia-ctk/cdi/generate/generate.go | 52 ++++++--- cmd/nvidia-ctk/cdi/generate/generate_test.go | 117 +++++++++++++++++++ pkg/nvcdi/device-wsl.go | 37 ++++++ pkg/nvcdi/driver-wsl.go | 106 +++++++++++++++++ pkg/nvcdi/lib-wsl.go | 76 ++++++++++++ pkg/nvcdi/lib.go | 2 + 6 files changed, 375 insertions(+), 15 deletions(-) create mode 100644 cmd/nvidia-ctk/cdi/generate/generate_test.go create mode 100644 pkg/nvcdi/device-wsl.go create mode 100644 pkg/nvcdi/driver-wsl.go create mode 100644 pkg/nvcdi/lib-wsl.go diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index eb869d28..9daf97ca 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -37,9 +37,12 @@ import ( const ( discoveryModeNVML = "nvml" + discoveryModeWSL = "wsl" formatJSON = "json" formatYAML = "yaml" + + allDeviceName = "all" ) type command struct { @@ -93,7 +96,7 @@ func (m command) build() *cli.Command { }, &cli.StringFlag{ Name: "discovery-mode", - Usage: "The mode to use when discovering the available entities. One of [nvml]", + Usage: "The mode to use when discovering the available entities. One of [nvml | wsl]", Value: discoveryModeNVML, Destination: &cfg.discoveryMode, }, @@ -130,6 +133,7 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { cfg.discoveryMode = strings.ToLower(cfg.discoveryMode) switch cfg.discoveryMode { case discoveryModeNVML: + case discoveryModeWSL: default: return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode) } @@ -252,10 +256,20 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { if err != nil { return nil, fmt.Errorf("failed to create device CDI specs: %v", err) } - - allDevice := createAllDevice(deviceSpecs) - - deviceSpecs = append(deviceSpecs, allDevice) + var hasAll bool + for _, deviceSpec := range deviceSpecs { + if deviceSpec.Name == allDeviceName { + hasAll = true + break + } + } + if !hasAll { + allDevice, err := MergeDeviceSpecs(deviceSpecs, allDeviceName) + if err != nil { + return nil, fmt.Errorf("failed to create CDI specification for %q device: %v", allDeviceName, err) + } + deviceSpecs = append(deviceSpecs, allDevice) + } commonEdits, err := cdilib.GetCommonEdits() if err != nil { @@ -287,22 +301,32 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { return &spec, nil } -// createAllDevice creates an 'all' device which combines the edits from the previous devices -func createAllDevice(deviceSpecs []specs.Device) specs.Device { - edits := edits.NewContainerEdits() +// MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices. +// If a device of the specified name already exists, an error is returned. +func MergeDeviceSpecs(deviceSpecs []specs.Device, mergedDeviceName string) (specs.Device, error) { + if err := cdi.ValidateDeviceName(mergedDeviceName); err != nil { + return specs.Device{}, fmt.Errorf("invalid device name %q: %v", mergedDeviceName, err) + } + for _, d := range deviceSpecs { + if d.Name == mergedDeviceName { + return specs.Device{}, fmt.Errorf("device %q already exists", mergedDeviceName) + } + } + + mergedEdits := edits.NewContainerEdits() for _, d := range deviceSpecs { edit := cdi.ContainerEdits{ ContainerEdits: &d.ContainerEdits, } - edits.Append(&edit) + mergedEdits.Append(&edit) } - all := specs.Device{ - Name: "all", - ContainerEdits: *edits.ContainerEdits, + merged := specs.Device{ + Name: mergedDeviceName, + ContainerEdits: *mergedEdits.ContainerEdits, } - return all + return merged, nil } // createParentDirsIfRequired creates the parent folders of the specified path if requried. @@ -315,5 +339,3 @@ func createParentDirsIfRequired(filename string) error { } return os.MkdirAll(dir, 0755) } - -type discoveryMode string diff --git a/cmd/nvidia-ctk/cdi/generate/generate_test.go b/cmd/nvidia-ctk/cdi/generate/generate_test.go new file mode 100644 index 00000000..5924480e --- /dev/null +++ b/cmd/nvidia-ctk/cdi/generate/generate_test.go @@ -0,0 +1,117 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package generate + +import ( + "fmt" + "testing" + + "github.com/container-orchestrated-devices/container-device-interface/specs-go" + "github.com/stretchr/testify/require" +) + +func TestMergeDeviceSpecs(t *testing.T) { + testCases := []struct { + description string + deviceSpecs []specs.Device + mergedDeviceName string + expectedError error + expected specs.Device + }{ + { + description: "no devices", + mergedDeviceName: "all", + expected: specs.Device{ + Name: "all", + }, + }, + { + description: "one device", + mergedDeviceName: "all", + deviceSpecs: []specs.Device{ + { + Name: "gpu0", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + }, + expected: specs.Device{ + Name: "all", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + }, + { + description: "two devices", + mergedDeviceName: "all", + deviceSpecs: []specs.Device{ + { + Name: "gpu0", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + { + Name: "gpu1", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=1"}, + }, + }, + }, + expected: specs.Device{ + Name: "all", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0", "GPU=1"}, + }, + }, + }, + { + description: "has merged device", + mergedDeviceName: "gpu0", + deviceSpecs: []specs.Device{ + { + Name: "gpu0", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + }, + expectedError: fmt.Errorf("device %q already exists", "gpu0"), + }, + { + description: "invalid merged device name", + mergedDeviceName: ".-not-valid", + expectedError: fmt.Errorf("invalid device name %q", ".-not-valid"), + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + mergedDevice, err := MergeDeviceSpecs(tc.deviceSpecs, tc.mergedDeviceName) + + if tc.expectedError != nil { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.EqualValues(t, tc.expected, mergedDevice) + }) + } +} diff --git a/pkg/nvcdi/device-wsl.go b/pkg/nvcdi/device-wsl.go new file mode 100644 index 00000000..6acbb4b4 --- /dev/null +++ b/pkg/nvcdi/device-wsl.go @@ -0,0 +1,37 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/sirupsen/logrus" +) + +const ( + dxgDeviceNode = "/dev/dxg" +) + +// newDXGDeviceDiscoverer returns a Discoverer for DXG devices under WSL2. +func newDXGDeviceDiscoverer(logger *logrus.Logger, driverRoot string) discover.Discover { + deviceNodes := discover.NewCharDeviceDiscoverer( + logger, + []string{dxgDeviceNode}, + driverRoot, + ) + + return deviceNodes +} diff --git a/pkg/nvcdi/driver-wsl.go b/pkg/nvcdi/driver-wsl.go new file mode 100644 index 00000000..cca4d52c --- /dev/null +++ b/pkg/nvcdi/driver-wsl.go @@ -0,0 +1,106 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + "path/filepath" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/sirupsen/logrus" +) + +var requiredDriverStoreFiles = []string{ + "libcuda.so.1.1", /* Core library for cuda support */ + "libcuda_loader.so", /* Core library for cuda support on WSL */ + "libnvidia-ptxjitcompiler.so.1", /* Core library for PTX Jit support */ + "libnvidia-ml.so.1", /* Core library for nvml */ + "libnvidia-ml_loader.so", /* Core library for nvml on WSL */ + "libdxcore.so", /* Core library for dxcore support */ + "nvcubins.bin", /* Binary containing GPU code for cuda */ + "nvidia-smi", /* nvidia-smi binary*/ +} + +// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers. +func newWSLDriverDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string) (discover.Discover, error) { + err := dxcore.Init() + if err != nil { + return nil, fmt.Errorf("failed to initialize dxcore: %v", err) + } + defer dxcore.Shutdown() + + driverStorePaths := dxcore.GetDriverStorePaths() + if len(driverStorePaths) == 0 { + return nil, fmt.Errorf("no driver store paths found") + } + logger.Infof("Using WSL driver store paths: %v", driverStorePaths) + + return newWSLDriverStoreDiscoverer(logger, driverRoot, nvidiaCTKPath, driverStorePaths) +} + +// newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter. +func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, driverStorePaths []string) (discover.Discover, error) { + var searchPaths []string + seen := make(map[string]bool) + for _, path := range driverStorePaths { + if seen[path] { + continue + } + searchPaths = append(searchPaths, path) + } + if len(searchPaths) > 1 { + logger.Warnf("Found multiple driver store paths: %v", searchPaths) + } + driverStorePath := searchPaths[0] + searchPaths = append(searchPaths, "/usr/lib/wsl/lib") + + libraries := discover.NewMounts( + logger, + lookup.NewFileLocator( + lookup.WithLogger(logger), + lookup.WithSearchPaths( + searchPaths..., + ), + lookup.WithCount(1), + ), + driverRoot, + requiredDriverStoreFiles, + ) + + // On WSL2 the driver store location is used unchanged. + // For this reason we need to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the driver store. + target := filepath.Join(driverStorePath, "nvidia-smi") + link := "/usr/bin/nvidia-smi" + links := []string{fmt.Sprintf("%s::%s", target, link)} + symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links) + + cfg := &discover.Config{ + DriverRoot: driverRoot, + NvidiaCTKPath: nvidiaCTKPath, + } + ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, cfg) + + d := discover.Merge( + libraries, + symlinkHook, + ldcacheHook, + ) + + return d, nil +} diff --git a/pkg/nvcdi/lib-wsl.go b/pkg/nvcdi/lib-wsl.go new file mode 100644 index 00000000..d901995c --- /dev/null +++ b/pkg/nvcdi/lib-wsl.go @@ -0,0 +1,76 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" +) + +type wsllib nvcdilib + +var _ Interface = (*wsllib)(nil) + +// GetAllDeviceSpecs returns the device specs for all available devices. +func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { + device := newDXGDeviceDiscoverer(l.logger, l.driverRoot) + deviceEdits, err := edits.FromDiscoverer(device) + if err != nil { + return nil, fmt.Errorf("failed to create container edits for DXG device: %v", err) + } + + deviceSpec := specs.Device{ + Name: "all", + ContainerEdits: *deviceEdits.ContainerEdits, + } + + return []specs.Device{deviceSpec}, nil +} + +// GetCommonEdits generates a CDI specification that can be used for ANY devices +func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) { + driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath) + if err != nil { + return nil, fmt.Errorf("failed to create discoverer for WSL driver: %v", err) + } + + return edits.FromDiscoverer(driver) +} + +// GetGPUDeviceEdits generates a CDI specification that can be used for GPU devices +func (l *wsllib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) { + return nil, fmt.Errorf("GetGPUDeviceEdits is not supported on WSL") +} + +// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'. +func (l *wsllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) { + return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported on WSL") +} + +// GetMIGDeviceEdits generates a CDI specification that can be used for MIG devices +func (l *wsllib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) { + return nil, fmt.Errorf("GetMIGDeviceEdits is not supported on WSL") +} + +// GetMIGDeviceSpecs returns the CDI device specs for the full MIG represented by 'device'. +func (l *wsllib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) { + return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported on WSL") +} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 985e6850..4081e524 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -64,6 +64,8 @@ func New(opts ...Option) Interface { } return (*nvmllib)(l) + case "wsl": + return (*wsllib)(l) } // TODO: We want an error here.