Implement runtime package for creating runtime CLI

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar
2023-02-07 21:09:53 +01:00
parent f71c419cfb
commit 406a5ec76f
7 changed files with 189 additions and 94 deletions

33
internal/runtime/api.go Normal file
View File

@@ -0,0 +1,33 @@
package runtime
type rt struct {
logger *Logger
modeOverride string
}
// Interface is the interface for the runtime library.
type Interface interface {
Run([]string) error
}
// Option is a function that configures the runtime.
type Option func(*rt)
// New creates a runtime with the specified options.
func New(opts ...Option) Interface {
r := rt{}
for _, opt := range opts {
opt(&r)
}
if r.logger == nil {
r.logger = NewLogger()
}
return &r
}
// WithModeOverride allows for overriding the mode specified in the config.
func WithModeOverride(mode string) Option {
return func(r *rt) {
r.modeOverride = mode
}
}

240
internal/runtime/logger.go Normal file
View File

@@ -0,0 +1,240 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"github.com/sirupsen/logrus"
)
// Logger adds a way to manage output to a log file to a logrus.Logger
type Logger struct {
*logrus.Logger
previousLogger *logrus.Logger
logFiles []*os.File
}
// NewLogger creates an empty logger
func NewLogger() *Logger {
return &Logger{
Logger: logrus.New(),
}
}
// Update constructs a Logger with a preddefined formatter
func (l *Logger) Update(filename string, logLevel string, argv []string) error {
configFromArgs := parseArgs(argv)
level, logLevelError := configFromArgs.getLevel(logLevel)
var logFiles []*os.File
var argLogFileError error
// We don't create log files if the version argument is supplied
if !configFromArgs.version {
configLogFile, err := createLogFile(filename)
if err != nil {
return fmt.Errorf("error opening debug log file: %v", err)
}
if configLogFile != nil {
logFiles = append(logFiles, configLogFile)
}
argLogFile, err := createLogFile(configFromArgs.file)
if argLogFile != nil {
logFiles = append(logFiles, argLogFile)
}
argLogFileError = err
}
previous := l.Logger
l = &Logger{
Logger: logrus.New(),
previousLogger: previous,
logFiles: logFiles,
}
l.SetLevel(level)
if level == logrus.DebugLevel {
logrus.SetReportCaller(true)
// Shorten function and file names reported by the logger, by
// trimming common "github.com/opencontainers/runc" prefix.
// This is only done for text formatter.
_, file, _, _ := runtime.Caller(0)
prefix := filepath.Dir(file) + "/"
logrus.SetFormatter(&logrus.TextFormatter{
CallerPrettyfier: func(f *runtime.Frame) (string, string) {
function := strings.TrimPrefix(f.Function, prefix) + "()"
fileLine := strings.TrimPrefix(f.File, prefix) + ":" + strconv.Itoa(f.Line)
return function, fileLine
},
})
}
if configFromArgs.format == "json" {
l.SetFormatter(new(logrus.JSONFormatter))
}
if len(logFiles) == 0 {
l.SetOutput(io.Discard)
} else if len(logFiles) == 1 {
l.SetOutput(logFiles[0])
} else if len(logFiles) > 1 {
var writers []io.Writer
for _, f := range logFiles {
writers = append(writers, f)
}
l.SetOutput(io.MultiWriter(writers...))
}
if logLevelError != nil {
l.Warn(logLevelError)
}
if argLogFileError != nil {
l.Warnf("Failed to open log file: %v", argLogFileError)
}
return nil
}
// Reset closes the log file (if any) and resets the logger output to what it
// was before UpdateLogger was called.
func (l *Logger) Reset() error {
defer func() {
previous := l.previousLogger
if previous == nil {
previous = logrus.New()
}
l = &Logger{Logger: previous}
}()
var errs []error
for _, f := range l.logFiles {
err := f.Close()
if err != nil {
errs = append(errs, err)
}
}
var err error
for _, e := range errs {
if err == nil {
err = e
continue
}
return fmt.Errorf("%v; %w", e, err)
}
return err
}
func createLogFile(filename string) (*os.File, error) {
if filename != "" && filename != os.DevNull {
return os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
}
return nil, nil
}
type loggerConfig struct {
file string
format string
debug bool
version bool
}
func (c loggerConfig) getLevel(logLevel string) (logrus.Level, error) {
if c.debug {
return logrus.DebugLevel, nil
}
if logLevel, err := logrus.ParseLevel(logLevel); err == nil {
return logLevel, nil
}
return logrus.InfoLevel, fmt.Errorf("invalid log-level '%v'", logLevel)
}
// Informed by Taken from https://github.com/opencontainers/runc/blob/7fd8b57001f5bfa102e89cb434d96bf71f7c1d35/main.go#L182
func parseArgs(args []string) loggerConfig {
c := loggerConfig{}
expected := map[string]*string{
"log-format": &c.format,
"log": &c.file,
}
found := make(map[string]bool)
for i := 0; i < len(args); i++ {
if len(found) == 4 {
break
}
param := args[i]
parts := strings.SplitN(param, "=", 2)
trimmed := strings.TrimLeft(parts[0], "-")
// If this is not a flag we continue
if parts[0] == trimmed {
continue
}
// Check the version flag
if trimmed == "version" {
c.version = true
found["version"] = true
// For the version flag we don't process any other flags
continue
}
// Check the debug flag
if trimmed == "debug" {
c.debug = true
found["debug"] = true
continue
}
destination, exists := expected[trimmed]
if !exists {
continue
}
var value string
if len(parts) == 2 {
value = parts[2]
} else if i+1 < len(args) {
value = args[i+1]
i++
} else {
continue
}
*destination = value
found[trimmed] = true
}
return c
}

101
internal/runtime/runtime.go Normal file
View File

@@ -0,0 +1,101 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package runtime
import (
"fmt"
"strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/opencontainers/runtime-spec/specs-go"
)
// Run is an entry point that allows for idiomatic handling of errors
// when calling from the main function.
func (r rt) Run(argv []string) (rerr error) {
defer func() {
if rerr != nil {
r.logger.Errorf("%v", rerr)
}
}()
printVersion := hasVersionFlag(argv)
if printVersion {
fmt.Printf("%v version %v\n", "NVIDIA Container Runtime", info.GetVersionString(fmt.Sprintf("spec: %v", specs.Version)))
}
cfg, err := config.GetConfig()
if err != nil {
return fmt.Errorf("error loading config: %v", err)
}
if r.modeOverride != "" {
cfg.NVIDIAContainerRuntimeConfig.Mode = r.modeOverride
}
err = r.logger.Update(
cfg.NVIDIAContainerRuntimeConfig.DebugFilePath,
cfg.NVIDIAContainerRuntimeConfig.LogLevel,
argv,
)
if err != nil {
return fmt.Errorf("failed to set up logger: %v", err)
}
defer func() {
if rerr != nil {
r.logger.Errorf("%v", rerr)
}
r.logger.Reset()
}()
r.logger.Infof("Using config %+v", cfg)
r.logger.Debugf("Command line arguments: %v", argv)
runtime, err := newNVIDIAContainerRuntime(r.logger.Logger, cfg, argv)
if err != nil {
return fmt.Errorf("failed to create NVIDIA Container Runtime: %v", err)
}
if printVersion {
fmt.Print("\n")
}
return runtime.Exec(argv)
}
func (r rt) Errorf(format string, args ...interface{}) {
r.logger.Errorf(format, args...)
}
// TODO: This should be refactored / combined with parseArgs in logger.
func hasVersionFlag(args []string) bool {
for i := 0; i < len(args); i++ {
param := args[i]
parts := strings.SplitN(param, "=", 2)
trimmed := strings.TrimLeft(parts[0], "-")
// If this is not a flag we continue
if parts[0] == trimmed {
continue
}
// Check the version flag
if trimmed == "version" {
return true
}
}
return false
}

View File

@@ -0,0 +1,110 @@
/*
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"fmt"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
"github.com/sirupsen/logrus"
)
// newNVIDIAContainerRuntime is a factory method that constructs a runtime based on the selected configuration and specified logger
func newNVIDIAContainerRuntime(logger *logrus.Logger, cfg *config.Config, argv []string) (oci.Runtime, error) {
lowLevelRuntime, err := oci.NewLowLevelRuntime(logger, cfg.NVIDIAContainerRuntimeConfig.Runtimes)
if err != nil {
return nil, fmt.Errorf("error constructing low-level runtime: %v", err)
}
if !oci.HasCreateSubcommand(argv) {
logger.Debugf("Skipping modifier for non-create subcommand")
return lowLevelRuntime, nil
}
ociSpec, err := oci.NewSpec(logger, argv)
if err != nil {
return nil, fmt.Errorf("error constructing OCI specification: %v", err)
}
specModifier, err := newSpecModifier(logger, cfg, ociSpec, argv)
if err != nil {
return nil, fmt.Errorf("failed to construct OCI spec modifier: %v", err)
}
// Create the wrapping runtime with the specified modifier
r := oci.NewModifyingRuntimeWrapper(
logger,
lowLevelRuntime,
ociSpec,
specModifier,
)
return r, nil
}
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) {
modeModifier, err := newModeModifier(logger, cfg, ociSpec, argv)
if err != nil {
return nil, err
}
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, ociSpec)
if err != nil {
return nil, err
}
gdsModifier, err := modifier.NewGDSModifier(logger, cfg, ociSpec)
if err != nil {
return nil, err
}
mofedModifier, err := modifier.NewMOFEDModifier(logger, cfg, ociSpec)
if err != nil {
return nil, err
}
tegraModifier, err := modifier.NewTegraPlatformFiles(logger)
if err != nil {
return nil, err
}
modifiers := modifier.Merge(
modeModifier,
graphicsModifier,
gdsModifier,
mofedModifier,
tegraModifier,
)
return modifiers, nil
}
func newModeModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci.Spec, argv []string) (oci.SpecModifier, error) {
switch info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode) {
case "legacy":
return modifier.NewStableRuntimeModifier(logger), nil
case "csv":
return modifier.NewCSVModifier(logger, cfg, ociSpec)
case "cdi":
return modifier.NewCDIModifier(logger, cfg, ociSpec)
}
return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode)
}

View File

@@ -0,0 +1,161 @@
/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"encoding/json"
"os"
"os/exec"
"path/filepath"
"testing"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require"
)
const (
runcExecutableName = "runc"
)
func TestMain(m *testing.M) {
// TEST SETUP
// Determine the module root and the test binary path
var err error
moduleRoot, err := test.GetModuleRoot()
if err != nil {
logrus.Fatalf("error in test setup: could not get module root: %v", err)
}
testBinPath := filepath.Join(moduleRoot, "test", "bin")
// Set the environment variables for the test
os.Setenv("PATH", test.PrependToPath(testBinPath, moduleRoot))
// Confirm that the environment is configured correctly
runcPath, err := exec.LookPath(runcExecutableName)
if err != nil || filepath.Join(testBinPath, runcExecutableName) != runcPath {
logrus.Fatalf("error in test setup: mock runc path set incorrectly in TestMain(): %v", err)
}
// RUN TESTS
exitCode := m.Run()
os.Exit(exitCode)
}
func TestFactoryMethod(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
description string
cfg *config.Config
spec *specs.Spec
expectedError bool
}{
{
description: "empty config raises error",
cfg: &config.Config{
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{},
},
expectedError: true,
},
{
description: "config with runtime raises no error",
cfg: &config.Config{
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
Runtimes: []string{"runc"},
Mode: "legacy",
},
},
},
{
description: "csv mode is supported",
cfg: &config.Config{
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
Runtimes: []string{"runc"},
Mode: "csv",
},
},
spec: &specs.Spec{
Process: &specs.Process{
Env: []string{
"NVIDIA_VISIBLE_DEVICES=all",
},
},
},
},
{
description: "non-legacy discover mode raises error",
cfg: &config.Config{
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
Runtimes: []string{"runc"},
Mode: "non-legacy",
},
},
expectedError: true,
},
{
description: "legacy discover mode returns modifier",
cfg: &config.Config{
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
Runtimes: []string{"runc"},
Mode: "legacy",
},
},
},
{
description: "csv discover mode returns modifier",
cfg: &config.Config{
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
Runtimes: []string{"runc"},
Mode: "csv",
},
},
},
{
description: "empty mode raises error",
cfg: &config.Config{
NVIDIAContainerRuntimeConfig: config.RuntimeConfig{
Runtimes: []string{"runc"},
},
},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
bundleDir := t.TempDir()
specFile, err := os.Create(filepath.Join(bundleDir, "config.json"))
require.NoError(t, err)
require.NoError(t, json.NewEncoder(specFile).Encode(tc.spec))
argv := []string{"--bundle", bundleDir, "create"}
_, err = newNVIDIAContainerRuntime(logger, tc.cfg, argv)
if tc.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
})
}
}