Unify GetDevices logic at internal/config/image
Some checks are pending
CI Pipeline / code-scanning (push) Waiting to run
CI Pipeline / variables (push) Waiting to run
CI Pipeline / golang (push) Waiting to run
CI Pipeline / image (push) Blocked by required conditions
CI Pipeline / e2e-test (push) Blocked by required conditions

Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Carlos Eduardo Arango Gutierrez 2025-05-28 07:33:27 +02:00
parent bb3a54f7f4
commit 02a486e3a3
No known key found for this signature in database
GPG Key ID: 42D9CB42F300A852
10 changed files with 268 additions and 204 deletions

View File

@ -13,10 +13,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
) )
const (
capSysAdmin = "CAP_SYS_ADMIN"
)
type nvidiaConfig struct { type nvidiaConfig struct {
Devices []string Devices []string
MigConfigDevices string MigConfigDevices string
@ -103,9 +99,9 @@ func loadSpec(path string) (spec *Spec) {
return return
} }
func isPrivileged(s *Spec) bool { func (s *Spec) GetCapabilities() []string {
if s.Process.Capabilities == nil { if s == nil || s.Process == nil || s.Process.Capabilities == nil {
return false return nil
} }
var caps []string var caps []string
@ -118,67 +114,22 @@ func isPrivileged(s *Spec) bool {
if err != nil { if err != nil {
log.Panicln("could not decode Process.Capabilities in OCI spec:", err) log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
} }
for _, c := range caps { return caps
if c == capSysAdmin {
return true
}
}
return false
} }
// Otherwise, parse s.Process.Capabilities as: // Otherwise, parse s.Process.Capabilities as:
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54 // github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
process := specs.Process{ capabilities := specs.LinuxCapabilities{}
Env: s.Process.Env, err := json.Unmarshal(*s.Process.Capabilities, &capabilities)
}
err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities)
if err != nil { if err != nil {
log.Panicln("could not decode Process.Capabilities in OCI spec:", err) log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
} }
fullSpec := specs.Spec{ return image.OCISpecCapabilities(capabilities).GetCapabilities()
Version: *s.Version,
Process: &process,
}
return image.IsPrivileged(&fullSpec)
} }
func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string { func isPrivileged(s *Spec) bool {
// We check if the image has at least one of the Swarm resource envvars defined and use this return image.IsPrivileged(s)
// if specified.
for _, envvar := range swarmResourceEnvvars {
if containerImage.HasEnvvar(envvar) {
return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List()
}
}
return containerImage.VisibleDevicesFromEnvVar()
}
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
// If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := image.VisibleDevicesFromMounts()
if len(devices) > 0 {
return devices
}
}
// Fallback to reading from the environment variable if privileges are correct
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
if len(devices) == 0 {
return nil
}
if privileged || hookConfig.AcceptEnvvarUnprivileged {
return devices
}
configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged")
log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged)
return nil
} }
func getMigConfigDevices(i image.CUDA) *string { func getMigConfigDevices(i image.CUDA) *string {
@ -225,7 +176,6 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
// We use the default driver capabilities by default. This is filtered to only include the // We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities // supported capabilities
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities) supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities) capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities)
@ -251,7 +201,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig { func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
legacyImage := image.IsLegacy() legacyImage := image.IsLegacy()
devices := hookConfig.getDevices(image, privileged) devices := image.VisibleDevices()
if len(devices) == 0 { if len(devices) == 0 {
// empty devices means this is not a GPU container. // empty devices means this is not a GPU container.
return nil return nil
@ -306,20 +256,30 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
s := loadSpec(path.Join(b, "config.json")) s := loadSpec(path.Join(b, "config.json"))
image, err := image.New( privileged := isPrivileged(s)
opts := []image.Option{
image.WithEnv(s.Process.Env), image.WithEnv(s.Process.Env),
image.WithMounts(s.Mounts), image.WithMounts(s.Mounts),
image.WithPrivileged(privileged),
image.WithDisableRequire(hookConfig.DisableRequire), image.WithDisableRequire(hookConfig.DisableRequire),
) image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
}
if len(hookConfig.getSwarmResourceEnvvars()) > 0 {
opts = append(opts, image.WithVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...))
}
i, err := image.New(opts...)
if err != nil { if err != nil {
log.Panicln(err) log.Panicln(err)
} }
privileged := isPrivileged(s)
return containerConfig{ return containerConfig{
Pid: h.Pid, Pid: h.Pid,
Rootfs: s.Root.Path, Rootfs: s.Root.Path,
Image: image, Image: i,
Nvidia: hookConfig.getNvidiaConfig(image, privileged), Nvidia: hookConfig.getNvidiaConfig(i, privileged),
} }
} }

View File

@ -477,9 +477,19 @@ func TestGetNvidiaConfig(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, _ := image.New( opts := []image.Option{
image.WithEnvMap(tc.env), image.WithEnvMap(tc.env),
) image.WithPrivileged(tc.privileged),
image.WithAcceptEnvvarUnprivileged(true),
}
if tc.hookConfig != nil {
if tc.hookConfig.SwarmResource != "" {
opts = append(opts, image.WithVisibleDevicesEnvVars(tc.hookConfig.SwarmResource))
}
}
image, _ := image.New(opts...)
// Wrap the call to getNvidiaConfig() in a closure. // Wrap the call to getNvidiaConfig() in a closure.
var cfg *nvidiaConfig var cfg *nvidiaConfig
getConfig := func() { getConfig := func() {
@ -622,12 +632,15 @@ func TestDeviceListSourcePriority(t *testing.T) {
}, },
), ),
image.WithMounts(tc.mountDevices), image.WithMounts(tc.mountDevices),
image.WithPrivileged(tc.privileged),
image.WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts),
image.WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged),
) )
defaultConfig, _ := config.GetDefault() defaultConfig, _ := config.GetDefault()
cfg := &hookConfig{defaultConfig} cfg := &hookConfig{defaultConfig}
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
devices = cfg.getDevices(image, tc.privileged) devices = image.VisibleDevices()
} }
// For all other tests, just grab the devices and check the results // For all other tests, just grab the devices and check the results
@ -843,10 +856,20 @@ func TestGetDevicesFromEnvvar(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, _ := image.New( opts := []image.Option{
image.WithEnvMap(tc.env), image.WithEnvMap(tc.env),
) image.WithPrivileged(true),
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars) image.WithAcceptDeviceListAsVolumeMounts(false),
image.WithAcceptEnvvarUnprivileged(false),
}
if len(tc.swarmResourceEnvvars) > 0 {
opts = append(opts, image.WithVisibleDevicesEnvVars(tc.swarmResourceEnvvars...))
}
image, _ := image.New(opts...)
devices := image.VisibleDevices()
require.EqualValues(t, tc.expectedDevices, devices) require.EqualValues(t, tc.expectedDevices, devices)
}) })
} }

View File

@ -1,89 +0,0 @@
package main
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestIsPrivileged(t *testing.T) {
var tests = []struct {
spec string
expected bool
}{
{
`
{
"ociVersion": "1.0.0",
"process": {
"capabilities": {
"bounding": [ "CAP_SYS_ADMIN" ]
}
}
}
`,
true,
},
{
`
{
"ociVersion": "1.0.0",
"process": {
"capabilities": {
"bounding": [ "CAP_SYS_OTHER" ]
}
}
}
`,
false,
},
{
`
{
"ociVersion": "1.0.0",
"process": {}
}
`,
false,
},
{
`
{
"ociVersion": "1.0.0-rc2-dev",
"process": {
"capabilities": [ "CAP_SYS_ADMIN" ]
}
}
`,
true,
},
{
`
{
"ociVersion": "1.0.0-rc2-dev",
"process": {
"capabilities": [ "CAP_SYS_OTHER" ]
}
}
`,
false,
},
{
`
{
"ociVersion": "1.0.0-rc2-dev",
"process": {}
}
`,
false,
},
}
for i, tc := range tests {
var spec Spec
_ = json.Unmarshal([]byte(tc.spec), &spec)
privileged := isPrivileged(&spec)
require.Equal(t, tc.expected, privileged, "%d: %v", i, tc)
}
}

View File

@ -24,11 +24,14 @@ import (
) )
type builder struct { type builder struct {
env map[string]string CUDA
mounts []specs.Mount
disableRequire bool disableRequire bool
} }
// Option is a functional option for creating a CUDA image.
type Option func(*builder) error
// New creates a new CUDA image from the input options. // New creates a new CUDA image from the input options.
func New(opt ...Option) (CUDA, error) { func New(opt ...Option) (CUDA, error) {
b := &builder{} b := &builder{}
@ -50,15 +53,36 @@ func (b builder) build() (CUDA, error) {
b.env[EnvVarNvidiaDisableRequire] = "true" b.env[EnvVarNvidiaDisableRequire] = "true"
} }
c := CUDA{ return b.CUDA, nil
env: b.env,
mounts: b.mounts,
}
return c, nil
} }
// Option is a functional option for creating a CUDA image. func WithAnnotationPrefixes(annotationPrefixes []string) Option {
type Option func(*builder) error return func(b *builder) error {
b.annotationPrefixes = annotationPrefixes
return nil
}
}
func WithAnnotations(annotations map[string]string) Option {
return func(b *builder) error {
b.annotations = annotations
return nil
}
}
func WithAcceptDeviceListAsVolumeMounts(acceptDeviceListAsVolumeMounts bool) Option {
return func(b *builder) error {
b.acceptDeviceListAsVolumeMounts = acceptDeviceListAsVolumeMounts
return nil
}
}
func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option {
return func(b *builder) error {
b.acceptEnvvarUnprivileged = acceptEnvvarUnprivileged
return nil
}
}
// WithDisableRequire sets the disable require option. // WithDisableRequire sets the disable require option.
func WithDisableRequire(disableRequire bool) Option { func WithDisableRequire(disableRequire bool) Option {
@ -100,3 +124,35 @@ func WithMounts(mounts []specs.Mount) Option {
return nil return nil
} }
} }
// WithPrivileged sets whether an image is privileged or not.
func WithPrivileged(isPrivileged bool) Option {
return func(b *builder) error {
b.isPrivileged = isPrivileged
return nil
}
}
// WithVisibleDevicesEnvVars sets the swarm resource for the CUDA image.
func WithVisibleDevicesEnvVars(visibleDevicesEnvVars ...string) Option {
return func(b *builder) error {
if len(visibleDevicesEnvVars) == 0 {
return fmt.Errorf("visible devices env vars cannot be empty")
}
b.visibleDevicesEnvVars = []string{}
// if resource is a single string, split it by comma
if len(visibleDevicesEnvVars) == 1 && strings.Contains(visibleDevicesEnvVars[0], ",") {
candidates := strings.Split(visibleDevicesEnvVars[0], ",")
for _, c := range candidates {
trimmed := strings.TrimSpace(c)
if len(trimmed) > 0 {
b.visibleDevicesEnvVars = append(b.visibleDevicesEnvVars, trimmed)
}
}
return nil
}
b.visibleDevicesEnvVars = append(b.visibleDevicesEnvVars, visibleDevicesEnvVars...)
return nil
}
}

View File

@ -38,8 +38,17 @@ const (
// a map of environment variable to values that can be used to perform lookups // a map of environment variable to values that can be used to perform lookups
// such as requirements. // such as requirements.
type CUDA struct { type CUDA struct {
env map[string]string visibleDevicesEnvVars []string
annotationPrefixes []string
env map[string]string
annotations map[string]string
mounts []specs.Mount mounts []specs.Mount
acceptDeviceListAsVolumeMounts bool
acceptEnvvarUnprivileged bool
isPrivileged bool
} }
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec. // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
@ -51,14 +60,16 @@ func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) {
} }
return New( return New(
WithAnnotations(spec.Annotations),
WithEnv(env), WithEnv(env),
WithMounts(spec.Mounts), WithMounts(spec.Mounts),
WithPrivileged(IsPrivileged((*OCISpec)(spec))),
) )
} }
// NewCUDAImageFromEnv creates a CUDA image from the input environment. The environment // newCUDAImageFromEnv creates a CUDA image from the input environment. The environment
// is a list of strings of the form ENVAR=VALUE. // is a list of strings of the form ENVAR=VALUE.
func NewCUDAImageFromEnv(env []string) (CUDA, error) { func newCUDAImageFromEnv(env []string) (CUDA, error) {
return New(WithEnv(env)) return New(WithEnv(env))
} }
@ -78,9 +89,14 @@ func (i CUDA) HasEnvvar(key string) bool {
// image is considered legacy if it has a CUDA_VERSION environment variable defined // image is considered legacy if it has a CUDA_VERSION environment variable defined
// and no NVIDIA_REQUIRE_CUDA environment variable defined. // and no NVIDIA_REQUIRE_CUDA environment variable defined.
func (i CUDA) IsLegacy() bool { func (i CUDA) IsLegacy() bool {
legacyCudaVersion := i.env[EnvVarCudaVersion] cudaVersion := i.Getenv(EnvVarCudaVersion)
cudaRequire := i.env[EnvVarNvidiaRequireCuda] cudaRequire := i.Getenv(EnvVarNvidiaRequireCuda)
return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 return len(cudaVersion) > 0 && len(cudaRequire) == 0
}
// IsSwarm returns whether the image is a Docker Swarm image.
func (i CUDA) IsSwarmResource() bool {
return len(i.visibleDevicesEnvVars) > 0
} }
// GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment // GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment
@ -111,24 +127,25 @@ func (i CUDA) GetRequirements() ([]string, error) {
// HasDisableRequire checks for the value of the NVIDIA_DISABLE_REQUIRE. If set // HasDisableRequire checks for the value of the NVIDIA_DISABLE_REQUIRE. If set
// to a valid (true) boolean value this can be used to disable the requirement checks // to a valid (true) boolean value this can be used to disable the requirement checks
func (i CUDA) HasDisableRequire() bool { func (i CUDA) HasDisableRequire() bool {
if disable, exists := i.env[EnvVarNvidiaDisableRequire]; exists { disable := i.Getenv(EnvVarNvidiaDisableRequire)
// i.logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", disable) if disable != "" {
d, _ := strconv.ParseBool(disable) d, _ := strconv.ParseBool(disable)
return d return d
} }
return false return false
} }
// DevicesFromEnvvars returns the devices requested by the image through environment variables // DevicesFromEnvvars returns the devices requested by the image through environment variables
func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices { func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
// We concantenate all the devices from the specified env.
var isSet bool var isSet bool
var devices []string var devices []string
requested := make(map[string]bool) requested := make(map[string]bool)
for _, envVar := range envVars { for _, envVar := range envVars {
if devs, ok := i.env[envVar]; ok { devs := i.Getenv(envVar)
if i.HasEnvvar(envVar) {
isSet = true isSet = true
}
if devs != "" {
for _, d := range strings.Split(devs, ",") { for _, d := range strings.Split(devs, ",") {
trimmed := strings.TrimSpace(d) trimmed := strings.TrimSpace(d)
if len(trimmed) == 0 { if len(trimmed) == 0 {
@ -140,12 +157,10 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
} }
} }
// Environment variable unset with legacy image: default to "all".
if !isSet && len(devices) == 0 && i.IsLegacy() { if !isSet && len(devices) == 0 && i.IsLegacy() {
return NewVisibleDevices("all") return NewVisibleDevices("all")
} }
// Environment variable unset or empty or "void": return nil
if len(devices) == 0 || requested["void"] { if len(devices) == 0 || requested["void"] {
return NewVisibleDevices("void") return NewVisibleDevices("void")
} }
@ -155,7 +170,7 @@ func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
// GetDriverCapabilities returns the requested driver capabilities. // GetDriverCapabilities returns the requested driver capabilities.
func (i CUDA) GetDriverCapabilities() DriverCapabilities { func (i CUDA) GetDriverCapabilities() DriverCapabilities {
env := i.env[EnvVarNvidiaDriverCapabilities] env := i.Getenv(EnvVarNvidiaDriverCapabilities)
capabilities := make(DriverCapabilities) capabilities := make(DriverCapabilities)
for _, c := range strings.Split(env, ",") { for _, c := range strings.Split(env, ",") {
@ -166,7 +181,7 @@ func (i CUDA) GetDriverCapabilities() DriverCapabilities {
} }
func (i CUDA) legacyVersion() (string, error) { func (i CUDA) legacyVersion() (string, error) {
cudaVersion := i.env[EnvVarCudaVersion] cudaVersion := i.Getenv(EnvVarCudaVersion)
majorMinor, err := parseMajorMinorVersion(cudaVersion) majorMinor, err := parseMajorMinorVersion(cudaVersion)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid CUDA version %v: %v", cudaVersion, err) return "", fmt.Errorf("invalid CUDA version %v: %v", cudaVersion, err)
@ -219,11 +234,14 @@ func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
// VisibleDevicesFromEnvVar returns the set of visible devices requested through // VisibleDevicesFromEnvVar returns the set of visible devices requested through
// the NVIDIA_VISIBLE_DEVICES environment variable. // the NVIDIA_VISIBLE_DEVICES environment variable.
func (i CUDA) VisibleDevicesFromEnvVar() []string { func (i CUDA) VisibleDevicesFromEnvVar() []string {
if i.IsSwarmResource() {
return i.DevicesFromEnvvars(i.visibleDevicesEnvVars...).List()
}
return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List() return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List()
} }
// VisibleDevicesFromMounts returns the set of visible devices requested as mounts. // visibleDevicesFromMounts returns the set of visible devices requested as mounts.
func (i CUDA) VisibleDevicesFromMounts() []string { func (i CUDA) visibleDevicesFromMounts() []string {
var devices []string var devices []string
for _, device := range i.DevicesFromMounts() { for _, device := range i.DevicesFromMounts() {
switch { switch {
@ -238,7 +256,6 @@ func (i CUDA) VisibleDevicesFromMounts() []string {
} }
// DevicesFromMounts returns a list of device specified as mounts. // DevicesFromMounts returns a list of device specified as mounts.
// TODO: This should be merged with getDevicesFromMounts used in the NVIDIA Container Runtime
func (i CUDA) DevicesFromMounts() []string { func (i CUDA) DevicesFromMounts() []string {
root := filepath.Clean(DeviceListAsVolumeMountsRoot) root := filepath.Clean(DeviceListAsVolumeMountsRoot)
seen := make(map[string]bool) seen := make(map[string]bool)
@ -271,6 +288,28 @@ func (i CUDA) DevicesFromMounts() []string {
return devices return devices
} }
func (i CUDA) VisibleDevices() []string {
// If enabled, try and get the device list from volume mounts first
if i.acceptDeviceListAsVolumeMounts {
devices := i.visibleDevicesFromMounts()
if len(devices) > 0 {
return devices
}
}
// Fallback to reading from the environment variable if privileges are correct
devices := i.VisibleDevicesFromEnvVar()
if len(devices) == 0 {
return nil
}
if i.isPrivileged || i.acceptEnvvarUnprivileged {
return devices
}
return nil
}
// CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image. // CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
func (i CUDA) CDIDevicesFromMounts() []string { func (i CUDA) CDIDevicesFromMounts() []string {
var devices []string var devices []string

View File

@ -122,7 +122,7 @@ func TestGetRequirements(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, err := NewCUDAImageFromEnv(tc.env) image, err := newCUDAImageFromEnv(tc.env)
require.NoError(t, err) require.NoError(t, err)
requirements, err := image.GetRequirements() requirements, err := image.GetRequirements()
@ -198,7 +198,7 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, _ := New(WithMounts(tc.mounts)) image, _ := New(WithMounts(tc.mounts))
require.Equal(t, tc.expectedDevices, image.VisibleDevicesFromMounts()) require.Equal(t, tc.expectedDevices, image.visibleDevicesFromMounts())
}) })
} }
} }
@ -224,7 +224,7 @@ func TestImexChannelsFromEnvVar(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} { for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} {
t.Run(tc.description+id, func(t *testing.T) { t.Run(tc.description+id, func(t *testing.T) {
i, err := NewCUDAImageFromEnv(append(baseEnvvars, tc.env...)) i, err := newCUDAImageFromEnv(append(baseEnvvars, tc.env...))
require.NoError(t, err) require.NoError(t, err)
channels := i.ImexChannelsFromEnvVar() channels := i.ImexChannelsFromEnvVar()

View File

@ -16,28 +16,45 @@
package image package image
import ( import "github.com/opencontainers/runtime-spec/specs-go"
"github.com/opencontainers/runtime-spec/specs-go"
)
const ( const (
capSysAdmin = "CAP_SYS_ADMIN" capSysAdmin = "CAP_SYS_ADMIN"
) )
type CapabilitiesGetter interface {
GetCapabilities() []string
}
type OCISpec specs.Spec
type OCISpecCapabilities specs.LinuxCapabilities
// IsPrivileged returns true if the container is a privileged container. // IsPrivileged returns true if the container is a privileged container.
func IsPrivileged(s *specs.Spec) bool { func IsPrivileged(s CapabilitiesGetter) bool {
if s.Process.Capabilities == nil { if s == nil {
return false return false
} }
for _, c := range s.GetCapabilities() {
// We only make sure that the bounding capabibility set has
// CAP_SYS_ADMIN. This allows us to make sure that the container was
// actually started as '--privileged', but also allow non-root users to
// access the privileged NVIDIA capabilities.
for _, c := range s.Process.Capabilities.Bounding {
if c == capSysAdmin { if c == capSysAdmin {
return true return true
} }
} }
return false return false
} }
func (s OCISpec) GetCapabilities() []string {
if s.Process == nil || s.Process.Capabilities == nil {
return nil
}
return (*OCISpecCapabilities)(s.Process.Capabilities).GetCapabilities()
}
func (c OCISpecCapabilities) GetCapabilities() []string {
// We only make sure that the bounding capabibility set has
// CAP_SYS_ADMIN. This allows us to make sure that the container was
// actually started as '--privileged', but also allow non-root users to
// access the privileged NVIDIA capabilities.
return c.Bounding
}

View File

@ -0,0 +1,57 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package image
import (
"testing"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require"
)
func TestIsPrivileged(t *testing.T) {
var tests = []struct {
spec specs.Spec
expected bool
}{
{
specs.Spec{
Process: &specs.Process{
Capabilities: &specs.LinuxCapabilities{
Bounding: []string{"CAP_SYS_ADMIN"},
},
},
},
true,
},
{
specs.Spec{
Process: &specs.Process{
Capabilities: &specs.LinuxCapabilities{
Bounding: []string{"CAP_SYS_FOO"},
},
},
},
false,
},
}
for i, tc := range tests {
privileged := IsPrivileged((*OCISpec)(&tc.spec))
require.Equal(t, tc.expected, privileged, "%d: %v", i, tc)
}
}

View File

@ -107,7 +107,7 @@ func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.C
return nil, nil return nil, nil
} }
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged(rawSpec) { if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged((*image.OCISpec)(rawSpec)) {
return devices, nil return devices, nil
} }

View File

@ -37,8 +37,9 @@ import (
// //
// If not devices are selected, no changes are made. // If not devices are selected, no changes are made.
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 { devices := image.VisibleDevices()
logger.Infof("No modification required; no devices requested") if len(devices) == 0 {
logger.Debugf("No modification required; no devices requested")
return nil, nil return nil, nil
} }