mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Use semver package to parse CUDA version
This avoids the use of scanf on a user-provided string which is flagged as a security vulnerability. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
		
							parent
							
								
									4ce932e7a7
								
							
						
					
					
						commit
						bdf997c761
					
				| @ -104,18 +104,31 @@ type HookState struct { | |||||||
| 	BundlePath string `json:"bundlePath"` | 	BundlePath string `json:"bundlePath"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func parseCudaVersion(cudaVersion string) (vmaj, vmin, vpatch uint32) { | func parseCudaVersion(cudaVersion string) (uint32, uint32) { | ||||||
| 	if _, err := fmt.Sscanf(cudaVersion, "%d.%d.%d\n", &vmaj, &vmin, &vpatch); err != nil { | 	major, minor, err := parseMajorMinorVersion(cudaVersion) | ||||||
| 		vpatch = 0 | 	if err != nil { | ||||||
| 		if _, err := fmt.Sscanf(cudaVersion, "%d.%d\n", &vmaj, &vmin); err != nil { | 		log.Panicln("invalid CUDA Version", cudaVersion, err) | ||||||
| 			vmin = 0 | 	} | ||||||
| 			if _, err := fmt.Sscanf(cudaVersion, "%d\n", &vmaj); err != nil { | 	return major, minor | ||||||
| 				log.Panicln("invalid CUDA version:", cudaVersion) | } | ||||||
| 			} | 
 | ||||||
| 		} | func parseMajorMinorVersion(version string) (uint32, uint32, error) { | ||||||
|  | 	if !semver.IsValid("v" + version) { | ||||||
|  | 		return 0, 0, fmt.Errorf("invalid version string") | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	return | 	majorMinor := strings.TrimPrefix(semver.MajorMinor("v"+version), "v") | ||||||
|  | 	parts := strings.Split(majorMinor, ".") | ||||||
|  | 
 | ||||||
|  | 	major, err := strconv.ParseUint(parts[0], 10, 32) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, 0, fmt.Errorf("invalid major version") | ||||||
|  | 	} | ||||||
|  | 	minor, err := strconv.ParseUint(parts[1], 10, 32) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, 0, fmt.Errorf("invalid minor version") | ||||||
|  | 	} | ||||||
|  | 	return uint32(major), uint32(minor), nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getEnvMap(e []string) (m map[string]string) { | func getEnvMap(e []string) (m map[string]string) { | ||||||
| @ -344,7 +357,7 @@ func getRequirements(env map[string]string, legacyImage bool) []string { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	if legacyImage { | 	if legacyImage { | ||||||
| 		vmaj, vmin, _ := parseCudaVersion(env[envCUDAVersion]) | 		vmaj, vmin := parseCudaVersion(env[envCUDAVersion]) | ||||||
| 		cudaRequire := fmt.Sprintf("cuda>=%d.%d", vmaj, vmin) | 		cudaRequire := fmt.Sprintf("cuda>=%d.%d", vmaj, vmin) | ||||||
| 		requirements = append(requirements, cudaRequire) | 		requirements = append(requirements, cudaRequire) | ||||||
| 	} | 	} | ||||||
|  | |||||||
| @ -10,18 +10,18 @@ import ( | |||||||
| func TestParseCudaVersionValid(t *testing.T) { | func TestParseCudaVersionValid(t *testing.T) { | ||||||
| 	var tests = []struct { | 	var tests = []struct { | ||||||
| 		version  string | 		version  string | ||||||
| 		expected [3]uint32 | 		expected [2]uint32 | ||||||
| 	}{ | 	}{ | ||||||
| 		{"0", [3]uint32{0, 0, 0}}, | 		{"0", [2]uint32{0, 0}}, | ||||||
| 		{"8", [3]uint32{8, 0, 0}}, | 		{"8", [2]uint32{8, 0}}, | ||||||
| 		{"7.5", [3]uint32{7, 5, 0}}, | 		{"7.5", [2]uint32{7, 5}}, | ||||||
| 		{"9.0.116", [3]uint32{9, 0, 116}}, | 		{"9.0.116", [2]uint32{9, 0}}, | ||||||
| 		{"4294967295.4294967295.4294967295", [3]uint32{4294967295, 4294967295, 4294967295}}, | 		{"4294967295.4294967295.4294967295", [2]uint32{4294967295, 4294967295}}, | ||||||
| 	} | 	} | ||||||
| 	for i, c := range tests { | 	for i, c := range tests { | ||||||
| 		vmaj, vmin, vpatch := parseCudaVersion(c.version) | 		vmaj, vmin := parseCudaVersion(c.version) | ||||||
| 
 | 
 | ||||||
| 		version := [3]uint32{vmaj, vmin, vpatch} | 		version := [2]uint32{vmaj, vmin} | ||||||
| 
 | 
 | ||||||
| 		require.Equal(t, c.expected, version, "%d: %v", i, c) | 		require.Equal(t, c.expected, version, "%d: %v", i, c) | ||||||
| 	} | 	} | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user