nvidia-container-toolkit/internal/runtime/logger.go
Evan Lezar 2c8431c1f8 Fix bug in argument parsing for logger creation
Signed-off-by: Evan Lezar <elezar@nvidia.com>
2024-07-01 12:03:53 +02:00

255 lines
5.5 KiB
Go

/*
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/
package runtime
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"github.com/sirupsen/logrus"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)
// Logger adds a way to manage output to a log file to a logrus.Logger
type Logger struct {
logger.Interface
previousLogger logger.Interface
logFiles []*os.File
}
// NewLogger creates an empty logger
func NewLogger() *Logger {
return &Logger{
Interface: logrus.New(),
}
}
// Update constructs a Logger with a preddefined formatter
func (l *Logger) Update(filename string, logLevel string, argv []string) {
configFromArgs := parseArgs(argv)
level, logLevelError := configFromArgs.getLevel(logLevel)
defer func() {
if logLevelError != nil {
l.Warning(logLevelError)
}
}()
var logFiles []*os.File
var argLogFileError error
// We don't create log files if the version argument is supplied
if !configFromArgs.version {
configLogFile, err := createLogFile(filename)
if err != nil {
argLogFileError = errors.Join(argLogFileError, err)
}
if configLogFile != nil {
logFiles = append(logFiles, configLogFile)
}
argLogFile, err := createLogFile(configFromArgs.file)
if argLogFile != nil {
logFiles = append(logFiles, argLogFile)
}
argLogFileError = errors.Join(argLogFileError, err)
}
defer func() {
if argLogFileError != nil {
l.Warningf("Failed to open log file: %v", argLogFileError)
}
}()
newLogger := logrus.New()
newLogger.SetLevel(level)
if level == logrus.DebugLevel {
logrus.SetReportCaller(true)
// Shorten function and file names reported by the logger, by
// trimming common "github.com/opencontainers/runc" prefix.
// This is only done for text formatter.
_, file, _, _ := runtime.Caller(0)
prefix := filepath.Dir(file) + "/"
logrus.SetFormatter(&logrus.TextFormatter{
CallerPrettyfier: func(f *runtime.Frame) (string, string) {
function := strings.TrimPrefix(f.Function, prefix) + "()"
fileLine := strings.TrimPrefix(f.File, prefix) + ":" + strconv.Itoa(f.Line)
return function, fileLine
},
})
}
if configFromArgs.format == "json" {
newLogger.SetFormatter(new(logrus.JSONFormatter))
}
switch len(logFiles) {
case 0:
newLogger.SetOutput(io.Discard)
case 1:
newLogger.SetOutput(logFiles[0])
default:
var writers []io.Writer
for _, f := range logFiles {
writers = append(writers, f)
}
newLogger.SetOutput(io.MultiWriter(writers...))
}
*l = Logger{
Interface: newLogger,
previousLogger: l.Interface,
logFiles: logFiles,
}
}
// Reset closes the log file (if any) and resets the logger output to what it
// was before UpdateLogger was called.
func (l *Logger) Reset() error {
defer func() {
previous := l.previousLogger
if previous == nil {
previous = logrus.New()
}
l.Interface = previous
l.previousLogger = nil
l.logFiles = nil
}()
var errs []error
for _, f := range l.logFiles {
err := f.Close()
if err != nil {
errs = append(errs, err)
}
}
var err error
for _, e := range errs {
if err == nil {
err = e
continue
}
return fmt.Errorf("%v; %w", e, err)
}
return err
}
func createLogFile(filename string) (*os.File, error) {
if filename == "" || filename == os.DevNull {
return nil, nil
}
if dir := filepath.Dir(filepath.Clean(filename)); dir != "." {
err := os.MkdirAll(dir, 0755)
if err != nil {
return nil, err
}
}
return os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
}
type loggerConfig struct {
file string
format string
debug bool
version bool
}
func (c loggerConfig) getLevel(logLevel string) (logrus.Level, error) {
if c.debug {
return logrus.DebugLevel, nil
}
if logLevel, err := logrus.ParseLevel(logLevel); err == nil {
return logLevel, nil
}
return logrus.InfoLevel, fmt.Errorf("invalid log-level '%v'", logLevel)
}
// Informed by Taken from https://github.com/opencontainers/runc/blob/7fd8b57001f5bfa102e89cb434d96bf71f7c1d35/main.go#L182
func parseArgs(args []string) loggerConfig {
c := loggerConfig{}
expected := map[string]*string{
"log-format": &c.format,
"log": &c.file,
}
found := make(map[string]bool)
for i := 0; i < len(args); i++ {
if len(found) == 4 {
break
}
param := args[i]
parts := strings.SplitN(param, "=", 2)
trimmed := strings.TrimLeft(parts[0], "-")
// If this is not a flag we continue
if parts[0] == trimmed {
continue
}
// Check the version flag
if trimmed == "version" {
c.version = true
found["version"] = true
// For the version flag we don't process any other flags
continue
}
// Check the debug flag
if trimmed == "debug" {
c.debug = true
found["debug"] = true
continue
}
destination, exists := expected[trimmed]
if !exists {
continue
}
var value string
switch {
case len(parts) == 2:
value = parts[1]
case i+1 < len(args):
value = args[i+1]
i++
default:
continue
}
*destination = value
found[trimmed] = true
}
return c
}