From 76b6d4d38f4162a8e7beca52ee3cca5eeb754f89 Mon Sep 17 00:00:00 2001
From: Evan Lezar <elezar@nvidia.com>
Date: Mon, 10 Mar 2025 13:51:53 +0200
Subject: [PATCH] [no-relnote] Add functions to create gpu device nodes

Signed-off-by: Evan Lezar <elezar@nvidia.com>
---
 .../system/nvdevices/control-device-nodes.go  | 23 ++++++
 internal/system/nvdevices/devices.go          | 46 +++++++++++
 internal/system/nvdevices/gpu-device-nodes.go | 76 +++++++++++++++++++
 3 files changed, 145 insertions(+)
 create mode 100644 internal/system/nvdevices/gpu-device-nodes.go

diff --git a/internal/system/nvdevices/control-device-nodes.go b/internal/system/nvdevices/control-device-nodes.go
index 2b5c6c14..b7a3b4a8 100644
--- a/internal/system/nvdevices/control-device-nodes.go
+++ b/internal/system/nvdevices/control-device-nodes.go
@@ -17,11 +17,13 @@
 package nvdevices
 
 import (
+	"errors"
 	"fmt"
 	"path/filepath"
 	"strings"
 
 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
+	"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
 )
 
 // A controlDeviceNode represents an NVIDIA devices node for control or meta devices.
@@ -43,6 +45,27 @@ func (m *Interface) CreateNVIDIAControlDevices() error {
 	return nil
 }
 
+// CreateNVIDIACapsControlDeviceNodes creates the nvidia-caps control device nodes at the configured devRoot.
+func (m *Interface) CreateNVIDIACapsControlDeviceNodes() error {
+	capsMajor, exists := m.Get("nvidia-caps")
+	if !exists {
+		return nil
+	}
+
+	var errs error
+	for _, migCap := range []nvcaps.MigCap{"config", "monitor"} {
+		migMinor, exists := m.migCaps[migCap]
+		if !exists {
+			continue
+		}
+		deviceNodePath := migMinor.DevicePath()
+		if err := m.createDeviceNode(deviceNodePath, int(capsMajor), int(migMinor)); err != nil {
+			errs = errors.Join(errs, fmt.Errorf("failed to create nvidia-caps device node %v: %w", deviceNodePath, err))
+		}
+	}
+	return errs
+}
+
 // createControlDeviceNode creates the specified NVIDIA device node at the configured devRoot.
 func (m *Interface) createControlDeviceNode(node controlDeviceNode) error {
 	if !strings.HasPrefix(string(node), "nvidia") {
diff --git a/internal/system/nvdevices/devices.go b/internal/system/nvdevices/devices.go
index ef935078..882af59e 100644
--- a/internal/system/nvdevices/devices.go
+++ b/internal/system/nvdevices/devices.go
@@ -20,9 +20,14 @@ import (
 	"errors"
 	"fmt"
 	"path/filepath"
+	"strconv"
+	"strings"
+
+	"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
 
 	"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
 	"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
+	"github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps"
 )
 
 var errInvalidDeviceNode = errors.New("invalid device node")
@@ -37,6 +42,8 @@ type Interface struct {
 	// devRoot is the root directory where device nodes are expected to exist.
 	devRoot string
 
+	migCaps nvcaps.MigCaps
+
 	mknoder
 }
 
@@ -61,6 +68,14 @@ func New(opts ...Option) (*Interface, error) {
 		i.Devices = devices
 	}
 
+	if i.migCaps == nil {
+		migCaps, err := nvcaps.NewMigCaps()
+		if err != nil {
+			return nil, fmt.Errorf("failed to load MIG caps: %w", err)
+		}
+		i.migCaps = migCaps
+	}
+
 	if i.dryRun {
 		i.mknoder = &mknodLogger{i.logger}
 	} else {
@@ -69,6 +84,37 @@ func New(opts ...Option) (*Interface, error) {
 	return i, nil
 }
 
+// CreateDeviceNodes creates the device nodes for a device with the specified identifier.
+// A list of created device nodes are returned and an error.
+func (m *Interface) CreateDeviceNodes(id device.Identifier) error {
+	switch {
+	case id.IsGpuIndex():
+		index, err := strconv.Atoi(string(id))
+		if err != nil {
+			return fmt.Errorf("invalid GPU index: %v", id)
+		}
+		return m.createGPUDeviceNode(index)
+	case id.IsMigIndex():
+		indices := strings.Split(string(id), ":")
+		if len(indices) != 2 {
+			return fmt.Errorf("invalid MIG index %v", id)
+		}
+		gpuIndex, err := strconv.Atoi(indices[0])
+		if err != nil {
+			return fmt.Errorf("invalid parent index %v: %w", indices[0], err)
+		}
+		if err := m.createGPUDeviceNode(gpuIndex); err != nil {
+			return fmt.Errorf("failed to create parent device node: %w", err)
+		}
+
+		return m.createMigDeviceNodes(gpuIndex)
+	case id.IsGpuUUID(), id.IsMigUUID(), id == "all":
+		return m.createAllGPUDeviceNodes()
+	default:
+		return fmt.Errorf("invalid device identifier: %v", id)
+	}
+}
+
 // createDeviceNode creates the specified device node with the require major and minor numbers.
 // If a devRoot is configured, this is prepended to the path.
 func (m *Interface) createDeviceNode(path string, major int, minor int) error {
diff --git a/internal/system/nvdevices/gpu-device-nodes.go b/internal/system/nvdevices/gpu-device-nodes.go
new file mode 100644
index 00000000..be75f7a9
--- /dev/null
+++ b/internal/system/nvdevices/gpu-device-nodes.go
@@ -0,0 +1,76 @@
+/**
+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+**/
+
+package nvdevices
+
+import (
+	"errors"
+	"fmt"
+	"path/filepath"
+
+	"github.com/NVIDIA/go-nvlib/pkg/nvpci"
+
+	"github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices"
+)
+
+func (m *Interface) createGPUDeviceNode(gpuIndex int) error {
+	major, exists := m.Get(devices.NVIDIAGPU)
+	if !exists {
+		return fmt.Errorf("failed to determine device major; nvidia kernel module may not be loaded")
+	}
+
+	deviceNodePath := fmt.Sprintf("/dev/nvidia%d", gpuIndex)
+	if err := m.createDeviceNode(deviceNodePath, int(major), gpuIndex); err != nil {
+		return fmt.Errorf("failed to create device node %v: %w", deviceNodePath, err)
+	}
+	return nil
+}
+
+func (m *Interface) createMigDeviceNodes(gpuIndex int) error {
+	capsMajor, exists := m.Get("nvidia-caps")
+	if !exists {
+		return nil
+	}
+	var errs error
+	for _, capsDeviceMinor := range m.migCaps.FilterForGPU(gpuIndex) {
+		capDevicePath := capsDeviceMinor.DevicePath()
+		err := m.createDeviceNode(capDevicePath, int(capsMajor), int(capsDeviceMinor))
+		errs = errors.Join(errs, fmt.Errorf("failed to create %v: %w", capDevicePath, err))
+	}
+	return errs
+}
+
+func (m *Interface) createAllGPUDeviceNodes() error {
+	gpus, err := nvpci.New(
+		nvpci.WithPCIDevicesRoot(filepath.Join(m.devRoot, nvpci.PCIDevicesRoot)),
+		nvpci.WithLogger(m.logger),
+	).GetGPUs()
+	if err != nil {
+		return fmt.Errorf("failed to get GPU information from PCI: %w", err)
+	}
+
+	count := len(gpus)
+	if count == 0 {
+		return nil
+	}
+
+	var errs error
+	for gpuIndex := 0; gpuIndex < count; gpuIndex++ {
+		errs = errors.Join(errs, m.createGPUDeviceNode(gpuIndex))
+		errs = errors.Join(errs, m.createMigDeviceNodes(gpuIndex))
+	}
+	return errs
+}