diff --git a/pkg/nvpci/logger.go b/pkg/nvpci/logger.go new file mode 100644 index 0000000..8b7b5ab --- /dev/null +++ b/pkg/nvpci/logger.go @@ -0,0 +1,29 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package nvpci + +import "log" + +type logger interface { + Warningf(string, ...interface{}) +} + +type simpleLogger struct{} + +func (l simpleLogger) Warningf(format string, v ...interface{}) { + log.Printf("WARNING: "+format, v) +} diff --git a/pkg/nvpci/nvpci.go b/pkg/nvpci/nvpci.go index 44642f2..f5a7826 100644 --- a/pkg/nvpci/nvpci.go +++ b/pkg/nvpci/nvpci.go @@ -68,6 +68,7 @@ type ResourceInterface interface { } type nvpci struct { + logger logger pciDevicesRoot string pcidbPath string } @@ -134,6 +135,9 @@ func New(opts ...Option) Interface { for _, opt := range opts { opt(n) } + if n.logger == nil { + n.logger = &simpleLogger{} + } if n.pciDevicesRoot == "" { n.pciDevicesRoot = PCIDevicesRoot } @@ -143,6 +147,13 @@ func New(opts ...Option) Interface { // Option defines a function for passing options to the New() call type Option func(*nvpci) +// WithLogger provides an Option to set the logger for the library +func WithLogger(logger logger) Option { + return func(n *nvpci) { + n.logger = logger + } +} + // WithPCIDevicesRoot provides an Option to set the root path // for PCI devices on the system. func WithPCIDevicesRoot(root string) Option { @@ -308,12 +319,12 @@ func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) { deviceName, err := pciDB.GetDeviceName(uint16(vendorID), uint16(deviceID)) if err != nil { - fmt.Printf("WARNING: unable to get device name: %v\n", err) + p.logger.Warningf("unable to get device name: %v\n", err) deviceName = UnknownDeviceString } className, err := pciDB.GetClassName(uint32(classID)) if err != nil { - fmt.Printf("WARNING: unable to get class name for device: %v\n", err) + p.logger.Warningf("unable to get class name for device: %v\n", err) className = UnknownClassString }