diff --git a/cmd/nvidia-container-toolkit/container_config.go b/cmd/nvidia-container-toolkit/container_config.go index b129895f..ee3492b3 100644 --- a/cmd/nvidia-container-toolkit/container_config.go +++ b/cmd/nvidia-container-toolkit/container_config.go @@ -104,18 +104,31 @@ type HookState struct { BundlePath string `json:"bundlePath"` } -func parseCudaVersion(cudaVersion string) (vmaj, vmin, vpatch uint32) { - if _, err := fmt.Sscanf(cudaVersion, "%d.%d.%d\n", &vmaj, &vmin, &vpatch); err != nil { - vpatch = 0 - if _, err := fmt.Sscanf(cudaVersion, "%d.%d\n", &vmaj, &vmin); err != nil { - vmin = 0 - if _, err := fmt.Sscanf(cudaVersion, "%d\n", &vmaj); err != nil { - log.Panicln("invalid CUDA version:", cudaVersion) - } - } +func parseCudaVersion(cudaVersion string) (uint32, uint32) { + major, minor, err := parseMajorMinorVersion(cudaVersion) + if err != nil { + log.Panicln("invalid CUDA Version", cudaVersion, err) + } + return major, minor +} + +func parseMajorMinorVersion(version string) (uint32, uint32, error) { + if !semver.IsValid("v" + version) { + return 0, 0, fmt.Errorf("invalid version string") } - return + majorMinor := strings.TrimPrefix(semver.MajorMinor("v"+version), "v") + parts := strings.Split(majorMinor, ".") + + major, err := strconv.ParseUint(parts[0], 10, 32) + if err != nil { + return 0, 0, fmt.Errorf("invalid major version") + } + minor, err := strconv.ParseUint(parts[1], 10, 32) + if err != nil { + return 0, 0, fmt.Errorf("invalid minor version") + } + return uint32(major), uint32(minor), nil } func getEnvMap(e []string) (m map[string]string) { @@ -344,7 +357,7 @@ func getRequirements(env map[string]string, legacyImage bool) []string { } } if legacyImage { - vmaj, vmin, _ := parseCudaVersion(env[envCUDAVersion]) + vmaj, vmin := parseCudaVersion(env[envCUDAVersion]) cudaRequire := fmt.Sprintf("cuda>=%d.%d", vmaj, vmin) requirements = append(requirements, cudaRequire) } diff --git a/cmd/nvidia-container-toolkit/hook_test.go b/cmd/nvidia-container-toolkit/hook_test.go index 07d65bb4..0e955737 100644 --- a/cmd/nvidia-container-toolkit/hook_test.go +++ b/cmd/nvidia-container-toolkit/hook_test.go @@ -10,18 +10,18 @@ import ( func TestParseCudaVersionValid(t *testing.T) { var tests = []struct { version string - expected [3]uint32 + expected [2]uint32 }{ - {"0", [3]uint32{0, 0, 0}}, - {"8", [3]uint32{8, 0, 0}}, - {"7.5", [3]uint32{7, 5, 0}}, - {"9.0.116", [3]uint32{9, 0, 116}}, - {"4294967295.4294967295.4294967295", [3]uint32{4294967295, 4294967295, 4294967295}}, + {"0", [2]uint32{0, 0}}, + {"8", [2]uint32{8, 0}}, + {"7.5", [2]uint32{7, 5}}, + {"9.0.116", [2]uint32{9, 0}}, + {"4294967295.4294967295.4294967295", [2]uint32{4294967295, 4294967295}}, } for i, c := range tests { - vmaj, vmin, vpatch := parseCudaVersion(c.version) + vmaj, vmin := parseCudaVersion(c.version) - version := [3]uint32{vmaj, vmin, vpatch} + version := [2]uint32{vmaj, vmin} require.Equal(t, c.expected, version, "%d: %v", i, c) }