diff --git a/cmd/nvidia-ctk/hook/create-dev-char-symlinks/all.go b/cmd/nvidia-ctk/hook/create-dev-char-symlinks/all.go new file mode 100644 index 00000000..92d3d631 --- /dev/null +++ b/cmd/nvidia-ctk/hook/create-dev-char-symlinks/all.go @@ -0,0 +1,175 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package devchar + +import ( + "fmt" + "path/filepath" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/info/proc/devices" + "github.com/NVIDIA/nvidia-container-toolkit/internal/nvcaps" + "github.com/sirupsen/logrus" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvpci" +) + +type allPossible struct { + logger *logrus.Logger + driverRoot string + deviceMajors devices.Devices + migCaps nvcaps.MigCaps +} + +// newAllPossible returns a new allPossible device node lister. +// This lister lists all possible device nodes for NVIDIA GPUs, control devices, and capability devices. +func newAllPossible(logger *logrus.Logger, driverRoot string) (nodeLister, error) { + deviceMajors, err := devices.GetNVIDIADevices() + if err != nil { + return nil, fmt.Errorf("failed reading device majors: %v", err) + } + migCaps, err := nvcaps.NewMigCaps() + if err != nil { + return nil, fmt.Errorf("failed to read MIG caps: %v", err) + } + if migCaps == nil { + migCaps = make(nvcaps.MigCaps) + } + + l := allPossible{ + logger: logger, + driverRoot: driverRoot, + deviceMajors: deviceMajors, + migCaps: migCaps, + } + + return l, nil +} + +// DeviceNodes returns a list of all possible device nodes for NVIDIA GPUs, control devices, and capability devices. +func (m allPossible) DeviceNodes() ([]deviceNode, error) { + gpus, err := nvpci.NewFrom( + filepath.Join(m.driverRoot, nvpci.PCIDevicesRoot), + ).GetGPUs() + if err != nil { + return nil, fmt.Errorf("failed to get GPU information: %v", err) + } + + count := len(gpus) + if count == 0 { + m.logger.Infof("No NVIDIA devices found in %s", m.driverRoot) + return nil, nil + } + + deviceNodes, err := m.getControlDeviceNodes() + if err != nil { + return nil, fmt.Errorf("failed to get control device nodes: %v", err) + } + + for gpu := 0; gpu < count; gpu++ { + deviceNodes = append(deviceNodes, m.getGPUDeviceNodes(gpu)...) + deviceNodes = append(deviceNodes, m.getNVCapDeviceNodes(gpu)...) + } + + return deviceNodes, nil +} + +// getControlDeviceNodes generates a list of control devices +func (m allPossible) getControlDeviceNodes() ([]deviceNode, error) { + var deviceNodes []deviceNode + + // Define the control devices for standard GPUs. + controlDevices := []deviceNode{ + m.newDeviceNode(devices.NVIDIAGPU, "/dev/nvidia-modeset", devices.NVIDIAModesetMinor), + m.newDeviceNode(devices.NVIDIAGPU, "/dev/nvidiactl", devices.NVIDIACTLMinor), + m.newDeviceNode(devices.NVIDIAUVM, "/dev/nvidia-uvm", devices.NVIDIAUVMMinor), + m.newDeviceNode(devices.NVIDIAUVM, "/dev/nvidia-uvm-tools", devices.NVIDIAUVMToolsMinor), + } + + deviceNodes = append(deviceNodes, controlDevices...) + + for _, migControlDevice := range []nvcaps.MigCap{"config", "monitor"} { + migControlMinor, exist := m.migCaps[migControlDevice] + if !exist { + continue + } + + d := m.newDeviceNode( + devices.NVIDIACaps, + migControlMinor.DevicePath(), + int(migControlMinor), + ) + + deviceNodes = append(deviceNodes, d) + } + + return deviceNodes, nil +} + +// getGPUDeviceNodes generates a list of device nodes for a given GPU. +func (m allPossible) getGPUDeviceNodes(gpu int) []deviceNode { + d := m.newDeviceNode( + devices.NVIDIAGPU, + fmt.Sprintf("/dev/nvidia%d", gpu), + gpu, + ) + + return []deviceNode{d} +} + +// getNVCapDeviceNodes generates a list of cap device nodes for a given GPU. +func (m allPossible) getNVCapDeviceNodes(gpu int) []deviceNode { + var selectedCapMinors []nvcaps.MigMinor + for gi := 0; ; gi++ { + giCap := nvcaps.NewGPUInstanceCap(gpu, gi) + giMinor, exist := m.migCaps[giCap] + if !exist { + break + } + selectedCapMinors = append(selectedCapMinors, giMinor) + for ci := 0; ; ci++ { + ciCap := nvcaps.NewComputeInstanceCap(gpu, gi, ci) + ciMinor, exist := m.migCaps[ciCap] + if !exist { + break + } + selectedCapMinors = append(selectedCapMinors, ciMinor) + } + } + + var deviceNodes []deviceNode + for _, capMinor := range selectedCapMinors { + d := m.newDeviceNode( + devices.NVIDIACaps, + capMinor.DevicePath(), + int(capMinor), + ) + deviceNodes = append(deviceNodes, d) + } + + return deviceNodes +} + +// newDeviceNode creates a new device node with the specified path and major/minor numbers. +// The path is adjusted for the specified driver root. +func (m allPossible) newDeviceNode(deviceName devices.Name, path string, minor int) deviceNode { + major, _ := m.deviceMajors.Get(deviceName) + + return deviceNode{ + path: filepath.Join(m.driverRoot, path), + major: uint32(major), + minor: uint32(minor), + } +} diff --git a/cmd/nvidia-ctk/hook/create-dev-char-symlinks/create-dev-char-symlinks.go b/cmd/nvidia-ctk/hook/create-dev-char-symlinks/create-dev-char-symlinks.go index a1473e98..bec0314e 100644 --- a/cmd/nvidia-ctk/hook/create-dev-char-symlinks/create-dev-char-symlinks.go +++ b/cmd/nvidia-ctk/hook/create-dev-char-symlinks/create-dev-char-symlinks.go @@ -42,6 +42,7 @@ type config struct { driverRoot string dryRun bool watch bool + createAll bool } // NewCommand constructs a hook sub-command with the specified logger @@ -60,6 +61,9 @@ func (m command) build() *cli.Command { c := cli.Command{ Name: "create-dev-char-symlinks", Usage: "A hook to create symlinks to possible /dev/nv* devices in /dev/char", + Before: func(c *cli.Context) error { + return m.validateFlags(c, &cfg) + }, Action: func(c *cli.Context) error { return m.run(c, &cfg) }, @@ -87,6 +91,12 @@ func (m command) build() *cli.Command { Destination: &cfg.watch, EnvVars: []string{"WATCH"}, }, + &cli.BoolFlag{ + Name: "create-all", + Usage: "Create all possible /dev/char symlinks instead of limiting these to existing device nodes.", + Destination: &cfg.createAll, + EnvVars: []string{"CREATE_ALL"}, + }, &cli.BoolFlag{ Name: "dry-run", Usage: "If set, the command will not create any symlinks.", @@ -99,8 +109,15 @@ func (m command) build() *cli.Command { return &c } -func (m command) run(c *cli.Context, cfg *config) error { +func (m command) validateFlags(r *cli.Context, cfg *config) error { + if cfg.createAll && cfg.watch { + return fmt.Errorf("create-all and watch are mutually exclusive") + } + return nil +} + +func (m command) run(c *cli.Context, cfg *config) error { var watcher *fsnotify.Watcher var sigs chan os.Signal @@ -114,14 +131,19 @@ func (m command) run(c *cli.Context, cfg *config) error { sigs = newOSWatcher(syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) } - l := NewSymlinkCreator( + l, err := NewSymlinkCreator( WithLogger(m.logger), WithDevCharPath(cfg.devCharPath), WithDriverRoot(cfg.driverRoot), WithDryRun(cfg.dryRun), + WithCreateAll(cfg.createAll), ) + if err != nil { + return fmt.Errorf("failed to create symlink creator: %v", err) + } + create: - err := l.CreateLinks() + err = l.CreateLinks() if err != nil { return fmt.Errorf("failed to create links: %v", err) } @@ -169,6 +191,7 @@ type linkCreator struct { driverRoot string devCharPath string dryRun bool + createAll bool } // Creator is an interface for creating symlinks to /dev/nv* devices in /dev/char. @@ -180,7 +203,7 @@ type Creator interface { type Option func(*linkCreator) // NewSymlinkCreator creates a new linkCreator. -func NewSymlinkCreator(opts ...Option) Creator { +func NewSymlinkCreator(opts ...Option) (Creator, error) { c := linkCreator{} for _, opt := range opts { opt(&c) @@ -194,10 +217,17 @@ func NewSymlinkCreator(opts ...Option) Creator { if c.devCharPath == "" { c.devCharPath = defaultDevCharPath } - if c.lister == nil { + + if c.createAll { + lister, err := newAllPossible(c.logger, c.driverRoot) + if err != nil { + return nil, fmt.Errorf("failed to create all possible device lister: %v", err) + } + c.lister = lister + } else { c.lister = existing{c.logger, c.driverRoot} } - return c + return c, nil } // WithDriverRoot sets the driver root path. @@ -228,7 +258,14 @@ func WithLogger(logger *logrus.Logger) Option { } } -// CreateLinks creates symlinks for all device nodes returned by the configured lister. +// WithCreateAll sets the createAll flag for the linkCreator. +func WithCreateAll(createAll bool) Option { + return func(lc *linkCreator) { + lc.createAll = createAll + } +} + +// CreateLinks creates symlinks for all NVIDIA device nodes found in the driver root. func (m linkCreator) CreateLinks() error { deviceNodes, err := m.lister.DeviceNodes() if err != nil {