diff --git a/pkg/container_config.go b/pkg/container_config.go index 3c87c307..e515f420 100644 --- a/pkg/container_config.go +++ b/pkg/container_config.go @@ -8,6 +8,8 @@ import ( "path" "strconv" "strings" + + "golang.org/x/mod/semver" ) var envSwarmGPU *string @@ -55,8 +57,8 @@ type Root struct { // github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L57 type Process struct { - Env []string `json:"env,omitempty"` - Capabilities *LinuxCapabilities `json:"capabilities,omitempty" platform:"linux"` + Env []string `json:"env,omitempty"` + Capabilities *json.RawMessage `json:"capabilities,omitempty" platform:"linux"` } // https://github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L61 @@ -71,6 +73,7 @@ type LinuxCapabilities struct { // We use pointers to structs, similarly to the latest version of runtime-spec: // https://github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L5-L28 type Spec struct { + Version *string `json:"ociVersion"` Process *Process `json:"process,omitempty"` Root *Root `json:"root,omitempty"` } @@ -133,6 +136,9 @@ func loadSpec(path string) (spec *Spec) { if err = json.NewDecoder(f).Decode(&spec); err != nil { log.Panicln("could not decode OCI spec:", err) } + if spec.Version == nil { + log.Panicln("Version is empty in OCI spec") + } if spec.Process == nil { log.Panicln("Process is empty in OCI spec") } @@ -142,29 +148,43 @@ func loadSpec(path string) (spec *Spec) { return } -func isPrivileged(caps *LinuxCapabilities) bool { - if caps == nil { +func isPrivileged(s *Spec) bool { + if s.Process.Capabilities == nil { return false } - hasCapSysAdmin := func(caps []string) bool { - for _, c := range caps { - if c == capSysAdmin { - return true - } + var caps []string + // If v1.1.0-rc1 <= OCI version < v1.0.0-rc5 parse s.Process.Capabilities as: + // github.com/opencontainers/runtime-spec/blob/v1.0.0-rc1/specs-go/config.go#L30-L54 + rc1cmp := semver.Compare("v"+*s.Version, "v1.0.0-rc1") + rc5cmp := semver.Compare("v"+*s.Version, "v1.0.0-rc5") + if (rc1cmp == 1 || rc1cmp == 0) && (rc5cmp == -1) { + err := json.Unmarshal(*s.Process.Capabilities, &caps) + if err != nil { + log.Panicln("could not decode Process.Capabilities in OCI spec:", err) } - return false + // Otherwise, parse s.Process.Capabilities as: + // github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54 + } else { + var lc LinuxCapabilities + err := json.Unmarshal(*s.Process.Capabilities, &lc) + if err != nil { + log.Panicln("could not decode Process.Capabilities in OCI spec:", err) + } + // We only make sure that the bounding capabibility set has + // CAP_SYS_ADMIN. This allows us to make sure that the container was + // actually started as '--privileged', but also allow non-root users to + // access the priviliged NVIDIA capabilities. + caps = lc.Bounding } - // We only make sure that the bounding capabibility set has - // CAP_SYS_ADMIN. This allows us to make sure that the container was - // actually started as '--privileged', but also allow non-root users to - // access the priviliged NVIDIA capabilities. - if !hasCapSysAdmin(caps.Bounding) { - return false + for _, c := range caps { + if c == capSysAdmin { + return true + } } - return true + return false } func getDevices(env map[string]string) *string { @@ -365,7 +385,7 @@ func getContainerConfig(hook HookConfig) (config containerConfig) { s := loadSpec(path.Join(b, "config.json")) env := getEnvMap(s.Process.Env, hook.NvidiaContainerCLI) - privileged := isPrivileged(s.Process.Capabilities) + privileged := isPrivileged(s) envSwarmGPU = hook.SwarmResource return containerConfig{ Pid: h.Pid, diff --git a/pkg/hook_test.go b/pkg/hook_test.go index 05f2059a..f22a07a3 100644 --- a/pkg/hook_test.go +++ b/pkg/hook_test.go @@ -2,6 +2,7 @@ package main import ( "testing" + "encoding/json" ) func TestParseCudaVersionValid(t *testing.T) { @@ -58,3 +59,85 @@ func TestParseCudaVersionInvalid(t *testing.T) { }) } } + +func TestIsPrivileged(t *testing.T) { + var tests = []struct { + spec string + expected bool + }{ + { + ` + { + "ociVersion": "1.0.0", + "process": { + "capabilities": { + "bounding": [ "CAP_SYS_ADMIN" ] + } + } + } + `, + true, + }, + { + ` + { + "ociVersion": "1.0.0", + "process": { + "capabilities": { + "bounding": [ "CAP_SYS_OTHER" ] + } + } + } + `, + false, + }, + { + ` + { + "ociVersion": "1.0.0", + "process": {} + } + `, + false, + }, + { + ` + { + "ociVersion": "1.0.0-rc2-dev", + "process": { + "capabilities": [ "CAP_SYS_ADMIN" ] + } + } + `, + true, + }, + { + ` + { + "ociVersion": "1.0.0-rc2-dev", + "process": { + "capabilities": [ "CAP_SYS_OTHER" ] + } + } + `, + false, + }, + { + ` + { + "ociVersion": "1.0.0-rc2-dev", + "process": {} + } + `, + false, + }, + } + for _, tc := range tests { + var spec Spec + _ = json.Unmarshal([]byte(tc.spec), &spec) + privileged := isPrivileged(&spec) + if privileged != tc.expected { + t.Errorf("isPrivileged() returned unexpectred value (privileged: %v, tc.expected: %v)", privileged, tc.expected) + } + } +}