mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-29 07:21:46 +00:00
Merge branch 'fix-version-parsing' into 'main'
Use semver package to parse CUDA version See merge request nvidia/container-toolkit/container-toolkit!140
This commit is contained in:
commit
6e60b24828
@ -104,18 +104,31 @@ type HookState struct {
|
||||
BundlePath string `json:"bundlePath"`
|
||||
}
|
||||
|
||||
func parseCudaVersion(cudaVersion string) (vmaj, vmin, vpatch uint32) {
|
||||
if _, err := fmt.Sscanf(cudaVersion, "%d.%d.%d\n", &vmaj, &vmin, &vpatch); err != nil {
|
||||
vpatch = 0
|
||||
if _, err := fmt.Sscanf(cudaVersion, "%d.%d\n", &vmaj, &vmin); err != nil {
|
||||
vmin = 0
|
||||
if _, err := fmt.Sscanf(cudaVersion, "%d\n", &vmaj); err != nil {
|
||||
log.Panicln("invalid CUDA version:", cudaVersion)
|
||||
}
|
||||
func parseCudaVersion(cudaVersion string) (uint32, uint32) {
|
||||
major, minor, err := parseMajorMinorVersion(cudaVersion)
|
||||
if err != nil {
|
||||
log.Panicln("invalid CUDA Version", cudaVersion, err)
|
||||
}
|
||||
return major, minor
|
||||
}
|
||||
|
||||
return
|
||||
func parseMajorMinorVersion(version string) (uint32, uint32, error) {
|
||||
if !semver.IsValid("v" + version) {
|
||||
return 0, 0, fmt.Errorf("invalid version string")
|
||||
}
|
||||
|
||||
majorMinor := strings.TrimPrefix(semver.MajorMinor("v"+version), "v")
|
||||
parts := strings.Split(majorMinor, ".")
|
||||
|
||||
major, err := strconv.ParseUint(parts[0], 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid major version")
|
||||
}
|
||||
minor, err := strconv.ParseUint(parts[1], 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("invalid minor version")
|
||||
}
|
||||
return uint32(major), uint32(minor), nil
|
||||
}
|
||||
|
||||
func getEnvMap(e []string) (m map[string]string) {
|
||||
@ -344,7 +357,7 @@ func getRequirements(env map[string]string, legacyImage bool) []string {
|
||||
}
|
||||
}
|
||||
if legacyImage {
|
||||
vmaj, vmin, _ := parseCudaVersion(env[envCUDAVersion])
|
||||
vmaj, vmin := parseCudaVersion(env[envCUDAVersion])
|
||||
cudaRequire := fmt.Sprintf("cuda>=%d.%d", vmaj, vmin)
|
||||
requirements = append(requirements, cudaRequire)
|
||||
}
|
||||
|
@ -10,18 +10,18 @@ import (
|
||||
func TestParseCudaVersionValid(t *testing.T) {
|
||||
var tests = []struct {
|
||||
version string
|
||||
expected [3]uint32
|
||||
expected [2]uint32
|
||||
}{
|
||||
{"0", [3]uint32{0, 0, 0}},
|
||||
{"8", [3]uint32{8, 0, 0}},
|
||||
{"7.5", [3]uint32{7, 5, 0}},
|
||||
{"9.0.116", [3]uint32{9, 0, 116}},
|
||||
{"4294967295.4294967295.4294967295", [3]uint32{4294967295, 4294967295, 4294967295}},
|
||||
{"0", [2]uint32{0, 0}},
|
||||
{"8", [2]uint32{8, 0}},
|
||||
{"7.5", [2]uint32{7, 5}},
|
||||
{"9.0.116", [2]uint32{9, 0}},
|
||||
{"4294967295.4294967295.4294967295", [2]uint32{4294967295, 4294967295}},
|
||||
}
|
||||
for i, c := range tests {
|
||||
vmaj, vmin, vpatch := parseCudaVersion(c.version)
|
||||
vmaj, vmin := parseCudaVersion(c.version)
|
||||
|
||||
version := [3]uint32{vmaj, vmin, vpatch}
|
||||
version := [2]uint32{vmaj, vmin}
|
||||
|
||||
require.Equal(t, c.expected, version, "%d: %v", i, c)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user