diff --git a/internal/runtime/logger.go b/internal/runtime/logger.go index 2b5c9730..879c6bc0 100644 --- a/internal/runtime/logger.go +++ b/internal/runtime/logger.go @@ -44,9 +44,15 @@ func NewLogger() *Logger { // Update constructs a Logger with a preddefined formatter func (l *Logger) Update(filename string, logLevel string, argv []string) error { + configFromArgs := parseArgs(argv) level, logLevelError := configFromArgs.getLevel(logLevel) + defer func() { + if logLevelError != nil { + l.Warn(logLevelError) + } + }() var logFiles []*os.File var argLogFileError error @@ -67,15 +73,15 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error { } argLogFileError = err } + defer func() { + if argLogFileError != nil { + l.Warnf("Failed to open log file: %v", argLogFileError) + } + }() - previous := l.Logger - l = &Logger{ - Logger: logrus.New(), - previousLogger: previous, - logFiles: logFiles, - } + newLogger := logrus.New() - l.SetLevel(level) + newLogger.SetLevel(level) if level == logrus.DebugLevel { logrus.SetReportCaller(true) // Shorten function and file names reported by the logger, by @@ -93,27 +99,25 @@ func (l *Logger) Update(filename string, logLevel string, argv []string) error { } if configFromArgs.format == "json" { - l.SetFormatter(new(logrus.JSONFormatter)) + newLogger.SetFormatter(new(logrus.JSONFormatter)) } if len(logFiles) == 0 { - l.SetOutput(io.Discard) + newLogger.SetOutput(io.Discard) } else if len(logFiles) == 1 { - l.SetOutput(logFiles[0]) + newLogger.SetOutput(logFiles[0]) } else if len(logFiles) > 1 { var writers []io.Writer for _, f := range logFiles { writers = append(writers, f) } - l.SetOutput(io.MultiWriter(writers...)) + newLogger.SetOutput(io.MultiWriter(writers...)) } - if logLevelError != nil { - l.Warn(logLevelError) - } - - if argLogFileError != nil { - l.Warnf("Failed to open log file: %v", argLogFileError) + *l = Logger{ + Logger: newLogger, + previousLogger: l.Logger, + logFiles: logFiles, } return nil @@ -127,7 +131,9 @@ func (l *Logger) Reset() error { if previous == nil { previous = logrus.New() } - l = &Logger{Logger: previous} + l.Logger = previous + l.previousLogger = nil + l.logFiles = nil }() var errs []error diff --git a/internal/runtime/logger_test.go b/internal/runtime/logger_test.go new file mode 100644 index 00000000..6999b9d4 --- /dev/null +++ b/internal/runtime/logger_test.go @@ -0,0 +1,34 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package runtime + +import ( + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" +) + +func TestLogger(t *testing.T) { + l := NewLogger() + + l.Update("", "debug", nil) + + require.Equal(t, logrus.DebugLevel, l.Logger.Level) + require.Equal(t, logrus.InfoLevel, l.previousLogger.Level) + +}