mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-24 04:54:00 +00:00
94 lines
2.2 KiB
Go
94 lines
2.2 KiB
Go
|
/**
|
||
|
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
**/
|
||
|
|
||
|
package nvmodules
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
|
||
|
)
|
||
|
|
||
|
// Interface provides a set of utilities for interacting with NVIDIA modules on the system.
|
||
|
type Interface struct {
|
||
|
logger logger.Interface
|
||
|
|
||
|
dryRun bool
|
||
|
root string
|
||
|
|
||
|
cmder
|
||
|
}
|
||
|
|
||
|
// New constructs a new Interface struct with the specified options.
|
||
|
func New(opts ...Option) *Interface {
|
||
|
m := &Interface{}
|
||
|
for _, opt := range opts {
|
||
|
opt(m)
|
||
|
}
|
||
|
|
||
|
if m.logger == nil {
|
||
|
m.logger = logger.New()
|
||
|
}
|
||
|
if m.root == "" {
|
||
|
m.root = "/"
|
||
|
}
|
||
|
|
||
|
if m.dryRun {
|
||
|
m.cmder = &cmderLogger{m.logger}
|
||
|
} else {
|
||
|
m.cmder = &cmderExec{}
|
||
|
}
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
// LoadAll loads all the NVIDIA kernel modules.
|
||
|
func (m *Interface) LoadAll() error {
|
||
|
modules := []string{"nvidia", "nvidia-uvm", "nvidia-modeset"}
|
||
|
|
||
|
for _, module := range modules {
|
||
|
if err := m.Load(module); err != nil {
|
||
|
return fmt.Errorf("failed to load module %s: %w", module, err)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var errInvalidModule = fmt.Errorf("invalid module")
|
||
|
|
||
|
// Load loads the specified NVIDIA kernel module.
|
||
|
// If the root is specified we first chroot into this root.
|
||
|
func (m *Interface) Load(module string) error {
|
||
|
if !strings.HasPrefix(module, "nvidia") {
|
||
|
return errInvalidModule
|
||
|
}
|
||
|
|
||
|
var args []string
|
||
|
if m.root != "/" {
|
||
|
args = append(args, "chroot", m.root)
|
||
|
}
|
||
|
args = append(args, "/sbin/modprobe", module)
|
||
|
|
||
|
m.logger.Debugf("Loading kernel module %s: %v", module, args)
|
||
|
err := m.Run(args[0], args[1:]...)
|
||
|
if err != nil {
|
||
|
m.logger.Debugf("Failed to load kernel module %s: %v", module, err)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|