diff --git a/cmd/nvidia-container-toolkit/container_config.go b/cmd/nvidia-container-toolkit/container_config.go index ee3492b3..5920f471 100644 --- a/cmd/nvidia-container-toolkit/container_config.go +++ b/cmd/nvidia-container-toolkit/container_config.go @@ -7,9 +7,9 @@ import ( "os" "path" "path/filepath" - "strconv" "strings" + "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "golang.org/x/mod/semver" ) @@ -104,45 +104,6 @@ type HookState struct { BundlePath string `json:"bundlePath"` } -func parseCudaVersion(cudaVersion string) (uint32, uint32) { - major, minor, err := parseMajorMinorVersion(cudaVersion) - if err != nil { - log.Panicln("invalid CUDA Version", cudaVersion, err) - } - return major, minor -} - -func parseMajorMinorVersion(version string) (uint32, uint32, error) { - if !semver.IsValid("v" + version) { - return 0, 0, fmt.Errorf("invalid version string") - } - - majorMinor := strings.TrimPrefix(semver.MajorMinor("v"+version), "v") - parts := strings.Split(majorMinor, ".") - - major, err := strconv.ParseUint(parts[0], 10, 32) - if err != nil { - return 0, 0, fmt.Errorf("invalid major version") - } - minor, err := strconv.ParseUint(parts[1], 10, 32) - if err != nil { - return 0, 0, fmt.Errorf("invalid minor version") - } - return uint32(major), uint32(minor), nil -} - -func getEnvMap(e []string) (m map[string]string) { - m = make(map[string]string) - for _, s := range e { - p := strings.SplitN(s, "=", 2) - if len(p) != 2 { - log.Panicln("environment error") - } - m[p[0]] = p[1] - } - return -} - func loadSpec(path string) (spec *Spec) { f, err := os.Open(path) if err != nil { @@ -204,12 +165,6 @@ func isPrivileged(s *Spec) bool { return false } -func isLegacyCUDAImage(env map[string]string) bool { - legacyCudaVersion := env[envCUDAVersion] - cudaRequire := env[envNVRequireCUDA] - return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 -} - func getDevicesFromEnvvar(env map[string]string, legacyImage bool) *string { // Build a list of envvars to consider. envVars := []string{envNVVisibleDevices} @@ -348,27 +303,11 @@ func getDriverCapabilities(env map[string]string, supportedDriverCapabilities Dr return capabilities } -func getRequirements(env map[string]string, legacyImage bool) []string { - // All variables with the "NVIDIA_REQUIRE_" prefix are passed to nvidia-container-cli - var requirements []string - for name, value := range env { - if strings.HasPrefix(name, envNVRequirePrefix) { - requirements = append(requirements, value) - } - } - if legacyImage { - vmaj, vmin := parseCudaVersion(env[envCUDAVersion]) - cudaRequire := fmt.Sprintf("cuda>=%d.%d", vmaj, vmin) - requirements = append(requirements, cudaRequire) - } - return requirements -} - -func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mount, privileged bool) *nvidiaConfig { - legacyImage := isLegacyCUDAImage(env) +func getNvidiaConfig(hookConfig *HookConfig, image image.CUDA, mounts []Mount, privileged bool) *nvidiaConfig { + legacyImage := image.IsLegacy() var devices string - if d := getDevices(hookConfig, env, mounts, privileged, legacyImage); d != nil { + if d := getDevices(hookConfig, image, mounts, privileged, legacyImage); d != nil { devices = *d } else { // 'nil' devices means this is not a GPU container. @@ -376,7 +315,7 @@ func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mou } var migConfigDevices string - if d := getMigConfigDevices(env); d != nil { + if d := getMigConfigDevices(image); d != nil { migConfigDevices = *d } if !privileged && migConfigDevices != "" { @@ -384,19 +323,21 @@ func getNvidiaConfig(hookConfig *HookConfig, env map[string]string, mounts []Mou } var migMonitorDevices string - if d := getMigMonitorDevices(env); d != nil { + if d := getMigMonitorDevices(image); d != nil { migMonitorDevices = *d } if !privileged && migMonitorDevices != "" { log.Panicln("cannot set MIG_MONITOR_DEVICES in non privileged container") } - driverCapabilities := getDriverCapabilities(env, hookConfig.SupportedDriverCapabilities, legacyImage).String() + driverCapabilities := getDriverCapabilities(image, hookConfig.SupportedDriverCapabilities, legacyImage).String() - requirements := getRequirements(env, legacyImage) + requirements, err := image.GetRequirements() + if err != nil { + log.Panicln("failed to get requirements", err) + } - // Don't fail on invalid values. - disableRequire, _ := strconv.ParseBool(env[envNVDisableRequire]) + disableRequire := image.HasDisableRequire() return &nvidiaConfig{ Devices: devices, @@ -422,13 +363,17 @@ func getContainerConfig(hook HookConfig) (config containerConfig) { s := loadSpec(path.Join(b, "config.json")) - env := getEnvMap(s.Process.Env) + image, err := image.NewCUDAImageFromEnv(s.Process.Env) + if err != nil { + log.Panicln(err) + } + privileged := isPrivileged(s) envSwarmGPU = hook.SwarmResource return containerConfig{ Pid: h.Pid, Rootfs: s.Root.Path, - Env: env, - Nvidia: getNvidiaConfig(&hook, env, s.Mounts, privileged), + Env: image, + Nvidia: getNvidiaConfig(&hook, image, s.Mounts, privileged), } } diff --git a/cmd/nvidia-container-toolkit/hook_test.go b/cmd/nvidia-container-toolkit/hook_test.go index 0e955737..d5449bab 100644 --- a/cmd/nvidia-container-toolkit/hook_test.go +++ b/cmd/nvidia-container-toolkit/hook_test.go @@ -7,51 +7,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestParseCudaVersionValid(t *testing.T) { - var tests = []struct { - version string - expected [2]uint32 - }{ - {"0", [2]uint32{0, 0}}, - {"8", [2]uint32{8, 0}}, - {"7.5", [2]uint32{7, 5}}, - {"9.0.116", [2]uint32{9, 0}}, - {"4294967295.4294967295.4294967295", [2]uint32{4294967295, 4294967295}}, - } - for i, c := range tests { - vmaj, vmin := parseCudaVersion(c.version) - - version := [2]uint32{vmaj, vmin} - - require.Equal(t, c.expected, version, "%d: %v", i, c) - } -} - -func TestParseCudaVersionInvalid(t *testing.T) { - var tests = []string{ - "foo", - "foo.5.10", - "9.0.116.50", - "9.0.116foo", - "7.foo", - "9.0.bar", - "9.4294967296", - "9.0.116.", - "9..0", - "9.", - ".5.10", - "-9", - "+9", - "-9.1.116", - "-9.-1.-116", - } - for _, c := range tests { - require.Panics(t, func() { - parseCudaVersion(c) - }, "parseCudaVersion(%v)", c) - } -} - func TestIsPrivileged(t *testing.T) { var tests = []struct { spec string