diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index eb869d28..9daf97ca 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -37,9 +37,12 @@ import ( const ( discoveryModeNVML = "nvml" + discoveryModeWSL = "wsl" formatJSON = "json" formatYAML = "yaml" + + allDeviceName = "all" ) type command struct { @@ -93,7 +96,7 @@ func (m command) build() *cli.Command { }, &cli.StringFlag{ Name: "discovery-mode", - Usage: "The mode to use when discovering the available entities. One of [nvml]", + Usage: "The mode to use when discovering the available entities. One of [nvml | wsl]", Value: discoveryModeNVML, Destination: &cfg.discoveryMode, }, @@ -130,6 +133,7 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { cfg.discoveryMode = strings.ToLower(cfg.discoveryMode) switch cfg.discoveryMode { case discoveryModeNVML: + case discoveryModeWSL: default: return fmt.Errorf("invalid discovery mode: %v", cfg.discoveryMode) } @@ -252,10 +256,20 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { if err != nil { return nil, fmt.Errorf("failed to create device CDI specs: %v", err) } - - allDevice := createAllDevice(deviceSpecs) - - deviceSpecs = append(deviceSpecs, allDevice) + var hasAll bool + for _, deviceSpec := range deviceSpecs { + if deviceSpec.Name == allDeviceName { + hasAll = true + break + } + } + if !hasAll { + allDevice, err := MergeDeviceSpecs(deviceSpecs, allDeviceName) + if err != nil { + return nil, fmt.Errorf("failed to create CDI specification for %q device: %v", allDeviceName, err) + } + deviceSpecs = append(deviceSpecs, allDevice) + } commonEdits, err := cdilib.GetCommonEdits() if err != nil { @@ -287,22 +301,32 @@ func (m command) generateSpec(cfg *config) (*specs.Spec, error) { return &spec, nil } -// createAllDevice creates an 'all' device which combines the edits from the previous devices -func createAllDevice(deviceSpecs []specs.Device) specs.Device { - edits := edits.NewContainerEdits() +// MergeDeviceSpecs creates a device with the specified name which combines the edits from the previous devices. +// If a device of the specified name already exists, an error is returned. +func MergeDeviceSpecs(deviceSpecs []specs.Device, mergedDeviceName string) (specs.Device, error) { + if err := cdi.ValidateDeviceName(mergedDeviceName); err != nil { + return specs.Device{}, fmt.Errorf("invalid device name %q: %v", mergedDeviceName, err) + } + for _, d := range deviceSpecs { + if d.Name == mergedDeviceName { + return specs.Device{}, fmt.Errorf("device %q already exists", mergedDeviceName) + } + } + + mergedEdits := edits.NewContainerEdits() for _, d := range deviceSpecs { edit := cdi.ContainerEdits{ ContainerEdits: &d.ContainerEdits, } - edits.Append(&edit) + mergedEdits.Append(&edit) } - all := specs.Device{ - Name: "all", - ContainerEdits: *edits.ContainerEdits, + merged := specs.Device{ + Name: mergedDeviceName, + ContainerEdits: *mergedEdits.ContainerEdits, } - return all + return merged, nil } // createParentDirsIfRequired creates the parent folders of the specified path if requried. @@ -315,5 +339,3 @@ func createParentDirsIfRequired(filename string) error { } return os.MkdirAll(dir, 0755) } - -type discoveryMode string diff --git a/cmd/nvidia-ctk/cdi/generate/generate_test.go b/cmd/nvidia-ctk/cdi/generate/generate_test.go new file mode 100644 index 00000000..5924480e --- /dev/null +++ b/cmd/nvidia-ctk/cdi/generate/generate_test.go @@ -0,0 +1,117 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package generate + +import ( + "fmt" + "testing" + + "github.com/container-orchestrated-devices/container-device-interface/specs-go" + "github.com/stretchr/testify/require" +) + +func TestMergeDeviceSpecs(t *testing.T) { + testCases := []struct { + description string + deviceSpecs []specs.Device + mergedDeviceName string + expectedError error + expected specs.Device + }{ + { + description: "no devices", + mergedDeviceName: "all", + expected: specs.Device{ + Name: "all", + }, + }, + { + description: "one device", + mergedDeviceName: "all", + deviceSpecs: []specs.Device{ + { + Name: "gpu0", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + }, + expected: specs.Device{ + Name: "all", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + }, + { + description: "two devices", + mergedDeviceName: "all", + deviceSpecs: []specs.Device{ + { + Name: "gpu0", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + { + Name: "gpu1", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=1"}, + }, + }, + }, + expected: specs.Device{ + Name: "all", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0", "GPU=1"}, + }, + }, + }, + { + description: "has merged device", + mergedDeviceName: "gpu0", + deviceSpecs: []specs.Device{ + { + Name: "gpu0", + ContainerEdits: specs.ContainerEdits{ + Env: []string{"GPU=0"}, + }, + }, + }, + expectedError: fmt.Errorf("device %q already exists", "gpu0"), + }, + { + description: "invalid merged device name", + mergedDeviceName: ".-not-valid", + expectedError: fmt.Errorf("invalid device name %q", ".-not-valid"), + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + mergedDevice, err := MergeDeviceSpecs(tc.deviceSpecs, tc.mergedDeviceName) + + if tc.expectedError != nil { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.EqualValues(t, tc.expected, mergedDevice) + }) + } +} diff --git a/pkg/nvcdi/device-wsl.go b/pkg/nvcdi/device-wsl.go new file mode 100644 index 00000000..6acbb4b4 --- /dev/null +++ b/pkg/nvcdi/device-wsl.go @@ -0,0 +1,37 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/sirupsen/logrus" +) + +const ( + dxgDeviceNode = "/dev/dxg" +) + +// newDXGDeviceDiscoverer returns a Discoverer for DXG devices under WSL2. +func newDXGDeviceDiscoverer(logger *logrus.Logger, driverRoot string) discover.Discover { + deviceNodes := discover.NewCharDeviceDiscoverer( + logger, + []string{dxgDeviceNode}, + driverRoot, + ) + + return deviceNodes +} diff --git a/pkg/nvcdi/driver-wsl.go b/pkg/nvcdi/driver-wsl.go new file mode 100644 index 00000000..cca4d52c --- /dev/null +++ b/pkg/nvcdi/driver-wsl.go @@ -0,0 +1,106 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + "path/filepath" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" + "github.com/NVIDIA/nvidia-container-toolkit/internal/dxcore" + "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" + "github.com/sirupsen/logrus" +) + +var requiredDriverStoreFiles = []string{ + "libcuda.so.1.1", /* Core library for cuda support */ + "libcuda_loader.so", /* Core library for cuda support on WSL */ + "libnvidia-ptxjitcompiler.so.1", /* Core library for PTX Jit support */ + "libnvidia-ml.so.1", /* Core library for nvml */ + "libnvidia-ml_loader.so", /* Core library for nvml on WSL */ + "libdxcore.so", /* Core library for dxcore support */ + "nvcubins.bin", /* Binary containing GPU code for cuda */ + "nvidia-smi", /* nvidia-smi binary*/ +} + +// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers. +func newWSLDriverDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string) (discover.Discover, error) { + err := dxcore.Init() + if err != nil { + return nil, fmt.Errorf("failed to initialize dxcore: %v", err) + } + defer dxcore.Shutdown() + + driverStorePaths := dxcore.GetDriverStorePaths() + if len(driverStorePaths) == 0 { + return nil, fmt.Errorf("no driver store paths found") + } + logger.Infof("Using WSL driver store paths: %v", driverStorePaths) + + return newWSLDriverStoreDiscoverer(logger, driverRoot, nvidiaCTKPath, driverStorePaths) +} + +// newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter. +func newWSLDriverStoreDiscoverer(logger *logrus.Logger, driverRoot string, nvidiaCTKPath string, driverStorePaths []string) (discover.Discover, error) { + var searchPaths []string + seen := make(map[string]bool) + for _, path := range driverStorePaths { + if seen[path] { + continue + } + searchPaths = append(searchPaths, path) + } + if len(searchPaths) > 1 { + logger.Warnf("Found multiple driver store paths: %v", searchPaths) + } + driverStorePath := searchPaths[0] + searchPaths = append(searchPaths, "/usr/lib/wsl/lib") + + libraries := discover.NewMounts( + logger, + lookup.NewFileLocator( + lookup.WithLogger(logger), + lookup.WithSearchPaths( + searchPaths..., + ), + lookup.WithCount(1), + ), + driverRoot, + requiredDriverStoreFiles, + ) + + // On WSL2 the driver store location is used unchanged. + // For this reason we need to create a symlink from /usr/bin/nvidia-smi to the nvidia-smi binary in the driver store. + target := filepath.Join(driverStorePath, "nvidia-smi") + link := "/usr/bin/nvidia-smi" + links := []string{fmt.Sprintf("%s::%s", target, link)} + symlinkHook := discover.CreateCreateSymlinkHook(nvidiaCTKPath, links) + + cfg := &discover.Config{ + DriverRoot: driverRoot, + NvidiaCTKPath: nvidiaCTKPath, + } + ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, cfg) + + d := discover.Merge( + libraries, + symlinkHook, + ldcacheHook, + ) + + return d, nil +} diff --git a/pkg/nvcdi/lib-wsl.go b/pkg/nvcdi/lib-wsl.go new file mode 100644 index 00000000..d901995c --- /dev/null +++ b/pkg/nvcdi/lib-wsl.go @@ -0,0 +1,76 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvcdi + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/edits" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/container-orchestrated-devices/container-device-interface/specs-go" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" +) + +type wsllib nvcdilib + +var _ Interface = (*wsllib)(nil) + +// GetAllDeviceSpecs returns the device specs for all available devices. +func (l *wsllib) GetAllDeviceSpecs() ([]specs.Device, error) { + device := newDXGDeviceDiscoverer(l.logger, l.driverRoot) + deviceEdits, err := edits.FromDiscoverer(device) + if err != nil { + return nil, fmt.Errorf("failed to create container edits for DXG device: %v", err) + } + + deviceSpec := specs.Device{ + Name: "all", + ContainerEdits: *deviceEdits.ContainerEdits, + } + + return []specs.Device{deviceSpec}, nil +} + +// GetCommonEdits generates a CDI specification that can be used for ANY devices +func (l *wsllib) GetCommonEdits() (*cdi.ContainerEdits, error) { + driver, err := newWSLDriverDiscoverer(l.logger, l.driverRoot, l.nvidiaCTKPath) + if err != nil { + return nil, fmt.Errorf("failed to create discoverer for WSL driver: %v", err) + } + + return edits.FromDiscoverer(driver) +} + +// GetGPUDeviceEdits generates a CDI specification that can be used for GPU devices +func (l *wsllib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) { + return nil, fmt.Errorf("GetGPUDeviceEdits is not supported on WSL") +} + +// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'. +func (l *wsllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, error) { + return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported on WSL") +} + +// GetMIGDeviceEdits generates a CDI specification that can be used for MIG devices +func (l *wsllib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) { + return nil, fmt.Errorf("GetMIGDeviceEdits is not supported on WSL") +} + +// GetMIGDeviceSpecs returns the CDI device specs for the full MIG represented by 'device'. +func (l *wsllib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) (*specs.Device, error) { + return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported on WSL") +} diff --git a/pkg/nvcdi/lib.go b/pkg/nvcdi/lib.go index 985e6850..4081e524 100644 --- a/pkg/nvcdi/lib.go +++ b/pkg/nvcdi/lib.go @@ -64,6 +64,8 @@ func New(opts ...Option) Interface { } return (*nvmllib)(l) + case "wsl": + return (*wsllib)(l) } // TODO: We want an error here.