mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-04-10 15:25:33 +00:00
Merge branch 'wsl2-wip' into 'main'
Add CDI Spec generation on WSL2 See merge request nvidia/container-toolkit/container-toolkit!289
This commit is contained in:
commit
f36c775d50
@ -36,8 +36,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
discoveryModeNVML = "nvml"
|
||||
discoveryModeWSL = "wsl"
|
||||
|
||||
formatJSON = "json"
|
||||
formatYAML = "yaml"
|
||||
|
||||
allDeviceName = "all"
|
||||
)
|
||||
|
||||
type command struct {
|
||||
@ -50,6 +55,7 @@ type config struct {
|
||||
deviceNameStrategy string
|
||||
driverRoot string
|
||||
nvidiaCTKPath string
|
||||
discoveryMode string
|
||||
}
|
||||
|
||||
// NewCommand constructs a generate-cdi command with the specified logger
|
||||
@ -88,6 +94,12 @@ func (m command) build() *cli.Command {
|
||||
Value: formatYAML,
|
||||
Destination: &cfg.format,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "discovery-mode",
|
||||
Usage: "The mode to use when discovering the available entities. One of [nvml | wsl]",
|
||||
Value: discoveryModeNVML,
|
||||
Destination: &cfg.discoveryMode,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "device-name-strategy",
|
||||
Usage: "Specify the strategy for generating device names. One of [index | uuid | type-index]",
|
||||
@ -118,6 +130,14 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error {
|
||||
return fmt.Errorf("invalid output format: %v", cfg.format)
|
||||
}
|
||||
|
||||
cfg.discoveryMode = strings.ToLower(cfg.discoveryMode)
|
||||
switch cfg.discoveryMode {
|
||||
case discoveryModeNVML:
|
||||
case discoveryModeWSL:
|
||||
default:
|
||||
return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode)
|
||||
}
|
||||
|
||||
_, err := nvcdi.NewDeviceNamer(cfg.deviceNameStrategy)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -229,16 +249,27 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) {
|
||||
nvcdi.WithDeviceNamer(deviceNamer),
|
||||
nvcdi.WithDeviceLib(devicelib),
|
||||
nvcdi.WithNvmlLib(nvmllib),
|
||||
nvcdi.WithMode(string(cfg.discoveryMode)),
|
||||
)
|
||||
|
||||
deviceSpecs, err := cdilib.GetAllDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
|
||||
}
|
||||
|
||||
allDevice := createAllDevice(deviceSpecs)
|
||||
|
||||
deviceSpecs = append(deviceSpecs, allDevice)
|
||||
var hasAll bool
|
||||
for _, deviceSpec := range deviceSpecs {
|
||||
if deviceSpec.Name == allDeviceName {
|
||||
hasAll = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasAll {
|
||||
allDevice, err := MergeDeviceSpecs(deviceSpecs, allDeviceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create CDI specification for %q device: %v", allDeviceName, err)
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, allDevice)
|
||||
}
|
||||
|
||||
commonEdits, err := cdilib.GetCommonEdits()
|
||||
if err != nil {
|
||||
@ -270,22 +301,32 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) {
|
||||
return &spec, nil
|
||||
}
|
||||
|
||||
// createAllDevice creates an 'all' device which combines the edits from the previous devices
|
||||
func createAllDevice(deviceSpecs []specs.Device) specs.Device {
|
||||
edits := edits.NewContainerEdits()
|
||||
// MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices.
|
||||
// If a device of the specified name already exists, an error is returned.
|
||||
func MergeDeviceSpecs(deviceSpecs []specs.Device, mergedDeviceName string) (specs.Device, error) {
|
||||
if err := cdi.ValidateDeviceName(mergedDeviceName); err != nil {
|
||||
return specs.Device{}, fmt.Errorf("invalid device name %q: %v", mergedDeviceName, err)
|
||||
}
|
||||
for _, d := range deviceSpecs {
|
||||
if d.Name == mergedDeviceName {
|
||||
return specs.Device{}, fmt.Errorf("device %q already exists", mergedDeviceName)
|
||||
}
|
||||
}
|
||||
|
||||
mergedEdits := edits.NewContainerEdits()
|
||||
|
||||
for _, d := range deviceSpecs {
|
||||
edit := cdi.ContainerEdits{
|
||||
ContainerEdits: &d.ContainerEdits,
|
||||
}
|
||||
edits.Append(&edit)
|
||||
mergedEdits.Append(&edit)
|
||||
}
|
||||
|
||||
all := specs.Device{
|
||||
Name: "all",
|
||||
ContainerEdits: *edits.ContainerEdits,
|
||||
merged := specs.Device{
|
||||
Name: mergedDeviceName,
|
||||
ContainerEdits: *mergedEdits.ContainerEdits,
|
||||
}
|
||||
return all
|
||||
return merged, nil
|
||||
}
|
||||
|
||||
// createParentDirsIfRequired creates the parent folders of the specified path if requried.
|
||||
|
117
cmd/nvidia-ctk/cdi/generate/generate_test.go
Normal file
117
cmd/nvidia-ctk/cdi/generate/generate_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package generate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMergeDeviceSpecs(t *testing.T) {
|
||||
testCases := []struct {
|
||||
description string
|
||||
deviceSpecs []specs.Device
|
||||
mergedDeviceName string
|
||||
expectedError error
|
||||
expected specs.Device
|
||||
}{
|
||||
{
|
||||
description: "no devices",
|
||||
mergedDeviceName: "all",
|
||||
expected: specs.Device{
|
||||
Name: "all",
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "one device",
|
||||
mergedDeviceName: "all",
|
||||
deviceSpecs: []specs.Device{
|
||||
{
|
||||
Name: "gpu0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"GPU=0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: specs.Device{
|
||||
Name: "all",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"GPU=0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "two devices",
|
||||
mergedDeviceName: "all",
|
||||
deviceSpecs: []specs.Device{
|
||||
{
|
||||
Name: "gpu0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"GPU=0"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "gpu1",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"GPU=1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: specs.Device{
|
||||
Name: "all",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"GPU=0", "GPU=1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "has merged device",
|
||||
mergedDeviceName: "gpu0",
|
||||
deviceSpecs: []specs.Device{
|
||||
{
|
||||
Name: "gpu0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"GPU=0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: fmt.Errorf("device %q already exists", "gpu0"),
|
||||
},
|
||||
{
|
||||
description: "invalid merged device name",
|
||||
mergedDeviceName: ".-not-valid",
|
||||
expectedError: fmt.Errorf("invalid device name %q", ".-not-valid"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
mergedDevice, err := MergeDeviceSpecs(tc.deviceSpecs, tc.mergedDeviceName)
|
||||
|
||||
if tc.expectedError != nil {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, tc.expected, mergedDevice)
|
||||
})
|
||||
}
|
||||
}
|
@ -29,12 +29,47 @@ const (
|
||||
nvidiaCTKDefaultFilePath = "/usr/bin/nvidia-ctk"
|
||||
)
|
||||
|
||||
var _ Discover = (*Hook)(nil)
|
||||
|
||||
// Devices returns an empty list of devices for a Hook discoverer.
|
||||
func (h Hook) Devices() ([]Device, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Mounts returns an empty list of mounts for a Hook discoverer.
|
||||
func (h Hook) Mounts() ([]Mount, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Hooks allows the Hook type to also implement the Discoverer interface.
|
||||
// It returns a single hook
|
||||
func (h Hook) Hooks() ([]Hook, error) {
|
||||
return []Hook{h}, nil
|
||||
}
|
||||
|
||||
// CreateCreateSymlinkHook creates a hook which creates a symlink from link -> target.
|
||||
func CreateCreateSymlinkHook(nvidiaCTKPath string, links []string) Discover {
|
||||
if len(links) == 0 {
|
||||
return None{}
|
||||
}
|
||||
|
||||
var args []string
|
||||
for _, link := range links {
|
||||
args = append(args, "--link", link)
|
||||
}
|
||||
return CreateNvidiaCTKHook(
|
||||
nvidiaCTKPath,
|
||||
"create-symlinks",
|
||||
args...,
|
||||
)
|
||||
}
|
||||
|
||||
// CreateNvidiaCTKHook creates a hook which invokes the NVIDIA Container CLI hook subcommand.
|
||||
func CreateNvidiaCTKHook(executable string, hookName string, additionalArgs ...string) Hook {
|
||||
func CreateNvidiaCTKHook(nvidiaCTKPath string, hookName string, additionalArgs ...string) Hook {
|
||||
return Hook{
|
||||
Lifecycle: cdi.CreateContainerHook,
|
||||
Path: executable,
|
||||
Args: append([]string{filepath.Base(executable), "hook", hookName}, additionalArgs...),
|
||||
Path: nvidiaCTKPath,
|
||||
Args: append([]string{filepath.Base(nvidiaCTKPath), "hook", hookName}, additionalArgs...),
|
||||
}
|
||||
}
|
||||
|
||||
|
58
internal/dxcore/api.go
Normal file
58
internal/dxcore/api.go
Normal file
@ -0,0 +1,58 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package dxcore
|
||||
|
||||
import (
|
||||
"github.com/NVIDIA/go-nvml/pkg/dl"
|
||||
)
|
||||
|
||||
const (
|
||||
libraryName = "libdxcore.so"
|
||||
libraryLoadFlags = dl.RTLD_LAZY | dl.RTLD_GLOBAL
|
||||
)
|
||||
|
||||
// dxcore stores a reference the dxcore dynamic library
|
||||
var dxcore *context
|
||||
|
||||
// Init initializes the dxcore dynamic library
|
||||
func Init() error {
|
||||
c, err := initContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dxcore = c
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown closes the dxcore dynamic library
|
||||
func Shutdown() error {
|
||||
if dxcore != nil && dxcore.initialized != 0 {
|
||||
dxcore.deinitContext()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDriverStorePaths returns the list of driver store paths
|
||||
func GetDriverStorePaths() []string {
|
||||
var paths []string
|
||||
for i := 0; i < dxcore.getAdapterCount(); i++ {
|
||||
adapter := dxcore.getAdapter(i)
|
||||
paths = append(paths, adapter.getDriverStorePath())
|
||||
}
|
||||
|
||||
return paths
|
||||
}
|
334
internal/dxcore/dxcore.c
Normal file
334
internal/dxcore/dxcore.c
Normal file
@ -0,0 +1,334 @@
|
||||
/*
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*/
|
||||
|
||||
#include <dlfcn.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "dxcore.h"
|
||||
|
||||
// We define log_write as an empty macro to allow dxcore to remain unchanged.
|
||||
#define log_write(...)
|
||||
|
||||
// We define the following macros to allow dxcore to remain largely unchanged.
|
||||
#define log_info(msg) log_write('I', __FILE__, __LINE__, msg)
|
||||
#define log_warn(msg) log_write('W', __FILE__, __LINE__, msg)
|
||||
#define log_err(msg) log_write('E', __FILE__, __LINE__, msg)
|
||||
#define log_infof(fmt, ...) log_write('I', __FILE__, __LINE__, fmt, __VA_ARGS__)
|
||||
#define log_warnf(fmt, ...) log_write('W', __FILE__, __LINE__, fmt, __VA_ARGS__)
|
||||
#define log_errf(fmt, ...) log_write('E', __FILE__, __LINE__, fmt, __VA_ARGS__)
|
||||
|
||||
|
||||
#define DXCORE_MAX_PATH 260
|
||||
|
||||
/*
|
||||
* List of components we expect to find in the driver store that we need to mount
|
||||
*/
|
||||
static const char * const dxcore_nvidia_driver_store_components[] = {
|
||||
"libcuda.so.1.1", /* Core library for cuda support */
|
||||
"libcuda_loader.so", /* Core library for cuda support on WSL */
|
||||
"libnvidia-ptxjitcompiler.so.1", /* Core library for PTX Jit support */
|
||||
"libnvidia-ml.so.1", /* Core library for nvml */
|
||||
"libnvidia-ml_loader.so", /* Core library for nvml on WSL */
|
||||
"nvidia-smi", /* nvidia-smi binary*/
|
||||
"nvcubins.bin", /* Binary containing GPU code for cuda */
|
||||
};
|
||||
|
||||
|
||||
/*
|
||||
* List of functions and structures we need to communicate with libdxcore.
|
||||
* Documentation on these functions can be found on docs.microsoft.com in d3dkmthk.
|
||||
*/
|
||||
|
||||
struct dxcore_enumAdapters2;
|
||||
struct dxcore_queryAdapterInfo;
|
||||
|
||||
typedef int(*pfnDxcoreEnumAdapters2)(struct dxcore_enumAdapters2* pParams);
|
||||
typedef int(*pfnDxcoreQueryAdapterInfo)(struct dxcore_queryAdapterInfo* pParams);
|
||||
|
||||
struct dxcore_lib {
|
||||
void* hDxcoreLib;
|
||||
pfnDxcoreEnumAdapters2 pDxcoreEnumAdapters2;
|
||||
pfnDxcoreQueryAdapterInfo pDxcoreQueryAdapterInfo;
|
||||
};
|
||||
|
||||
struct dxcore_adapterInfo
|
||||
{
|
||||
unsigned int hAdapter;
|
||||
struct dxcore_luid AdapterLuid;
|
||||
unsigned int NumOfSources;
|
||||
unsigned int bPresentMoveRegionsPreferred;
|
||||
};
|
||||
|
||||
struct dxcore_enumAdapters2
|
||||
{
|
||||
unsigned int NumAdapters;
|
||||
struct dxcore_adapterInfo *pAdapters;
|
||||
};
|
||||
|
||||
enum dxcore_kmtqueryAdapterInfoType
|
||||
{
|
||||
DXCORE_QUERYDRIVERVERSION = 13,
|
||||
DXCORE_QUERYREGISTRY = 48,
|
||||
};
|
||||
|
||||
enum dxcore_queryregistry_type {
|
||||
DXCORE_QUERYREGISTRY_DRIVERSTOREPATH = 2,
|
||||
DXCORE_QUERYREGISTRY_DRIVERIMAGEPATH = 3,
|
||||
};
|
||||
|
||||
enum dxcore_queryregistry_status {
|
||||
DXCORE_QUERYREGISTRY_STATUS_SUCCESS = 0,
|
||||
DXCORE_QUERYREGISTRY_STATUS_BUFFER_OVERFLOW = 1,
|
||||
DXCORE_QUERYREGISTRY_STATUS_FAIL = 2,
|
||||
};
|
||||
|
||||
struct dxcore_queryregistry_info {
|
||||
enum dxcore_queryregistry_type QueryType;
|
||||
unsigned int QueryFlags;
|
||||
wchar_t ValueName[DXCORE_MAX_PATH];
|
||||
unsigned int ValueType;
|
||||
unsigned int PhysicalAdapterIndex;
|
||||
unsigned int OutputValueSize;
|
||||
enum dxcore_queryregistry_status Status;
|
||||
union {
|
||||
unsigned long long OutputQword;
|
||||
wchar_t Output;
|
||||
};
|
||||
};
|
||||
|
||||
struct dxcore_queryAdapterInfo
|
||||
{
|
||||
unsigned int hAdapter;
|
||||
enum dxcore_kmtqueryAdapterInfoType Type;
|
||||
void *pPrivateDriverData;
|
||||
unsigned int PrivateDriverDataSize;
|
||||
};
|
||||
|
||||
static int dxcore_query_adapter_info_helper(struct dxcore_lib* pLib,
|
||||
unsigned int hAdapter,
|
||||
enum dxcore_kmtqueryAdapterInfoType type,
|
||||
void* pPrivateDriverDate,
|
||||
unsigned int privateDriverDataSize)
|
||||
{
|
||||
struct dxcore_queryAdapterInfo queryAdapterInfo = { 0 };
|
||||
|
||||
queryAdapterInfo.hAdapter = hAdapter;
|
||||
queryAdapterInfo.Type = type;
|
||||
queryAdapterInfo.pPrivateDriverData = pPrivateDriverDate;
|
||||
queryAdapterInfo.PrivateDriverDataSize = privateDriverDataSize;
|
||||
|
||||
return pLib->pDxcoreQueryAdapterInfo(&queryAdapterInfo);
|
||||
}
|
||||
|
||||
static int dxcore_query_adapter_wddm_version(struct dxcore_lib* pLib, unsigned int hAdapter, unsigned int* version)
|
||||
{
|
||||
return dxcore_query_adapter_info_helper(pLib,
|
||||
hAdapter,
|
||||
DXCORE_QUERYDRIVERVERSION,
|
||||
(void*)version,
|
||||
sizeof(*version));
|
||||
}
|
||||
|
||||
static int dxcore_query_adapter_driverstore(struct dxcore_lib* pLib, unsigned int hAdapter, char** ppDriverStorePath)
|
||||
{
|
||||
struct dxcore_queryregistry_info params = {0};
|
||||
struct dxcore_queryregistry_info* pValue = NULL;
|
||||
wchar_t* pOutput;
|
||||
size_t outputSizeInBytes;
|
||||
size_t outputSize;
|
||||
|
||||
params.QueryType = DXCORE_QUERYREGISTRY_DRIVERSTOREPATH;
|
||||
|
||||
if (dxcore_query_adapter_info_helper(pLib,
|
||||
hAdapter,
|
||||
DXCORE_QUERYREGISTRY,
|
||||
(void*)¶ms,
|
||||
sizeof(params)))
|
||||
{
|
||||
log_err("Failed to query driver store path size for the WDDM Adapter");
|
||||
return (-1);
|
||||
}
|
||||
|
||||
if (params.OutputValueSize > DXCORE_MAX_PATH * sizeof(wchar_t)) {
|
||||
log_err("The driver store path size returned by dxcore is not valid");
|
||||
return (-1);
|
||||
}
|
||||
|
||||
outputSizeInBytes = (size_t)params.OutputValueSize;
|
||||
outputSize = outputSizeInBytes / sizeof(wchar_t);
|
||||
|
||||
pValue = calloc(sizeof(struct dxcore_queryregistry_info) + outputSizeInBytes + sizeof(wchar_t), 1);
|
||||
if (!pValue) {
|
||||
log_err("Out of memory while allocating temp buffer to query adapter info");
|
||||
return (-1);
|
||||
}
|
||||
|
||||
pValue->QueryType = DXCORE_QUERYREGISTRY_DRIVERSTOREPATH;
|
||||
pValue->OutputValueSize = (unsigned int)outputSizeInBytes;
|
||||
|
||||
if (dxcore_query_adapter_info_helper(pLib,
|
||||
hAdapter,
|
||||
DXCORE_QUERYREGISTRY,
|
||||
(void*)pValue,
|
||||
(unsigned int)(sizeof(struct dxcore_queryregistry_info) + outputSizeInBytes)))
|
||||
{
|
||||
log_err("Failed to query driver store path data for the WDDM Adapter");
|
||||
free(pValue);
|
||||
return (-1);
|
||||
}
|
||||
pOutput = (wchar_t*)(&pValue->Output);
|
||||
|
||||
// Make sure no matter what happened the wchar_t string is null terminated
|
||||
pOutput[outputSize] = L'\0';
|
||||
|
||||
// Convert the output into a regular c string
|
||||
*ppDriverStorePath = (char*)calloc(outputSize + 1, sizeof(char));
|
||||
if (!*ppDriverStorePath) {
|
||||
log_err("Out of memory while allocating the buffer for the driver store path");
|
||||
free(pValue);
|
||||
return (-1);
|
||||
}
|
||||
wcstombs(*ppDriverStorePath, pOutput, outputSize);
|
||||
|
||||
free(pValue);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void dxcore_add_adapter(struct dxcore_context* pCtx, struct dxcore_lib* pLib, struct dxcore_adapterInfo *pAdapterInfo)
|
||||
{
|
||||
unsigned int wddmVersion = 0;
|
||||
char* driverStorePath = NULL;
|
||||
|
||||
log_infof("Creating a new WDDM Adapter for hAdapter:%x luid:%llx", pAdapterInfo->hAdapter, *((unsigned long long*)&pAdapterInfo->AdapterLuid));
|
||||
|
||||
if (dxcore_query_adapter_wddm_version(pLib, pAdapterInfo->hAdapter, &wddmVersion)) {
|
||||
log_err("Failed to query the WDDM version for the specified adapter. Skipping it.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (wddmVersion < 2700) {
|
||||
log_err("Found a WDDM adapter running a driver with pre-WDDM 2.7 . Skipping it.");
|
||||
return;
|
||||
}
|
||||
|
||||
if (dxcore_query_adapter_driverstore(pLib, pAdapterInfo->hAdapter, &driverStorePath)) {
|
||||
log_err("Failed to query driver store path for the WDDM Adapter . Skipping it.");
|
||||
return;
|
||||
}
|
||||
|
||||
// We got all the info we needed. Adding it to the tracking structure.
|
||||
{
|
||||
struct dxcore_adapter* newList;
|
||||
newList = realloc(pCtx->adapterList, sizeof(struct dxcore_adapter) * (pCtx->adapterCount + 1));
|
||||
if (!newList) {
|
||||
log_err("Out of memory when trying to add a new WDDM Adapter to the list of valid adapters");
|
||||
free(driverStorePath);
|
||||
return;
|
||||
}
|
||||
|
||||
pCtx->adapterList = newList;
|
||||
|
||||
pCtx->adapterList[pCtx->adapterCount].hAdapter = pAdapterInfo->hAdapter;
|
||||
pCtx->adapterList[pCtx->adapterCount].pDriverStorePath = driverStorePath;
|
||||
pCtx->adapterList[pCtx->adapterCount].wddmVersion = wddmVersion;
|
||||
pCtx->adapterCount++;
|
||||
}
|
||||
|
||||
log_infof("Adding new adapter via dxcore hAdapter:%x luid:%llx wddm version:%d", pAdapterInfo->hAdapter, *((unsigned long long*)&pAdapterInfo->AdapterLuid), wddmVersion);
|
||||
}
|
||||
|
||||
static void dxcore_enum_adapters(struct dxcore_context* pCtx, struct dxcore_lib* pLib)
|
||||
{
|
||||
struct dxcore_enumAdapters2 params = {0};
|
||||
unsigned int adapterIndex = 0;
|
||||
|
||||
params.NumAdapters = 0;
|
||||
params.pAdapters = NULL;
|
||||
|
||||
if (pLib->pDxcoreEnumAdapters2(¶ms)) {
|
||||
log_err("Failed to enumerate adapters via dxcore");
|
||||
return;
|
||||
}
|
||||
|
||||
params.pAdapters = malloc(sizeof(struct dxcore_adapterInfo) * params.NumAdapters);
|
||||
if (pLib->pDxcoreEnumAdapters2(¶ms)) {
|
||||
free(params.pAdapters);
|
||||
log_err("Failed to enumerate adapters via dxcore");
|
||||
return;
|
||||
}
|
||||
|
||||
for (adapterIndex = 0; adapterIndex < params.NumAdapters; adapterIndex++) {
|
||||
dxcore_add_adapter(pCtx, pLib, ¶ms.pAdapters[adapterIndex]);
|
||||
}
|
||||
|
||||
free(params.pAdapters);
|
||||
}
|
||||
|
||||
int dxcore_init_context(struct dxcore_context* pCtx)
|
||||
{
|
||||
struct dxcore_lib lib = {0};
|
||||
|
||||
pCtx->initialized = 0;
|
||||
pCtx->adapterCount = 0;
|
||||
pCtx->adapterList = NULL;
|
||||
|
||||
lib.hDxcoreLib = dlopen("libdxcore.so", RTLD_LAZY);
|
||||
if (!lib.hDxcoreLib) {
|
||||
goto error;
|
||||
}
|
||||
|
||||
lib.pDxcoreEnumAdapters2 = (pfnDxcoreEnumAdapters2)dlsym(lib.hDxcoreLib, "D3DKMTEnumAdapters2");
|
||||
if (!lib.pDxcoreEnumAdapters2) {
|
||||
log_err("dxcore library is present but the symbol D3DKMTEnumAdapters2 is missing");
|
||||
goto error;
|
||||
}
|
||||
|
||||
lib.pDxcoreQueryAdapterInfo = (pfnDxcoreQueryAdapterInfo)dlsym(lib.hDxcoreLib, "D3DKMTQueryAdapterInfo");
|
||||
if (!lib.pDxcoreQueryAdapterInfo) {
|
||||
log_err("dxcore library is present but the symbol D3DKMTQueryAdapterInfo is missing");
|
||||
goto error;
|
||||
}
|
||||
|
||||
dxcore_enum_adapters(pCtx, &lib);
|
||||
|
||||
log_info("dxcore layer initialized successfully");
|
||||
pCtx->initialized = 1;
|
||||
|
||||
dlclose(lib.hDxcoreLib);
|
||||
|
||||
return 0;
|
||||
|
||||
error:
|
||||
dxcore_deinit_context(pCtx);
|
||||
|
||||
if (lib.hDxcoreLib)
|
||||
dlclose(lib.hDxcoreLib);
|
||||
|
||||
return (-1);
|
||||
}
|
||||
|
||||
static void dxcore_deinit_adapter(struct dxcore_adapter* pAdapter)
|
||||
{
|
||||
if (!pAdapter)
|
||||
return;
|
||||
|
||||
free(pAdapter->pDriverStorePath);
|
||||
}
|
||||
|
||||
void dxcore_deinit_context(struct dxcore_context* pCtx)
|
||||
{
|
||||
unsigned int adapterIndex = 0;
|
||||
|
||||
if (!pCtx)
|
||||
return;
|
||||
|
||||
for (adapterIndex = 0; adapterIndex < pCtx->adapterCount; adapterIndex++) {
|
||||
dxcore_deinit_adapter(&pCtx->adapterList[adapterIndex]);
|
||||
}
|
||||
|
||||
free(pCtx->adapterList);
|
||||
|
||||
pCtx->initialized = 0;
|
||||
}
|
59
internal/dxcore/dxcore.go
Normal file
59
internal/dxcore/dxcore.go
Normal file
@ -0,0 +1,59 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package dxcore
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -Wl,--unresolved-symbols=ignore-in-object-files
|
||||
#include <dxcore.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type context C.struct_dxcore_context
|
||||
type adapter C.struct_dxcore_adapter
|
||||
|
||||
// initContext initializes the dxcore context and populates the list of adapters.
|
||||
func initContext() (*context, error) {
|
||||
cContext := C.struct_dxcore_context{}
|
||||
if C.dxcore_init_context(&cContext) != 0 {
|
||||
return nil, fmt.Errorf("failed to initialize dxcore context")
|
||||
}
|
||||
c := (*context)(&cContext)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// deinitContext deinitializes the dxcore context and frees the list of adapters.
|
||||
func (c context) deinitContext() {
|
||||
cContext := C.struct_dxcore_context(c)
|
||||
C.dxcore_deinit_context(&cContext)
|
||||
}
|
||||
|
||||
func (c context) getAdapterCount() int {
|
||||
return int(c.adapterCount)
|
||||
}
|
||||
|
||||
func (c context) getAdapter(index int) adapter {
|
||||
arrayPointer := (*[1 << 30]C.struct_dxcore_adapter)(unsafe.Pointer(c.adapterList))
|
||||
return adapter(arrayPointer[index])
|
||||
}
|
||||
|
||||
func (a adapter) getDriverStorePath() string {
|
||||
return C.GoString(a.pDriverStorePath)
|
||||
}
|
39
internal/dxcore/dxcore.h
Normal file
39
internal/dxcore/dxcore.h
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
*/
|
||||
|
||||
#ifndef HEADER_DXCORE_H_
|
||||
#define HEADER_DXCORE_H_
|
||||
|
||||
#define MAX_DXCORE_DRIVERSTORE_LIBRAIRIES (16)
|
||||
|
||||
struct dxcore_luid
|
||||
{
|
||||
unsigned int lowPart;
|
||||
int highPart;
|
||||
};
|
||||
|
||||
struct dxcore_adapter
|
||||
{
|
||||
unsigned int hAdapter;
|
||||
unsigned int wddmVersion;
|
||||
char* pDriverStorePath;
|
||||
unsigned int driverStoreComponentCount;
|
||||
const char* pDriverStoreComponents[MAX_DXCORE_DRIVERSTORE_LIBRAIRIES];
|
||||
struct dxcore_context *pContext;
|
||||
};
|
||||
|
||||
struct dxcore_context
|
||||
{
|
||||
unsigned int adapterCount;
|
||||
struct dxcore_adapter *adapterList;
|
||||
|
||||
int initialized;
|
||||
};
|
||||
|
||||
|
||||
|
||||
int dxcore_init_context(struct dxcore_context* pDxcore_context);
|
||||
void dxcore_deinit_context(struct dxcore_context* pDxcore_context);
|
||||
|
||||
#endif // HEADER_DXCORE_H_
|
@ -20,27 +20,15 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||
)
|
||||
|
||||
// GetCommonEdits generates a CDI specification that can be used for ANY devices
|
||||
func (l *nvcdilib) GetCommonEdits() (*cdi.ContainerEdits, error) {
|
||||
common, err := newCommonDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err)
|
||||
}
|
||||
|
||||
return edits.FromDiscoverer(common)
|
||||
}
|
||||
|
||||
// newCommonDiscoverer returns a discoverer for entities that are not associated with a specific CDI device.
|
||||
// newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device.
|
||||
// This includes driver libraries and meta devices, for example.
|
||||
func newCommonDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) {
|
||||
func newCommonNVMLDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, nvmllib nvml.Interface) (discover.Discover, error) {
|
||||
metaDevices := discover.NewDeviceDiscoverer(
|
||||
logger,
|
||||
lookup.NewCharDeviceLocator(
|
37
pkg/nvcdi/device-wsl.go
Normal file
37
pkg/nvcdi/device-wsl.go
Normal file
@ -0,0 +1,37 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
dxgDeviceNode = "/dev/dxg"
|
||||
)
|
||||
|
||||
// newDXGDeviceDiscoverer returns a Discoverer for DXG devices under WSL2.
|
||||
func newDXGDeviceDiscoverer(logger *logrus.Logger, driverRoot string) discover.Discover {
|
||||
deviceNodes := discover.NewCharDeviceDiscoverer(
|
||||
logger,
|
||||
[]string{dxgDeviceNode},
|
||||
driverRoot,
|
||||
)
|
||||
|
||||
return deviceNodes
|
||||
}
|
106
pkg/nvcdi/driver-wsl.go
Normal file
106
pkg/nvcdi/driver-wsl.go
Normal file
@ -0,0 +1,106 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore"
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var requiredDriverStoreFiles = []string{
|
||||
"libcuda.so.1.1", /* Core library for cuda support */
|
||||
"libcuda_loader.so", /* Core library for cuda support on WSL */
|
||||
"libnvidia-ptxjitcompiler.so.1", /* Core library for PTX Jit support */
|
||||
"libnvidia-ml.so.1", /* Core library for nvml */
|
||||
"libnvidia-ml_loader.so", /* Core library for nvml on WSL */
|
||||
"libdxcore.so", /* Core library for dxcore support */
|
||||
"nvcubins.bin", /* Binary containing GPU code for cuda */
|
||||
"nvidia-smi", /* nvidia-smi binary*/
|
||||
}
|
||||
|
||||
// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers.
|
||||
func newWSLDriverDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string) (discover.Discover, error) {
|
||||
err := dxcore.Init()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize dxcore: %v", err)
|
||||
}
|
||||
defer dxcore.Shutdown()
|
||||
|
||||
driverStorePaths := dxcore.GetDriverStorePaths()
|
||||
if len(driverStorePaths) == 0 {
|
||||
return nil, fmt.Errorf("no driver store paths found")
|
||||
}
|
||||
logger.Infof("Using WSL driver store paths: %v", driverStorePaths)
|
||||
|
||||
return newWSLDriverStoreDiscoverer(logger, driverRoot, nvidiaCTKPath, driverStorePaths)
|
||||
}
|
||||
|
||||
// newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter.
|
||||
func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, driverStorePaths []string) (discover.Discover, error) {
|
||||
var searchPaths []string
|
||||
seen := make(map[string]bool)
|
||||
for _, path := range driverStorePaths {
|
||||
if seen[path] {
|
||||
continue
|
||||
}
|
||||
searchPaths = append(searchPaths, path)
|
||||
}
|
||||
if len(searchPaths) > 1 {
|
||||
logger.Warnf("Found multiple driver store paths: %v", searchPaths)
|
||||
}
|
||||
driverStorePath := searchPaths[0]
|
||||
searchPaths = append(searchPaths, "/usr/lib/wsl/lib")
|
||||
|
||||
libraries := discover.NewMounts(
|
||||
logger,
|
||||
lookup.NewFileLocator(
|
||||
lookup.WithLogger(logger),
|
||||
lookup.WithSearchPaths(
|
||||
searchPaths...,
|
||||
),
|
||||
lookup.WithCount(1),
|
||||
),
|
||||
driverRoot,
|
||||
requiredDriverStoreFiles,
|
||||
)
|
||||
|
||||
// On WSL2 the driver store location is used unchanged.
|
||||
// For this reason we need to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the driver store.
|
||||
target := filepath.Join(driverStorePath, "nvidia-smi")
|
||||
link := "/usr/bin/nvidia-smi"
|
||||
links := []string{fmt.Sprintf("%s::%s", target, link)}
|
||||
symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links)
|
||||
|
||||
cfg := &discover.Config{
|
||||
DriverRoot: driverRoot,
|
||||
NvidiaCTKPath: nvidiaCTKPath,
|
||||
}
|
||||
ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, cfg)
|
||||
|
||||
d := discover.Merge(
|
||||
libraries,
|
||||
symlinkHook,
|
||||
ldcacheHook,
|
||||
)
|
||||
|
||||
return d, nil
|
||||
}
|
@ -33,7 +33,7 @@ import (
|
||||
)
|
||||
|
||||
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
|
||||
func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) {
|
||||
func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) {
|
||||
edits, err := l.GetGPUDeviceEdits(d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get edits for device: %v", err)
|
||||
@ -53,7 +53,7 @@ func (l *nvcdilib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, err
|
||||
}
|
||||
|
||||
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
|
||||
func (l *nvcdilib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
|
||||
func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
|
||||
device, err := newFullGPUDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, d)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
|
93
pkg/nvcdi/lib-nvml.go
Normal file
93
pkg/nvcdi/lib-nvml.go
Normal file
@ -0,0 +1,93 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
|
||||
)
|
||||
|
||||
type nvmllib nvcdilib
|
||||
|
||||
var _ Interface = (*nvmllib)(nil)
|
||||
|
||||
// GetAllDeviceSpecs returns the device specs for all available devices.
|
||||
func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
|
||||
gpuDeviceSpecs, err := l.getGPUDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...)
|
||||
|
||||
migDeviceSpecs, err := l.getMigDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, migDeviceSpecs...)
|
||||
|
||||
return deviceSpecs, nil
|
||||
}
|
||||
|
||||
// GetCommonEdits generates a CDI specification that can be used for ANY devices
|
||||
func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
|
||||
common, err := newCommonNVMLDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath, l.nvmllib)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err)
|
||||
}
|
||||
|
||||
return edits.FromDiscoverer(common)
|
||||
}
|
||||
|
||||
func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitDevices(func(i int, d device.Device) error {
|
||||
deviceSpec, err := l.GetGPUDeviceSpecs(i, d)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
}
|
||||
|
||||
func (l *nvmllib) getMigDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
|
||||
deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
}
|
76
pkg/nvcdi/lib-wsl.go
Normal file
76
pkg/nvcdi/lib-wsl.go
Normal file
@ -0,0 +1,76 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/pkg/cdi"
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
|
||||
)
|
||||
|
||||
type wsllib nvcdilib
|
||||
|
||||
var _ Interface = (*wsllib)(nil)
|
||||
|
||||
// GetAllDeviceSpecs returns the device specs for all available devices.
|
||||
func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) {
|
||||
device := newDXGDeviceDiscoverer(l.logger, l.driverRoot)
|
||||
deviceEdits, err := edits.FromDiscoverer(device)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create container edits for DXG device: %v", err)
|
||||
}
|
||||
|
||||
deviceSpec := specs.Device{
|
||||
Name: "all",
|
||||
ContainerEdits: *deviceEdits.ContainerEdits,
|
||||
}
|
||||
|
||||
return []specs.Device{deviceSpec}, nil
|
||||
}
|
||||
|
||||
// GetCommonEdits generates a CDI specification that can be used for ANY devices
|
||||
func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
|
||||
driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create discoverer for WSL driver: %v", err)
|
||||
}
|
||||
|
||||
return edits.FromDiscoverer(driver)
|
||||
}
|
||||
|
||||
// GetGPUDeviceEdits generates a CDI specification that can be used for GPU devices
|
||||
func (l *wsllib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
|
||||
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported on WSL")
|
||||
}
|
||||
|
||||
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
|
||||
func (l *wsllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) {
|
||||
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported on WSL")
|
||||
}
|
||||
|
||||
// GetMIGDeviceEdits generates a CDI specification that can be used for MIG devices
|
||||
func (l *wsllib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
|
||||
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported on WSL")
|
||||
}
|
||||
|
||||
// GetMIGDeviceSpecs returns the CDI device specs for the full MIG represented by 'device'.
|
||||
func (l *wsllib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) {
|
||||
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported on WSL")
|
||||
}
|
@ -17,9 +17,6 @@
|
||||
package nvcdi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
|
||||
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
|
||||
@ -28,6 +25,7 @@ import (
|
||||
type nvcdilib struct {
|
||||
logger *logrus.Logger
|
||||
nvmllib nvml.Interface
|
||||
mode string
|
||||
devicelib device.Interface
|
||||
deviceNamer DeviceNamer
|
||||
driverRoot string
|
||||
@ -40,12 +38,8 @@ func New(opts ...Option) Interface {
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
|
||||
if l.nvmllib == nil {
|
||||
l.nvmllib = nvml.New()
|
||||
}
|
||||
if l.devicelib == nil {
|
||||
l.devicelib = device.New(device.WithNvml(l.nvmllib))
|
||||
if l.mode == "" {
|
||||
l.mode = "nvml"
|
||||
}
|
||||
if l.logger == nil {
|
||||
l.logger = logrus.StandardLogger()
|
||||
@ -60,58 +54,20 @@ func New(opts ...Option) Interface {
|
||||
l.nvidiaCTKPath = "/usr/bin/nvidia-ctk"
|
||||
}
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// GetAllDeviceSpecs returns the device specs for all available devices.
|
||||
func (l *nvcdilib) GetAllDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
|
||||
gpuDeviceSpecs, err := l.getGPUDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...)
|
||||
|
||||
migDeviceSpecs, err := l.getMigDeviceSpecs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, migDeviceSpecs...)
|
||||
|
||||
return deviceSpecs, nil
|
||||
}
|
||||
|
||||
func (l *nvcdilib) getGPUDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitDevices(func(i int, d device.Device) error {
|
||||
deviceSpec, err := l.GetGPUDeviceSpecs(i, d)
|
||||
if err != nil {
|
||||
return err
|
||||
switch l.mode {
|
||||
case "nvml":
|
||||
if l.nvmllib == nil {
|
||||
l.nvmllib = nvml.New()
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
}
|
||||
|
||||
func (l *nvcdilib) getMigDeviceSpecs() ([]specs.Device, error) {
|
||||
var deviceSpecs []specs.Device
|
||||
err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
|
||||
deviceSpec, err := l.GetMIGDeviceSpecs(i, d, j, mig)
|
||||
if err != nil {
|
||||
return err
|
||||
if l.devicelib == nil {
|
||||
l.devicelib = device.New(device.WithNvml(l.nvmllib))
|
||||
}
|
||||
deviceSpecs = append(deviceSpecs, *deviceSpec)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
|
||||
return (*nvmllib)(l)
|
||||
case "wsl":
|
||||
return (*wsllib)(l)
|
||||
}
|
||||
return deviceSpecs, err
|
||||
|
||||
// TODO: We want an error here.
|
||||
return nil
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ import (
|
||||
)
|
||||
|
||||
// GetMIGDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
|
||||
func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) {
|
||||
func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.MigDevice) (*specs.Device, error) {
|
||||
edits, err := l.GetMIGDeviceEdits(d, mig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get edits for device: %v", err)
|
||||
@ -50,7 +50,7 @@ func (l *nvcdilib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.M
|
||||
}
|
||||
|
||||
// GetMIGDeviceEdits returns the CDI edits for the MIG device represented by 'mig' on 'parent'.
|
||||
func (l *nvcdilib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) {
|
||||
func (l *nvmllib) GetMIGDeviceEdits(parent device.Device, mig device.MigDevice) (*cdi.ContainerEdits, error) {
|
||||
gpu, ret := parent.GetMinorNumber()
|
||||
if ret != nvml.SUCCESS {
|
||||
return nil, fmt.Errorf("error getting GPU minor: %v", ret)
|
@ -66,3 +66,10 @@ func WithNvmlLib(nvmllib nvml.Interface) Option {
|
||||
l.nvmllib = nvmllib
|
||||
}
|
||||
}
|
||||
|
||||
// WithMode sets the discovery mode for the library
|
||||
func WithMode(mode string) Option {
|
||||
return func(l *nvcdilib) {
|
||||
l.mode = mode
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user