diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index f094d78f..c604dff1 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -40,21 +40,8 @@ func NewGraphicsModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci. return nil, err } - if devices := image.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { - logger.Infof("No modification required; no devices requested") - return nil, nil - } - - var hasGraphics bool - for _, c := range strings.Split(image["NVIDIA_DRIVER_CAPABILITIES"], ",") { - if c == "graphics" || c == "all" { - hasGraphics = true - break - } - } - - if !hasGraphics { - logger.Debugf("Capability %q not selected", "graphics") + if required, reason := requiresGraphicsModifier(image); !required { + logger.Infof("No graphics modifier required: %v", reason) return nil, nil } @@ -65,3 +52,24 @@ func NewGraphicsModifier(logger *logrus.Logger, cfg *config.Config, ociSpec oci. return NewModifierFromDiscoverer(logger, d) } + +// requiresGraphicsModifier determines whether a graphics modifier is required. +func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) { + if devices := cudaImage.DevicesFromEnvvars(visibleDevicesEnvvar); len(devices.List()) == 0 { + return false, "no devices requested" + } + + var hasGraphics bool + for _, c := range strings.Split(cudaImage["NVIDIA_DRIVER_CAPABILITIES"], ",") { + if c == "graphics" || c == "all" { + hasGraphics = true + break + } + } + + if !hasGraphics { + return false, fmt.Sprintf("Capability %q not selected", "graphics") + } + + return true, "" +} diff --git a/internal/modifier/graphics_test.go b/internal/modifier/graphics_test.go new file mode 100644 index 00000000..bb2b5cfb --- /dev/null +++ b/internal/modifier/graphics_test.go @@ -0,0 +1,88 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package modifier + +import ( + "testing" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" + "github.com/stretchr/testify/require" +) + +func TestGraphicsModifier(t *testing.T) { + testCases := []struct { + description string + cudaImage image.CUDA + expectedRequired bool + }{ + { + description: "empty image does not create modifier", + }, + { + description: "devices with no capabilities does not create modifier", + cudaImage: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "all", + }, + }, + { + description: "devices with no non-graphics does not create modifier", + cudaImage: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "compute", + }, + }, + { + description: "devices with all capabilities creates modifier", + cudaImage: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "all", + }, + expectedRequired: true, + }, + { + description: "devices with graphics capability creates modifier", + cudaImage: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "graphics", + }, + expectedRequired: true, + }, + { + description: "devices with compute,graphics capability creates modifier", + cudaImage: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "compute,graphics", + }, + expectedRequired: true, + }, + { + description: "devices with display,graphics capability creates modifier", + cudaImage: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "display,graphics", + }, + expectedRequired: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + required, _ := requiresGraphicsModifier(tc.cudaImage) + require.EqualValues(t, tc.expectedRequired, required) + }) + } +}