mirror of
				https://github.com/NVIDIA/nvidia-container-toolkit
				synced 2025-06-26 18:18:24 +00:00 
			
		
		
		
	Add CUDA image abstraction
This change adds a CUDA image abstraction that encapsulates the queries performed on a container image (e.g. envvars) to check certain CUDA properties. Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
		
							parent
							
								
									2e319b5b08
								
							
						
					
					
						commit
						8f0e1906c2
					
				
							
								
								
									
										143
									
								
								internal/config/image/cuda_image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								internal/config/image/cuda_image.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,143 @@ | |||||||
|  | /** | ||||||
|  | # Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | **/ | ||||||
|  | 
 | ||||||
|  | package image | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | 
 | ||||||
|  | 	"github.com/opencontainers/runtime-spec/specs-go" | ||||||
|  | 	"golang.org/x/mod/semver" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | const ( | ||||||
|  | 	envCUDAVersion      = "CUDA_VERSION" | ||||||
|  | 	envNVRequirePrefix  = "NVIDIA_REQUIRE_" | ||||||
|  | 	envNVRequireCUDA    = envNVRequirePrefix + "CUDA" | ||||||
|  | 	envNVDisableRequire = "NVIDIA_DISABLE_REQUIRE" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | // CUDA represents a CUDA image that can be used for GPU computing. This wraps
 | ||||||
|  | // a map of environment variable to values that can be used to perform lookups
 | ||||||
|  | // such as requirements.
 | ||||||
|  | type CUDA map[string]string | ||||||
|  | 
 | ||||||
|  | // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
 | ||||||
|  | // The process environment is read (if present) to construc the CUDA Image.
 | ||||||
|  | func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) { | ||||||
|  | 	if spec == nil || spec.Process == nil { | ||||||
|  | 		return NewCUDAImageFromEnv(nil) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return NewCUDAImageFromEnv(spec.Process.Env) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // NewCUDAImageFromEnv creates a CUDA image from the input environment. The environment
 | ||||||
|  | // is a list of strings of the form ENVAR=VALUE.
 | ||||||
|  | func NewCUDAImageFromEnv(env []string) (CUDA, error) { | ||||||
|  | 	c := make(CUDA) | ||||||
|  | 
 | ||||||
|  | 	for _, e := range env { | ||||||
|  | 		parts := strings.SplitN(e, "=", 2) | ||||||
|  | 		if len(parts) != 2 { | ||||||
|  | 			return nil, fmt.Errorf("invalid environment variable: %v", e) | ||||||
|  | 		} | ||||||
|  | 		c[parts[0]] = parts[1] | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return c, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // IsLegacy returns whether the associated CUDA image is a "legacy" image. An
 | ||||||
|  | // image is considered legacy if it has a CUDA_VERSION environment variable defined
 | ||||||
|  | // and no NVIDIA_REQUIRE_CUDA environment variable defined.
 | ||||||
|  | func (i CUDA) IsLegacy() bool { | ||||||
|  | 	legacyCudaVersion := i[envCUDAVersion] | ||||||
|  | 	cudaRequire := i[envNVRequireCUDA] | ||||||
|  | 	return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // GetRequirements returns the requirements fomr all NVIDIA_REQUIRE_ environment
 | ||||||
|  | // variables.
 | ||||||
|  | func (i CUDA) GetRequirements() ([]string, error) { | ||||||
|  | 	// TODO: We need not process this if disable require is set, but this will be done
 | ||||||
|  | 	// in a single follow-up to ensure that the behavioural change is accurately captured.
 | ||||||
|  | 	// if i.HasDisableRequire() {
 | ||||||
|  | 	// 	return nil, nil
 | ||||||
|  | 	// }
 | ||||||
|  | 
 | ||||||
|  | 	// All variables with the "NVIDIA_REQUIRE_" prefix are passed to nvidia-container-cli
 | ||||||
|  | 	var requirements []string | ||||||
|  | 	for name, value := range i { | ||||||
|  | 		if strings.HasPrefix(name, envNVRequirePrefix) { | ||||||
|  | 			requirements = append(requirements, value) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	if i.IsLegacy() { | ||||||
|  | 		v, err := i.legacyVersion() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, fmt.Errorf("failed to get version: %v", err) | ||||||
|  | 		} | ||||||
|  | 		cudaRequire := fmt.Sprintf("cuda>=%s", v) | ||||||
|  | 		requirements = append(requirements, cudaRequire) | ||||||
|  | 	} | ||||||
|  | 	return requirements, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // HasDisableRequire checks for the value of the NVIDIA_DISABLE_REQUIRE. If set
 | ||||||
|  | // to a valid (true) boolean value this can be used to disable the requirement checks
 | ||||||
|  | func (i CUDA) HasDisableRequire() bool { | ||||||
|  | 	if disable, exists := i[envNVDisableRequire]; exists { | ||||||
|  | 		// i.logger.Debugf("NVIDIA_DISABLE_REQUIRE=%v; skipping requirement checks", disable)
 | ||||||
|  | 		d, _ := strconv.ParseBool(disable) | ||||||
|  | 		return d | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (i CUDA) legacyVersion() (string, error) { | ||||||
|  | 	majorMinor, err := parseMajorMinorVersion(i[envCUDAVersion]) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("invalid CUDA version: %v", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return majorMinor, nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func parseMajorMinorVersion(version string) (string, error) { | ||||||
|  | 	vVersion := "v" + strings.TrimPrefix(version, "v") | ||||||
|  | 
 | ||||||
|  | 	if !semver.IsValid(vVersion) { | ||||||
|  | 		return "", fmt.Errorf("invalid version string") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	majorMinor := strings.TrimPrefix(semver.MajorMinor(vVersion), "v") | ||||||
|  | 	parts := strings.Split(majorMinor, ".") | ||||||
|  | 
 | ||||||
|  | 	var err error | ||||||
|  | 	_, err = strconv.ParseUint(parts[0], 10, 32) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("invalid major version") | ||||||
|  | 	} | ||||||
|  | 	_, err = strconv.ParseUint(parts[1], 10, 32) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", fmt.Errorf("invalid minor version") | ||||||
|  | 	} | ||||||
|  | 	return majorMinor, nil | ||||||
|  | } | ||||||
							
								
								
									
										71
									
								
								internal/config/image/cuda_image_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								internal/config/image/cuda_image_test.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,71 @@ | |||||||
|  | /** | ||||||
|  | # Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | **/ | ||||||
|  | 
 | ||||||
|  | package image | ||||||
|  | 
 | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 
 | ||||||
|  | 	"github.com/stretchr/testify/require" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func TestParseMajorMinorVersionValid(t *testing.T) { | ||||||
|  | 	var tests = []struct { | ||||||
|  | 		version  string | ||||||
|  | 		expected string | ||||||
|  | 	}{ | ||||||
|  | 		{"0", "0.0"}, | ||||||
|  | 		{"8", "8.0"}, | ||||||
|  | 		{"7.5", "7.5"}, | ||||||
|  | 		{"9.0.116", "9.0"}, | ||||||
|  | 		{"4294967295.4294967295.4294967295", "4294967295.4294967295"}, | ||||||
|  | 		{"v11.6", "11.6"}, | ||||||
|  | 	} | ||||||
|  | 	for _, c := range tests { | ||||||
|  | 		t.Run(c.version, func(t *testing.T) { | ||||||
|  | 			version, err := parseMajorMinorVersion(c.version) | ||||||
|  | 
 | ||||||
|  | 			require.NoError(t, err) | ||||||
|  | 			require.Equal(t, c.expected, version) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestParseMajorMinorVersionInvalid(t *testing.T) { | ||||||
|  | 	var tests = []string{ | ||||||
|  | 		"foo", | ||||||
|  | 		"foo.5.10", | ||||||
|  | 		"9.0.116.50", | ||||||
|  | 		"9.0.116foo", | ||||||
|  | 		"7.foo", | ||||||
|  | 		"9.0.bar", | ||||||
|  | 		"9.4294967296", | ||||||
|  | 		"9.0.116.", | ||||||
|  | 		"9..0", | ||||||
|  | 		"9.", | ||||||
|  | 		".5.10", | ||||||
|  | 		"-9", | ||||||
|  | 		"+9", | ||||||
|  | 		"-9.1.116", | ||||||
|  | 		"-9.-1.-116", | ||||||
|  | 	} | ||||||
|  | 	for _, c := range tests { | ||||||
|  | 		t.Run(c, func(t *testing.T) { | ||||||
|  | 			_, err := parseMajorMinorVersion(c) | ||||||
|  | 			require.Error(t, err) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user