diff --git a/internal/modifier/graphics.go b/internal/modifier/graphics.go index c604dff1..cb8daa71 100644 --- a/internal/modifier/graphics.go +++ b/internal/modifier/graphics.go @@ -18,7 +18,6 @@ package modifier import ( "fmt" - "strings" "github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" @@ -59,16 +58,8 @@ func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) { return false, "no devices requested" } - var hasGraphics bool - for _, c := range strings.Split(cudaImage["NVIDIA_DRIVER_CAPABILITIES"], ",") { - if c == "graphics" || c == "all" { - hasGraphics = true - break - } - } - - if !hasGraphics { - return false, fmt.Sprintf("Capability %q not selected", "graphics") + if !cudaImage.GetDriverCapabilities().Any(image.DriverCapabilityGraphics, image.DriverCapabilityDisplay) { + return false, "no required capabilities requested" } return true, "" diff --git a/internal/modifier/graphics_test.go b/internal/modifier/graphics_test.go index bb2b5cfb..e062763d 100644 --- a/internal/modifier/graphics_test.go +++ b/internal/modifier/graphics_test.go @@ -69,6 +69,14 @@ func TestGraphicsModifier(t *testing.T) { }, expectedRequired: true, }, + { + description: "devices with display capability creates modifier", + cudaImage: image.CUDA{ + "NVIDIA_VISIBLE_DEVICES": "all", + "NVIDIA_DRIVER_CAPABILITIES": "display", + }, + expectedRequired: true, + }, { description: "devices with display,graphics capability creates modifier", cudaImage: image.CUDA{