Compare commits

..

1 Commits

Author SHA1 Message Date
Carlos Eduardo Arango Gutierrez
9674787e7e [no-relnote] Add E2E for libnvidia-container
Some checks failed
CI Pipeline / code-scanning (push) Has been cancelled
CI Pipeline / variables (push) Has been cancelled
CI Pipeline / golang (push) Has been cancelled
CI Pipeline / image (push) Has been cancelled
CI Pipeline / e2e-test (push) Has been cancelled
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
2025-06-05 18:18:04 +02:00
130 changed files with 4037 additions and 12684 deletions

View File

@@ -22,7 +22,15 @@ variables:
BUILD_MULTI_ARCH_IMAGES: "true" BUILD_MULTI_ARCH_IMAGES: "true"
stages: stages:
- pull - trigger
- image
- lint
- go-checks
- go-build
- unit-tests
- package-build
- image-build
- test
- scan - scan
- release - release
- sign - sign
@@ -45,6 +53,108 @@ workflow:
# We then add all the regular triggers # We then add all the regular triggers
- !reference [.pipeline-trigger-rules, rules] - !reference [.pipeline-trigger-rules, rules]
# The main or manual job is used to filter out distributions or architectures that are not required on
# every build.
.main-or-manual:
rules:
- !reference [.pipeline-trigger-rules, rules]
- if: $CI_PIPELINE_SOURCE == "schedule"
when: manual
# The trigger-pipeline job adds a manualy triggered job to the pipeline on merge requests.
trigger-pipeline:
stage: trigger
script:
- echo "starting pipeline"
rules:
- !reference [.main-or-manual, rules]
- if: $CI_PIPELINE_SOURCE == "merge_request_event"
when: manual
allow_failure: false
- when: always
# Define the distribution targets
.dist-centos7:
rules:
- !reference [.main-or-manual, rules]
variables:
DIST: centos7
.dist-centos8:
variables:
DIST: centos8
.dist-ubi8:
rules:
- !reference [.main-or-manual, rules]
variables:
DIST: ubi8
.dist-ubuntu18.04:
variables:
DIST: ubuntu18.04
.dist-ubuntu20.04:
variables:
DIST: ubuntu20.04
.dist-packaging:
variables:
DIST: packaging
# Define architecture targets
.arch-aarch64:
variables:
ARCH: aarch64
.arch-amd64:
variables:
ARCH: amd64
.arch-arm64:
variables:
ARCH: arm64
.arch-ppc64le:
rules:
- !reference [.main-or-manual, rules]
variables:
ARCH: ppc64le
.arch-x86_64:
variables:
ARCH: x86_64
# Define the platform targets
.platform-amd64:
variables:
PLATFORM: linux/amd64
.platform-arm64:
variables:
PLATFORM: linux/arm64
# Define test helpers
.integration:
stage: test
variables:
IMAGE_NAME: "${CI_REGISTRY_IMAGE}/container-toolkit"
VERSION: "${CI_COMMIT_SHORT_SHA}"
before_script:
- apk add --no-cache make bash jq
- docker login -u "${CI_REGISTRY_USER}" -p "${CI_REGISTRY_PASSWORD}" "${CI_REGISTRY}"
- docker pull "${IMAGE_NAME}:${VERSION}-${DIST}"
script:
- make -f deployments/container/Makefile test-${DIST}
# Define the test targets
test-packaging:
extends:
- .integration
- .dist-packaging
needs:
- image-packaging
# Download the regctl binary for use in the release steps # Download the regctl binary for use in the release steps
.regctl-setup: .regctl-setup:
before_script: before_script:
@@ -54,3 +164,84 @@ workflow:
- curl -sSLo bin/regctl https://github.com/regclient/regclient/releases/download/${REGCTL_VERSION}/regctl-linux-amd64 - curl -sSLo bin/regctl https://github.com/regclient/regclient/releases/download/${REGCTL_VERSION}/regctl-linux-amd64
- chmod a+x bin/regctl - chmod a+x bin/regctl
- export PATH=$(pwd)/bin:${PATH} - export PATH=$(pwd)/bin:${PATH}
# .release forms the base of the deployment jobs which push images to the CI registry.
# This is extended with the version to be deployed (e.g. the SHA or TAG) and the
# target os.
.release:
stage: release
variables:
# Define the source image for the release
IMAGE_NAME: "${CI_REGISTRY_IMAGE}/container-toolkit"
VERSION: "${CI_COMMIT_SHORT_SHA}"
# OUT_IMAGE_VERSION is overridden for external releases
OUT_IMAGE_VERSION: "${CI_COMMIT_SHORT_SHA}"
before_script:
- !reference [.regctl-setup, before_script]
# We ensure that the components of the output image are set:
- 'echo Image Name: ${OUT_IMAGE_NAME} ; [[ -n "${OUT_IMAGE_NAME}" ]] || exit 1'
- 'echo Version: ${OUT_IMAGE_VERSION} ; [[ -n "${OUT_IMAGE_VERSION}" ]] || exit 1'
- apk add --no-cache make bash
script:
# Log in to the "output" registry, tag the image and push the image
- 'echo "Logging in to CI registry ${CI_REGISTRY}"'
- regctl registry login "${CI_REGISTRY}" -u "${CI_REGISTRY_USER}" -p "${CI_REGISTRY_PASSWORD}"
- '[ ${CI_REGISTRY} = ${OUT_REGISTRY} ] || echo "Logging in to output registry ${OUT_REGISTRY}"'
- '[ ${CI_REGISTRY} = ${OUT_REGISTRY} ] || regctl registry login "${OUT_REGISTRY}" -u "${OUT_REGISTRY_USER}" -p "${OUT_REGISTRY_TOKEN}"'
# Since OUT_IMAGE_NAME and OUT_IMAGE_VERSION are set, this will push the CI image to the
# Target
- make -f deployments/container/Makefile push-${DIST}
# Define a staging release step that pushes an image to an internal "staging" repository
# This is triggered for all pipelines (i.e. not only tags) to test the pipeline steps
# outside of the release process.
.release:staging:
extends:
- .release
variables:
OUT_REGISTRY_USER: "${NGC_REGISTRY_USER}"
OUT_REGISTRY_TOKEN: "${NGC_REGISTRY_TOKEN}"
OUT_REGISTRY: "${NGC_REGISTRY}"
OUT_IMAGE_NAME: "${NGC_REGISTRY_STAGING_IMAGE_NAME}"
# Define an external release step that pushes an image to an external repository.
# This includes a devlopment image off main.
.release:external:
extends:
- .release
variables:
FORCE_PUBLISH_IMAGES: "yes"
rules:
- if: $CI_COMMIT_TAG
variables:
OUT_IMAGE_VERSION: "${CI_COMMIT_TAG}"
- if: $CI_COMMIT_BRANCH == $RELEASE_DEVEL_BRANCH
variables:
OUT_IMAGE_VERSION: "${DEVEL_RELEASE_IMAGE_VERSION}"
# Define the release jobs
release:staging-ubi8:
extends:
- .release:staging
- .dist-ubi8
needs:
- image-ubi8
release:staging-ubuntu20.04:
extends:
- .release:staging
- .dist-ubuntu20.04
needs:
- test-toolkit-ubuntu20.04
- test-containerd-ubuntu20.04
- test-crio-ubuntu20.04
- test-docker-ubuntu20.04
release:staging-packaging:
extends:
- .release:staging
- .dist-packaging
needs:
- test-packaging

View File

@@ -72,7 +72,7 @@ jobs:
env: env:
E2E_INSTALL_CTK: "true" E2E_INSTALL_CTK: "true"
E2E_IMAGE_NAME: ghcr.io/nvidia/container-toolkit E2E_IMAGE_NAME: ghcr.io/nvidia/container-toolkit
E2E_IMAGE_TAG: ${{ inputs.version }} E2E_IMAGE_TAG: ${{ inputs.version }}-ubuntu20.04
E2E_SSH_USER: ${{ secrets.E2E_SSH_USER }} E2E_SSH_USER: ${{ secrets.E2E_SSH_USER }}
E2E_SSH_HOST: ${{ steps.holodeck_public_dns_name.outputs.result }} E2E_SSH_HOST: ${{ steps.holodeck_public_dns_name.outputs.result }}
run: | run: |

View File

@@ -80,8 +80,14 @@ jobs:
strategy: strategy:
matrix: matrix:
dist: dist:
- ubi9 - ubuntu20.04
- ubi8
- packaging - packaging
ispr:
- ${{ github.ref_name != 'main' && !startsWith( github.ref_name, 'release-' ) }}
exclude:
- ispr: true
dist: ubi8
needs: packages needs: packages
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4

228
.gitlab-ci.yml Normal file
View File

@@ -0,0 +1,228 @@
# Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
include:
- .common-ci.yml
# Define the package build helpers
.multi-arch-build:
before_script:
- apk add --no-cache coreutils build-base sed git bash make
- '[[ -n "${SKIP_QEMU_SETUP}" ]] || docker run --rm --privileged multiarch/qemu-user-static --reset -p yes -c yes'
.package-artifacts:
variables:
ARTIFACTS_NAME: "toolkit-container-${CI_PIPELINE_ID}"
ARTIFACTS_ROOT: "toolkit-container-${CI_PIPELINE_ID}"
DIST_DIR: ${CI_PROJECT_DIR}/${ARTIFACTS_ROOT}
.package-build:
extends:
- .multi-arch-build
- .package-artifacts
stage: package-build
timeout: 3h
script:
- ./scripts/build-packages.sh ${DIST}-${ARCH}
artifacts:
name: ${ARTIFACTS_NAME}
paths:
- ${ARTIFACTS_ROOT}
needs:
- job: package-meta-packages
artifacts: true
# Define the package build targets
package-meta-packages:
extends:
- .package-artifacts
stage: package-build
variables:
SKIP_LIBNVIDIA_CONTAINER: "yes"
SKIP_NVIDIA_CONTAINER_TOOLKIT: "yes"
parallel:
matrix:
- PACKAGING: [deb, rpm]
before_script:
- apk add --no-cache coreutils build-base sed git bash make
script:
- ./scripts/build-packages.sh ${PACKAGING}
artifacts:
name: ${ARTIFACTS_NAME}
paths:
- ${ARTIFACTS_ROOT}
package-centos7-aarch64:
extends:
- .package-build
- .dist-centos7
- .arch-aarch64
package-centos7-x86_64:
extends:
- .package-build
- .dist-centos7
- .arch-x86_64
package-centos8-ppc64le:
extends:
- .package-build
- .dist-centos8
- .arch-ppc64le
package-ubuntu18.04-amd64:
extends:
- .package-build
- .dist-ubuntu18.04
- .arch-amd64
package-ubuntu18.04-arm64:
extends:
- .package-build
- .dist-ubuntu18.04
- .arch-arm64
package-ubuntu18.04-ppc64le:
extends:
- .package-build
- .dist-ubuntu18.04
- .arch-ppc64le
.buildx-setup:
before_script:
- export BUILDX_VERSION=v0.6.3
- apk add --no-cache curl
- mkdir -p ~/.docker/cli-plugins
- curl -sSLo ~/.docker/cli-plugins/docker-buildx "https://github.com/docker/buildx/releases/download/${BUILDX_VERSION}/buildx-${BUILDX_VERSION}.linux-amd64"
- chmod a+x ~/.docker/cli-plugins/docker-buildx
- docker buildx create --use --platform=linux/amd64,linux/arm64
- '[[ -n "${SKIP_QEMU_SETUP}" ]] || docker run --rm --privileged multiarch/qemu-user-static --reset -p yes'
# Define the image build targets
.image-build:
stage: image-build
variables:
IMAGE_NAME: "${CI_REGISTRY_IMAGE}/container-toolkit"
VERSION: "${CI_COMMIT_SHORT_SHA}"
PUSH_ON_BUILD: "true"
before_script:
- !reference [.buildx-setup, before_script]
- apk add --no-cache bash make git
- 'echo "Logging in to CI registry ${CI_REGISTRY}"'
- docker login -u "${CI_REGISTRY_USER}" -p "${CI_REGISTRY_PASSWORD}" "${CI_REGISTRY}"
script:
- make -f deployments/container/Makefile build-${DIST}
image-ubi8:
extends:
- .image-build
- .package-artifacts
- .dist-ubi8
needs:
# Note: The ubi8 image uses the centos7 packages
- package-centos7-aarch64
- package-centos7-x86_64
image-ubuntu20.04:
extends:
- .image-build
- .package-artifacts
- .dist-ubuntu20.04
needs:
- package-ubuntu18.04-amd64
- package-ubuntu18.04-arm64
- job: package-ubuntu18.04-ppc64le
optional: true
# The DIST=packaging target creates an image containing all built packages
image-packaging:
extends:
- .image-build
- .package-artifacts
- .dist-packaging
needs:
- job: package-ubuntu18.04-amd64
- job: package-ubuntu18.04-arm64
- job: package-amazonlinux2-aarch64
optional: true
- job: package-amazonlinux2-x86_64
optional: true
- job: package-centos7-aarch64
optional: true
- job: package-centos7-x86_64
optional: true
- job: package-centos8-ppc64le
optional: true
- job: package-debian10-amd64
optional: true
- job: package-opensuse-leap15.1-x86_64
optional: true
- job: package-ubuntu18.04-ppc64le
optional: true
# Define publish test helpers
.test:docker:
extends:
- .integration
variables:
TEST_CASES: "docker"
.test:containerd:
# TODO: The containerd tests fail due to issues with SIGHUP.
# Until this is resolved with retry up to twice and allow failure here.
retry: 2
allow_failure: true
extends:
- .integration
variables:
TEST_CASES: "containerd"
.test:crio:
extends:
- .integration
variables:
TEST_CASES: "crio"
# Define the test targets
test-toolkit-ubuntu20.04:
extends:
- .test:toolkit
- .dist-ubuntu20.04
needs:
- image-ubuntu20.04
test-containerd-ubuntu20.04:
extends:
- .test:containerd
- .dist-ubuntu20.04
needs:
- image-ubuntu20.04
test-crio-ubuntu20.04:
extends:
- .test:crio
- .dist-ubuntu20.04
needs:
- image-ubuntu20.04
test-docker-ubuntu20.04:
extends:
- .test:docker
- .dist-ubuntu20.04
needs:
- image-ubuntu20.04

View File

@@ -39,62 +39,19 @@ variables:
KITMAKER_RELEASE_FOLDER: "kitmaker" KITMAKER_RELEASE_FOLDER: "kitmaker"
PACKAGE_ARCHIVE_RELEASE_FOLDER: "releases" PACKAGE_ARCHIVE_RELEASE_FOLDER: "releases"
# .copy-images copies the required application and packaging images from the .image-pull:
# IN_IMAGE="${IN_IMAGE_NAME}:${IN_IMAGE_TAG}${TAG_SUFFIX}" stage: image-build
# to
# OUT_IMAGE="${OUT_IMAGE_NAME}:${OUT_IMAGE_TAG}${TAG_SUFFIX}"
# The script also logs into IN_REGISTRY and OUT_REGISTRY using the supplied
# username and tokens.
.copy-images:
parallel:
matrix:
- TAG_SUFFIX: ["", "-packaging"]
before_script:
- !reference [.regctl-setup, before_script]
- apk add --no-cache make bash
variables:
REGCTL: regctl
script:
- |
if [ -n ${IN_REGISTRY} ] && [ -n ${IN_REGISTRY_USER} ]; then
echo "Logging in to ${IN_REGISTRY}"
${REGCTL} registry login "${IN_REGISTRY}" -u "${IN_REGISTRY_USER}" -p "${IN_REGISTRY_TOKEN}" || exit 1
fi
if [ -n ${OUT_REGISTRY} ] && [ -n ${OUT_REGISTRY_USER} ] && [ "${IN_REGISTRY}" != "${OUT_REGISTRY}" ]; then
echo "Logging in to ${OUT_REGISTRY}"
${REGCTL} registry login "${OUT_REGISTRY}" -u "${OUT_REGISTRY_USER}" -p "${OUT_REGISTRY_TOKEN}" || exit 1
fi
export IN_IMAGE="${IN_IMAGE_NAME}:${IN_IMAGE_TAG}${TAG_SUFFIX}"
export OUT_IMAGE="${OUT_IMAGE_NAME}:${OUT_IMAGE_TAG}${TAG_SUFFIX}"
echo "Copying ${IN_IMAGE} to ${OUT_IMAGE}"
${REGCTL} image copy ${IN_IMAGE} ${OUT_IMAGE}
# pull-images pulls images from the public CI registry to the internal CI registry.
pull-images:
extends:
- .copy-images
stage: pull
variables: variables:
IN_REGISTRY: "${STAGING_REGISTRY}" IN_REGISTRY: "${STAGING_REGISTRY}"
IN_IMAGE_NAME: ${STAGING_REGISTRY}/container-toolkit IN_IMAGE_NAME: container-toolkit
IN_IMAGE_TAG: "${STAGING_VERSION}" IN_VERSION: "${STAGING_VERSION}"
OUT_REGISTRY: "${CI_REGISTRY}"
OUT_REGISTRY_USER: "${CI_REGISTRY_USER}" OUT_REGISTRY_USER: "${CI_REGISTRY_USER}"
OUT_REGISTRY_TOKEN: "${CI_REGISTRY_PASSWORD}" OUT_REGISTRY_TOKEN: "${CI_REGISTRY_PASSWORD}"
OUT_REGISTRY: "${CI_REGISTRY}"
OUT_IMAGE_NAME: "${CI_REGISTRY_IMAGE}/container-toolkit" OUT_IMAGE_NAME: "${CI_REGISTRY_IMAGE}/container-toolkit"
OUT_IMAGE_TAG: "${CI_COMMIT_SHORT_SHA}" PUSH_MULTIPLE_TAGS: "false"
# We delay the job start to allow the public pipeline to generate the required images. # We delay the job start to allow the public pipeline to generate the required images.
rules: rules:
# If the pipeline is triggered from a tag or the WEB UI we don't delay the
# start of the pipeline.
- if: $CI_COMMIT_TAG || $CI_PIPELINE_SOURCE == "web"
# If the pipeline is triggered through other means (i.e. a branch or MR)
# we add a 30 minute delay to ensure that the images are available in the
# public CI registry.
- when: delayed - when: delayed
start_in: 30 minutes start_in: 30 minutes
timeout: 30 minutes timeout: 30 minutes
@@ -103,6 +60,30 @@ pull-images:
when: when:
- job_execution_timeout - job_execution_timeout
- stuck_or_timeout_failure - stuck_or_timeout_failure
before_script:
- !reference [.regctl-setup, before_script]
- apk add --no-cache make bash
- >
regctl manifest get ${IN_REGISTRY}/${IN_IMAGE_NAME}:${IN_VERSION}-${DIST} --list > /dev/null && echo "${IN_REGISTRY}/${IN_IMAGE_NAME}:${IN_VERSION}-${DIST}" || ( echo "${IN_REGISTRY}/${IN_IMAGE_NAME}:${IN_VERSION}-${DIST} does not exist" && sleep infinity )
script:
- regctl registry login "${OUT_REGISTRY}" -u "${OUT_REGISTRY_USER}" -p "${OUT_REGISTRY_TOKEN}"
- make -f deployments/container/Makefile IMAGE=${IN_REGISTRY}/${IN_IMAGE_NAME}:${IN_VERSION}-${DIST} OUT_IMAGE=${OUT_IMAGE_NAME}:${CI_COMMIT_SHORT_SHA}-${DIST} push-${DIST}
image-ubi8:
extends:
- .dist-ubi8
- .image-pull
image-ubuntu20.04:
extends:
- .dist-ubuntu20.04
- .image-pull
# The DIST=packaging target creates an image containing all built packages
image-packaging:
extends:
- .dist-packaging
- .image-pull
# We skip the integration tests for the internal CI: # We skip the integration tests for the internal CI:
.integration: .integration:
@@ -114,37 +95,27 @@ pull-images:
# The .scan step forms the base of the image scan operation performed before releasing # The .scan step forms the base of the image scan operation performed before releasing
# images. # images.
scan-images: .scan:
stage: scan stage: scan
needs:
- pull-images
image: "${PULSE_IMAGE}" image: "${PULSE_IMAGE}"
parallel:
matrix:
- TAG_SUFFIX: [""]
PLATFORM: ["linux/amd64", "linux/arm64"]
- TAG_SUFFIX: "-packaging"
PLATFORM: "linux/amd64"
variables: variables:
IMAGE: "${CI_REGISTRY_IMAGE}/container-toolkit:${CI_COMMIT_SHORT_SHA}" IMAGE: "${CI_REGISTRY_IMAGE}/container-toolkit:${CI_COMMIT_SHORT_SHA}-${DIST}"
IMAGE_ARCHIVE: "container-toolkit-${CI_JOB_ID}.tar" IMAGE_ARCHIVE: "container-toolkit-${DIST}-${ARCH}-${CI_JOB_ID}.tar"
rules: rules:
- if: $IGNORE_SCANS == "yes" - if: $SKIP_SCANS != "yes"
allow_failure: true - when: manual
- when: on_success before_script:
script: - docker login -u "${CI_REGISTRY_USER}" -p "${CI_REGISTRY_PASSWORD}" "${CI_REGISTRY}"
- | # TODO: We should specify the architecture here and scan all architectures
docker login -u "${CI_REGISTRY_USER}" -p "${CI_REGISTRY_PASSWORD}" "${CI_REGISTRY}" - docker pull --platform="${PLATFORM}" "${IMAGE}"
export SCAN_IMAGE=${IMAGE}${TAG_SUFFIX} - docker save "${IMAGE}" -o "${IMAGE_ARCHIVE}"
echo "Scanning image ${SCAN_IMAGE} for ${PLATFORM}" - AuthHeader=$(echo -n $SSA_CLIENT_ID:$SSA_CLIENT_SECRET | base64 -w0)
docker pull --platform="${PLATFORM}" "${SCAN_IMAGE}" - >
docker save "${SCAN_IMAGE}" -o "${IMAGE_ARCHIVE}"
AuthHeader=$(echo -n $SSA_CLIENT_ID:$SSA_CLIENT_SECRET | base64 -w0)
export SSA_TOKEN=$(curl --request POST --header "Authorization: Basic $AuthHeader" --header "Content-Type: application/x-www-form-urlencoded" ${SSA_ISSUER_URL} | jq ".access_token" | tr -d '"') export SSA_TOKEN=$(curl --request POST --header "Authorization: Basic $AuthHeader" --header "Content-Type: application/x-www-form-urlencoded" ${SSA_ISSUER_URL} | jq ".access_token" | tr -d '"')
if [ -z "$SSA_TOKEN" ]; then exit 1; else echo "SSA_TOKEN set!"; fi - if [ -z "$SSA_TOKEN" ]; then exit 1; else echo "SSA_TOKEN set!"; fi
script:
pulse-cli -n $NSPECT_ID --ssa $SSA_TOKEN scan -i $IMAGE_ARCHIVE -p $CONTAINER_POLICY -o - pulse-cli -n $NSPECT_ID --ssa $SSA_TOKEN scan -i $IMAGE_ARCHIVE -p $CONTAINER_POLICY -o
rm -f "${IMAGE_ARCHIVE}" - rm -f "${IMAGE_ARCHIVE}"
artifacts: artifacts:
when: always when: always
expire_in: 1 week expire_in: 1 week
@@ -155,10 +126,62 @@ scan-images:
- vulns.json - vulns.json
- policy_evaluation.json - policy_evaluation.json
upload-kitmaker-packages: # Define the scan targets
scan-ubuntu20.04-amd64:
extends:
- .dist-ubuntu20.04
- .platform-amd64
- .scan
needs:
- image-ubuntu20.04
scan-ubuntu20.04-arm64:
extends:
- .dist-ubuntu20.04
- .platform-arm64
- .scan
needs:
- image-ubuntu20.04
- scan-ubuntu20.04-amd64
scan-ubi8-amd64:
extends:
- .dist-ubi8
- .platform-amd64
- .scan
needs:
- image-ubi8
scan-ubi8-arm64:
extends:
- .dist-ubi8
- .platform-arm64
- .scan
needs:
- image-ubi8
- scan-ubi8-amd64
scan-packaging:
extends:
- .dist-packaging
- .scan
needs:
- image-packaging
# Define external release helpers
.release:ngc:
extends:
- .release:external
variables:
OUT_REGISTRY_USER: "${NGC_REGISTRY_USER}"
OUT_REGISTRY_TOKEN: "${NGC_REGISTRY_TOKEN}"
OUT_REGISTRY: "${NGC_REGISTRY}"
OUT_IMAGE_NAME: "${NGC_REGISTRY_IMAGE}"
.release:packages:
stage: release stage: release
needs: needs:
- pull-images - image-packaging
variables: variables:
VERSION: "${CI_COMMIT_SHORT_SHA}" VERSION: "${CI_COMMIT_SHORT_SHA}"
PACKAGE_REGISTRY: "${CI_REGISTRY}" PACKAGE_REGISTRY: "${CI_REGISTRY}"
@@ -176,81 +199,34 @@ upload-kitmaker-packages:
- ./scripts/release-kitmaker-artifactory.sh "${KITMAKER_ARTIFACTORY_REPO}" - ./scripts/release-kitmaker-artifactory.sh "${KITMAKER_ARTIFACTORY_REPO}"
- rm -rf ${ARTIFACTS_DIR} - rm -rf ${ARTIFACTS_DIR}
push-images-to-staging: # Define the package release targets
release:packages:kitmaker:
extends: extends:
- .copy-images - .release:packages
stage: release
release:staging-ubuntu20.04:
extends:
- .release:staging
- .dist-ubuntu20.04
needs: needs:
- scan-images - image-ubuntu20.04
variables:
IN_REGISTRY: "${CI_REGISTRY}"
IN_REGISTRY_USER: "${CI_REGISTRY_USER}"
IN_REGISTRY_TOKEN: "${CI_REGISTRY_PASSWORD}"
IN_IMAGE_NAME: "${CI_REGISTRY_IMAGE}/container-toolkit"
IN_IMAGE_TAG: "${CI_COMMIT_SHORT_SHA}"
OUT_REGISTRY: "${NGC_REGISTRY}" # Define the external release targets
OUT_REGISTRY_USER: "${NGC_REGISTRY_USER}" # Release to NGC
OUT_REGISTRY_TOKEN: "${NGC_REGISTRY_TOKEN}" release:ngc-ubuntu20.04:
OUT_IMAGE_NAME: "${NGC_STAGING_REGISTRY}/container-toolkit"
OUT_IMAGE_TAG: "${CI_COMMIT_SHORT_SHA}"
.release-images:
extends: extends:
- .copy-images - .dist-ubuntu20.04
stage: release - .release:ngc
needs:
- scan-images
- push-images-to-staging
variables:
IN_REGISTRY: "${CI_REGISTRY}"
IN_REGISTRY_USER: "${CI_REGISTRY_USER}"
IN_REGISTRY_TOKEN: "${CI_REGISTRY_PASSWORD}"
IN_IMAGE_NAME: "${CI_REGISTRY_IMAGE}/container-toolkit"
IN_IMAGE_TAG: "${CI_COMMIT_SHORT_SHA}"
OUT_REGISTRY: "${NGC_REGISTRY}" release:ngc-ubi8:
OUT_REGISTRY_USER: "${NGC_REGISTRY_USER}"
OUT_REGISTRY_TOKEN: "${NGC_REGISTRY_TOKEN}"
OUT_IMAGE_NAME: "${NGC_REGISTRY_IMAGE}"
OUT_IMAGE_TAG: "${CI_COMMIT_TAG}"
release-images-to-ngc:
extends: extends:
- .release-images - .dist-ubi8
rules: - .release:ngc
- if: $CI_COMMIT_TAG
release-images-dummy: release:ngc-packaging:
extends: extends:
- .release-images - .dist-packaging
variables: - .release:ngc
REGCTL: "echo [DUMMY] regctl"
rules:
- if: $CI_COMMIT_TAG == null || $CI_COMMIT_TAG == ""
# .sign-images forms the base of the jobs which sign images in the NGC registry.
.sign-images:
stage: sign
image: ubuntu:latest
parallel:
matrix:
- TAG_SUFFIX: ["", "-packaging"]
variables:
IMAGE_NAME: "${NGC_REGISTRY_IMAGE}"
IMAGE_TAG: "${CI_COMMIT_TAG}"
NGC_CLI: "ngc-cli/ngc"
before_script:
- !reference [.ngccli-setup, before_script]
script:
- |
# We ensure that the IMAGE_NAME and IMAGE_TAG is set
echo Image Name: ${IMAGE_NAME} && [[ -n "${IMAGE_NAME}" ]] || exit 1
echo Image Tag: ${IMAGE_TAG} && [[ -n "${IMAGE_TAG}" ]] || exit 1
export IMAGE=${IMAGE_NAME}:${IMAGE_TAG}${TAG_SUFFIX}
echo "Signing the image ${IMAGE}"
${NGC_CLI} registry image publish --source ${IMAGE} ${IMAGE} --public --discoverable --allow-guest --sign --org nvidia
# Define the external image signing steps for NGC # Define the external image signing steps for NGC
# Download the ngc cli binary for use in the sign steps # Download the ngc cli binary for use in the sign steps
@@ -268,24 +244,45 @@ release-images-dummy:
- unzip ngccli_linux.zip - unzip ngccli_linux.zip
- chmod u+x ngc-cli/ngc - chmod u+x ngc-cli/ngc
sign-ngc-images: # .sign forms the base of the deployment jobs which signs images in the CI registry.
extends: # This is extended with the image name and version to be deployed.
- .sign-images .sign:ngc:
needs: image: ubuntu:latest
- release-images-to-ngc stage: sign
rules: rules:
- if: $CI_COMMIT_TAG - if: $CI_COMMIT_TAG
variables: variables:
NGC_CLI_API_KEY: "${NGC_REGISTRY_TOKEN}" NGC_CLI_API_KEY: "${NGC_REGISTRY_TOKEN}"
IMAGE_NAME: "${NGC_REGISTRY_IMAGE}"
IMAGE_TAG: "${CI_COMMIT_TAG}-${DIST}"
retry: retry:
max: 2 max: 2
before_script:
- !reference [.ngccli-setup, before_script]
# We ensure that the IMAGE_NAME and IMAGE_TAG is set
- 'echo Image Name: ${IMAGE_NAME} && [[ -n "${IMAGE_NAME}" ]] || exit 1'
- 'echo Image Tag: ${IMAGE_TAG} && [[ -n "${IMAGE_TAG}" ]] || exit 1'
script:
- 'echo "Signing the image ${IMAGE_NAME}:${IMAGE_TAG}"'
- ngc-cli/ngc registry image publish --source ${IMAGE_NAME}:${IMAGE_TAG} ${IMAGE_NAME}:${IMAGE_TAG} --public --discoverable --allow-guest --sign --org nvidia
sign-images-dummy: sign:ngc-ubuntu20.04:
extends: extends:
- .sign-images - .dist-ubuntu20.04
- .sign:ngc
needs: needs:
- release-images-dummy - release:ngc-ubuntu20.04
variables:
NGC_CLI: "echo [DUMMY] ngc-cli/ngc" sign:ngc-ubi8:
rules: extends:
- if: $CI_COMMIT_TAG == null || $CI_COMMIT_TAG == "" - .dist-ubi8
- .sign:ngc
needs:
- release:ngc-ubi8
sign:ngc-packaging:
extends:
- .dist-packaging
- .sign:ngc
needs:
- release:ngc-packaging

View File

@@ -22,7 +22,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod" "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/chmod"
symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks" symlinks "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/create-symlinks"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat" "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/cudacompat"
disabledevicenodemodification "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/disable-device-node-modification"
ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache" ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-cdi-hook/update-ldcache"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
) )
@@ -35,7 +34,6 @@ func New(logger logger.Interface) []*cli.Command {
symlinks.NewCommand(logger), symlinks.NewCommand(logger),
chmod.NewCommand(logger), chmod.NewCommand(logger),
cudacompat.NewCommand(logger), cudacompat.NewCommand(logger),
disabledevicenodemodification.NewCommand(logger),
} }
} }

View File

@@ -1,144 +0,0 @@
/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package disabledevicenodemodification
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/urfave/cli/v2"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
)
const (
nvidiaDriverParamsPath = "/proc/driver/nvidia/params"
)
type options struct {
containerSpec string
}
// NewCommand constructs an disable-device-node-modification subcommand with the specified logger
func NewCommand(logger logger.Interface) *cli.Command {
cfg := options{}
c := cli.Command{
Name: "disable-device-node-modification",
Usage: "Ensure that the /proc/driver/nvidia/params file present in the container does not allow device node modifications.",
Before: func(c *cli.Context) error {
return validateFlags(c, &cfg)
},
Action: func(c *cli.Context) error {
return run(c, &cfg)
},
}
c.Flags = []cli.Flag{
&cli.StringFlag{
Name: "container-spec",
Hidden: true,
Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN",
Destination: &cfg.containerSpec,
},
}
return &c
}
func validateFlags(c *cli.Context, cfg *options) error {
return nil
}
func run(_ *cli.Context, cfg *options) error {
modifiedParamsFileContents, err := getModifiedNVIDIAParamsContents()
if err != nil {
return fmt.Errorf("failed to get modified params file contents: %w", err)
}
if len(modifiedParamsFileContents) == 0 {
return nil
}
s, err := oci.LoadContainerState(cfg.containerSpec)
if err != nil {
return fmt.Errorf("failed to load container state: %w", err)
}
containerRootDirPath, err := s.GetContainerRoot()
if err != nil {
return fmt.Errorf("failed to determined container root: %w", err)
}
return createParamsFileInContainer(containerRootDirPath, modifiedParamsFileContents)
}
func getModifiedNVIDIAParamsContents() ([]byte, error) {
hostNvidiaParamsFile, err := os.Open(nvidiaDriverParamsPath)
if errors.Is(err, os.ErrNotExist) {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to load params file: %w", err)
}
defer hostNvidiaParamsFile.Close()
modifiedContents, err := getModifiedParamsFileContentsFromReader(hostNvidiaParamsFile)
if err != nil {
return nil, fmt.Errorf("failed to get modfied params file contents: %w", err)
}
return modifiedContents, nil
}
// getModifiedParamsFileContentsFromReader returns the contents of a modified params file from the specified reader.
func getModifiedParamsFileContentsFromReader(r io.Reader) ([]byte, error) {
var modified bytes.Buffer
scanner := bufio.NewScanner(r)
var requiresModification bool
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "ModifyDeviceFiles: ") {
if line == "ModifyDeviceFiles: 0" {
return nil, nil
}
if line == "ModifyDeviceFiles: 1" {
line = "ModifyDeviceFiles: 0"
requiresModification = true
}
}
if _, err := modified.WriteString(line + "\n"); err != nil {
return nil, fmt.Errorf("failed to create output buffer: %w", err)
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read params file: %w", err)
}
if !requiresModification {
return nil, nil
}
return modified.Bytes(), nil
}

View File

@@ -1,91 +0,0 @@
/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package disabledevicenodemodification
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetModifiedParamsFileContentsFromReader(t *testing.T) {
testCases := map[string]struct {
contents []byte
expectedError error
expectedContents []byte
}{
"no contents": {
contents: nil,
expectedError: nil,
expectedContents: nil,
},
"other contents are ignored": {
contents: []byte(`# Some other content
that we don't care about
`),
expectedError: nil,
expectedContents: nil,
},
"already zero requires no modification": {
contents: []byte("ModifyDeviceFiles: 0"),
expectedError: nil,
expectedContents: nil,
},
"leading spaces require no modification": {
contents: []byte(" ModifyDeviceFiles: 1"),
},
"Trailing spaces require no modification": {
contents: []byte("ModifyDeviceFiles: 1 "),
},
"Not 1 require no modification": {
contents: []byte("ModifyDeviceFiles: 11"),
},
"single line requires modification": {
contents: []byte("ModifyDeviceFiles: 1"),
expectedError: nil,
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
},
"single line with trailing newline requires modification": {
contents: []byte("ModifyDeviceFiles: 1\n"),
expectedError: nil,
expectedContents: []byte("ModifyDeviceFiles: 0\n"),
},
"other content is maintained": {
contents: []byte(`ModifyDeviceFiles: 1
other content
that
is maintained`),
expectedError: nil,
expectedContents: []byte(`ModifyDeviceFiles: 0
other content
that
is maintained
`),
},
}
for description, tc := range testCases {
t.Run(description, func(t *testing.T) {
contents, err := getModifiedParamsFileContentsFromReader(bytes.NewReader(tc.contents))
require.EqualValues(t, tc.expectedError, err)
require.EqualValues(t, string(tc.expectedContents), string(contents))
})
}
}

View File

@@ -1,63 +0,0 @@
//go:build linux
/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package disabledevicenodemodification
import (
"fmt"
"os"
"path/filepath"
"github.com/opencontainers/runc/libcontainer/utils"
"golang.org/x/sys/unix"
)
func createParamsFileInContainer(containerRootDirPath string, contents []byte) error {
tmpRoot, err := os.MkdirTemp("", "nvct-empty-dir*")
if err != nil {
return fmt.Errorf("failed to create temp root: %w", err)
}
if err := createTmpFs(tmpRoot, len(contents)); err != nil {
return fmt.Errorf("failed to create tmpfs mount for params file: %w", err)
}
modifiedParamsFile, err := os.OpenFile(filepath.Join(tmpRoot, "nvct-params"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0444)
if err != nil {
return fmt.Errorf("failed to open modified params file: %w", err)
}
defer modifiedParamsFile.Close()
if _, err := modifiedParamsFile.Write(contents); err != nil {
return fmt.Errorf("failed to write temporary params file: %w", err)
}
err = utils.WithProcfd(containerRootDirPath, nvidiaDriverParamsPath, func(nvidiaDriverParamsFdPath string) error {
return unix.Mount(modifiedParamsFile.Name(), nvidiaDriverParamsFdPath, "", unix.MS_BIND|unix.MS_RDONLY|unix.MS_NODEV|unix.MS_PRIVATE|unix.MS_NOSYMFOLLOW, "")
})
if err != nil {
return fmt.Errorf("failed to mount modified params file: %w", err)
}
return nil
}
func createTmpFs(target string, size int) error {
return unix.Mount("tmpfs", target, "tmpfs", 0, fmt.Sprintf("size=%d", size))
}

View File

@@ -1,27 +0,0 @@
//go:build !linux
// +build !linux
/**
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package disabledevicenodemodification
import "fmt"
func createParamsFileInContainer(containerRootDirPath string, contents []byte) error {
return fmt.Errorf("not supported")
}

View File

@@ -13,6 +13,10 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
) )
const (
capSysAdmin = "CAP_SYS_ADMIN"
)
type nvidiaConfig struct { type nvidiaConfig struct {
Devices []string Devices []string
MigConfigDevices string MigConfigDevices string
@@ -99,9 +103,9 @@ func loadSpec(path string) (spec *Spec) {
return return
} }
func (s *Spec) GetCapabilities() []string { func isPrivileged(s *Spec) bool {
if s == nil || s.Process == nil || s.Process.Capabilities == nil { if s.Process.Capabilities == nil {
return nil return false
} }
var caps []string var caps []string
@@ -114,22 +118,67 @@ func (s *Spec) GetCapabilities() []string {
if err != nil { if err != nil {
log.Panicln("could not decode Process.Capabilities in OCI spec:", err) log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
} }
return caps for _, c := range caps {
if c == capSysAdmin {
return true
}
}
return false
} }
// Otherwise, parse s.Process.Capabilities as: // Otherwise, parse s.Process.Capabilities as:
// github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54 // github.com/opencontainers/runtime-spec/blob/v1.0.0/specs-go/config.go#L30-L54
capabilities := specs.LinuxCapabilities{} process := specs.Process{
err := json.Unmarshal(*s.Process.Capabilities, &capabilities) Env: s.Process.Env,
}
err := json.Unmarshal(*s.Process.Capabilities, &process.Capabilities)
if err != nil { if err != nil {
log.Panicln("could not decode Process.Capabilities in OCI spec:", err) log.Panicln("could not decode Process.Capabilities in OCI spec:", err)
} }
return image.OCISpecCapabilities(capabilities).GetCapabilities() fullSpec := specs.Spec{
Version: *s.Version,
Process: &process,
}
return image.IsPrivileged(&fullSpec)
} }
func isPrivileged(s *Spec) bool { func getDevicesFromEnvvar(containerImage image.CUDA, swarmResourceEnvvars []string) []string {
return image.IsPrivileged(s) // We check if the image has at least one of the Swarm resource envvars defined and use this
// if specified.
for _, envvar := range swarmResourceEnvvars {
if containerImage.HasEnvvar(envvar) {
return containerImage.DevicesFromEnvvars(swarmResourceEnvvars...).List()
}
}
return containerImage.VisibleDevicesFromEnvVar()
}
func (hookConfig *hookConfig) getDevices(image image.CUDA, privileged bool) []string {
// If enabled, try and get the device list from volume mounts first
if hookConfig.AcceptDeviceListAsVolumeMounts {
devices := image.VisibleDevicesFromMounts()
if len(devices) > 0 {
return devices
}
}
// Fallback to reading from the environment variable if privileges are correct
devices := getDevicesFromEnvvar(image, hookConfig.getSwarmResourceEnvvars())
if len(devices) == 0 {
return nil
}
if privileged || hookConfig.AcceptEnvvarUnprivileged {
return devices
}
configName := hookConfig.getConfigOption("AcceptEnvvarUnprivileged")
log.Printf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES (privileged=%v, %v=%v) ", privileged, configName, hookConfig.AcceptEnvvarUnprivileged)
return nil
} }
func getMigConfigDevices(i image.CUDA) *string { func getMigConfigDevices(i image.CUDA) *string {
@@ -176,6 +225,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
// We use the default driver capabilities by default. This is filtered to only include the // We use the default driver capabilities by default. This is filtered to only include the
// supported capabilities // supported capabilities
supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities) supportedDriverCapabilities := image.NewDriverCapabilities(hookConfig.SupportedDriverCapabilities)
capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities) capabilities := supportedDriverCapabilities.Intersection(image.DefaultDriverCapabilities)
capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities) capsEnvSpecified := cudaImage.HasEnvvar(image.EnvVarNvidiaDriverCapabilities)
@@ -201,7 +251,7 @@ func (hookConfig *hookConfig) getDriverCapabilities(cudaImage image.CUDA, legacy
func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig { func (hookConfig *hookConfig) getNvidiaConfig(image image.CUDA, privileged bool) *nvidiaConfig {
legacyImage := image.IsLegacy() legacyImage := image.IsLegacy()
devices := image.VisibleDevices() devices := hookConfig.getDevices(image, privileged)
if len(devices) == 0 { if len(devices) == 0 {
// empty devices means this is not a GPU container. // empty devices means this is not a GPU container.
return nil return nil
@@ -256,25 +306,20 @@ func (hookConfig *hookConfig) getContainerConfig() (config containerConfig) {
s := loadSpec(path.Join(b, "config.json")) s := loadSpec(path.Join(b, "config.json"))
privileged := isPrivileged(s) image, err := image.New(
i, err := image.New(
image.WithEnv(s.Process.Env), image.WithEnv(s.Process.Env),
image.WithMounts(s.Mounts), image.WithMounts(s.Mounts),
image.WithPrivileged(privileged),
image.WithDisableRequire(hookConfig.DisableRequire), image.WithDisableRequire(hookConfig.DisableRequire),
image.WithAcceptDeviceListAsVolumeMounts(hookConfig.AcceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(hookConfig.AcceptEnvvarUnprivileged),
image.WithPreferredVisibleDevicesEnvVars(hookConfig.getSwarmResourceEnvvars()...),
) )
if err != nil { if err != nil {
log.Panicln(err) log.Panicln(err)
} }
privileged := isPrivileged(s)
return containerConfig{ return containerConfig{
Pid: h.Pid, Pid: h.Pid,
Rootfs: s.Root.Path, Rootfs: s.Root.Path,
Image: i, Image: image,
Nvidia: hookConfig.getNvidiaConfig(i, privileged), Nvidia: hookConfig.getNvidiaConfig(image, privileged),
} }
} }

View File

@@ -1,8 +1,10 @@
package main package main
import ( import (
"path/filepath"
"testing" "testing"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
@@ -477,10 +479,7 @@ func TestGetNvidiaConfig(t *testing.T) {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, _ := image.New( image, _ := image.New(
image.WithEnvMap(tc.env), image.WithEnvMap(tc.env),
image.WithPrivileged(tc.privileged),
image.WithPreferredVisibleDevicesEnvVars(tc.hookConfig.getSwarmResourceEnvvars()...),
) )
// Wrap the call to getNvidiaConfig() in a closure. // Wrap the call to getNvidiaConfig() in a closure.
var cfg *nvidiaConfig var cfg *nvidiaConfig
getConfig := func() { getConfig := func() {
@@ -519,6 +518,340 @@ func TestGetNvidiaConfig(t *testing.T) {
} }
} }
func TestDeviceListSourcePriority(t *testing.T) {
var tests = []struct {
description string
mountDevices []specs.Mount
envvarDevices string
privileged bool
acceptUnprivileged bool
acceptMounts bool
expectedDevices []string
}{
{
description: "Mount devices, unprivileged, no accept unprivileged",
mountDevices: []specs.Mount{
{
Source: "/dev/null",
Destination: filepath.Join(image.DeviceListAsVolumeMountsRoot, "GPU0"),
},
{
Source: "/dev/null",
Destination: filepath.Join(image.DeviceListAsVolumeMountsRoot, "GPU1"),
},
},
envvarDevices: "GPU2,GPU3",
privileged: false,
acceptUnprivileged: false,
acceptMounts: true,
expectedDevices: []string{"GPU0", "GPU1"},
},
{
description: "No mount devices, unprivileged, no accept unprivileged",
mountDevices: nil,
envvarDevices: "GPU0,GPU1",
privileged: false,
acceptUnprivileged: false,
acceptMounts: true,
expectedDevices: nil,
},
{
description: "No mount devices, privileged, no accept unprivileged",
mountDevices: nil,
envvarDevices: "GPU0,GPU1",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
expectedDevices: []string{"GPU0", "GPU1"},
},
{
description: "No mount devices, unprivileged, accept unprivileged",
mountDevices: nil,
envvarDevices: "GPU0,GPU1",
privileged: false,
acceptUnprivileged: true,
acceptMounts: true,
expectedDevices: []string{"GPU0", "GPU1"},
},
{
description: "Mount devices, unprivileged, accept unprivileged, no accept mounts",
mountDevices: []specs.Mount{
{
Source: "/dev/null",
Destination: filepath.Join(image.DeviceListAsVolumeMountsRoot, "GPU0"),
},
{
Source: "/dev/null",
Destination: filepath.Join(image.DeviceListAsVolumeMountsRoot, "GPU1"),
},
},
envvarDevices: "GPU2,GPU3",
privileged: false,
acceptUnprivileged: true,
acceptMounts: false,
expectedDevices: []string{"GPU2", "GPU3"},
},
{
description: "Mount devices, unprivileged, no accept unprivileged, no accept mounts",
mountDevices: []specs.Mount{
{
Source: "/dev/null",
Destination: filepath.Join(image.DeviceListAsVolumeMountsRoot, "GPU0"),
},
{
Source: "/dev/null",
Destination: filepath.Join(image.DeviceListAsVolumeMountsRoot, "GPU1"),
},
},
envvarDevices: "GPU2,GPU3",
privileged: false,
acceptUnprivileged: false,
acceptMounts: false,
expectedDevices: nil,
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
// Wrap the call to getDevices() in a closure.
var devices []string
getDevices := func() {
image, _ := image.New(
image.WithEnvMap(
map[string]string{
image.EnvVarNvidiaVisibleDevices: tc.envvarDevices,
},
),
image.WithMounts(tc.mountDevices),
)
defaultConfig, _ := config.GetDefault()
cfg := &hookConfig{defaultConfig}
cfg.AcceptEnvvarUnprivileged = tc.acceptUnprivileged
cfg.AcceptDeviceListAsVolumeMounts = tc.acceptMounts
devices = cfg.getDevices(image, tc.privileged)
}
// For all other tests, just grab the devices and check the results
getDevices()
require.Equal(t, tc.expectedDevices, devices)
})
}
}
func TestGetDevicesFromEnvvar(t *testing.T) {
envDockerResourceGPUs := "DOCKER_RESOURCE_GPUS"
gpuID := "GPU-12345"
anotherGPUID := "GPU-67890"
thirdGPUID := "MIG-12345"
var tests = []struct {
description string
swarmResourceEnvvars []string
env map[string]string
expectedDevices []string
}{
{
description: "empty env returns nil for non-legacy image",
},
{
description: "blank NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: "",
},
},
{
description: "'void' NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: "void",
},
},
{
description: "'none' NVIDIA_VISIBLE_DEVICES returns empty for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: "none",
},
expectedDevices: []string{""},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
},
expectedDevices: []string{gpuID},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
image.EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{gpuID},
},
{
description: "empty env returns all for legacy image",
env: map[string]string{
image.EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{"all"},
},
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is ignored when
// not enabled
{
description: "missing NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
envDockerResourceGPUs: anotherGPUID,
},
},
{
description: "blank NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: "",
envDockerResourceGPUs: anotherGPUID,
},
},
{
description: "'void' NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: "void",
envDockerResourceGPUs: anotherGPUID,
},
},
{
description: "'none' NVIDIA_VISIBLE_DEVICES returns empty for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: "none",
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{""},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{gpuID},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID,
image.EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{gpuID},
},
{
description: "empty env returns all for legacy image",
env: map[string]string{
envDockerResourceGPUs: anotherGPUID,
image.EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{"all"},
},
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is selected when
// enabled
{
description: "empty env returns nil for non-legacy image",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
},
{
description: "blank DOCKER_RESOURCE_GPUS returns nil for non-legacy image",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: "",
},
},
{
description: "'void' DOCKER_RESOURCE_GPUS returns nil for non-legacy image",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: "void",
},
},
{
description: "'none' DOCKER_RESOURCE_GPUS returns empty for non-legacy image",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: "none",
},
expectedDevices: []string{""},
},
{
description: "DOCKER_RESOURCE_GPUS set returns value for non-legacy image",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: gpuID,
},
expectedDevices: []string{gpuID},
},
{
description: "DOCKER_RESOURCE_GPUS set returns value for legacy image",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: gpuID,
image.EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{gpuID},
},
{
description: "DOCKER_RESOURCE_GPUS is selected if present",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
{
description: "DOCKER_RESOURCE_GPUS overrides NVIDIA_VISIBLE_DEVICES if present",
swarmResourceEnvvars: []string{envDockerResourceGPUs},
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
{
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL overrides NVIDIA_VISIBLE_DEVICES if present",
swarmResourceEnvvars: []string{"DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
{
description: "All available swarm resource envvars are selected and override NVIDIA_VISIBLE_DEVICES if present",
swarmResourceEnvvars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS": thirdGPUID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
},
expectedDevices: []string{thirdGPUID, anotherGPUID},
},
{
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS override NVIDIA_VISIBLE_DEVICES if present",
swarmResourceEnvvars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
image.EnvVarNvidiaVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
image, _ := image.New(
image.WithEnvMap(tc.env),
)
devices := getDevicesFromEnvvar(image, tc.swarmResourceEnvvars)
require.EqualValues(t, tc.expectedDevices, devices)
})
}
}
func TestGetDriverCapabilities(t *testing.T) { func TestGetDriverCapabilities(t *testing.T) {
supportedCapabilities := "compute,display,utility,video" supportedCapabilities := "compute,display,utility,video"

View File

@@ -88,7 +88,7 @@ func (c hookConfig) getConfigOption(fieldName string) string {
// getSwarmResourceEnvvars returns the swarm resource envvars for the config. // getSwarmResourceEnvvars returns the swarm resource envvars for the config.
func (c *hookConfig) getSwarmResourceEnvvars() []string { func (c *hookConfig) getSwarmResourceEnvvars() []string {
if c == nil || c.SwarmResource == "" { if c.SwarmResource == "" {
return nil return nil
} }

View File

@@ -21,8 +21,8 @@ The `runtimes` config option allows for the low-level runtime to be specified. T
The default value for this setting is: The default value for this setting is:
```toml ```toml
runtimes = [ runtimes = [
"docker-runc",
"runc", "runc",
"crun",
] ]
``` ```

View File

@@ -14,7 +14,6 @@ import (
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit" "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info" "github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
) )
const ( const (
@@ -28,7 +27,7 @@ const (
) )
var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}} var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
var defaultLowLevelRuntimes = []string{"runc", "crun"} var defaultLowLevelRuntimes = []string{"docker-runc", "runc", "crun"}
var waitingForSignal = make(chan bool, 1) var waitingForSignal = make(chan bool, 1)
var signalReceived = make(chan bool, 1) var signalReceived = make(chan bool, 1)
@@ -37,11 +36,10 @@ var signalReceived = make(chan bool, 1)
type options struct { type options struct {
toolkitInstallDir string toolkitInstallDir string
noDaemon bool noDaemon bool
runtime string runtime string
pidFile string pidFile string
sourceRoot string sourceRoot string
packageType string
toolkitOptions toolkit.Options toolkitOptions toolkit.Options
runtimeOptions runtime.Options runtimeOptions runtime.Options
@@ -125,17 +123,11 @@ func (a app) build() *cli.App {
EnvVars: []string{"TOOLKIT_INSTALL_DIR", "ROOT"}, EnvVars: []string{"TOOLKIT_INSTALL_DIR", "ROOT"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "toolkit-source-root", Name: "source-root",
Usage: "The folder where the required toolkit artifacts can be found. If this is not specified, the path /artifacts/{{ .ToolkitPackageType }} is used where ToolkitPackageType is the resolved package type", Value: "/",
Usage: "The folder where the required toolkit artifacts can be found",
Destination: &options.sourceRoot, Destination: &options.sourceRoot,
EnvVars: []string{"TOOLKIT_SOURCE_ROOT"}, EnvVars: []string{"SOURCE_ROOT"},
},
&cli.StringFlag{
Name: "toolkit-package-type",
Usage: "specify the package type to use for the toolkit. One of ['deb', 'rpm', 'auto', '']. If 'auto' or '' are used, the type is inferred automatically.",
Value: "auto",
Destination: &options.packageType,
EnvVars: []string{"TOOLKIT_PACKAGE_TYPE"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "pid-file", Name: "pid-file",
@@ -153,15 +145,6 @@ func (a app) build() *cli.App {
} }
func (a *app) Before(c *cli.Context, o *options) error { func (a *app) Before(c *cli.Context, o *options) error {
if o.sourceRoot == "" {
sourceRoot, err := a.resolveSourceRoot(o.runtimeOptions.HostRootMount, o.packageType)
if err != nil {
return fmt.Errorf("failed to resolve source root: %v", err)
}
a.logger.Infof("Resolved source root to %v", sourceRoot)
o.sourceRoot = sourceRoot
}
a.toolkit = toolkit.NewInstaller( a.toolkit = toolkit.NewInstaller(
toolkit.WithLogger(a.logger), toolkit.WithLogger(a.logger),
toolkit.WithSourceRoot(o.sourceRoot), toolkit.WithSourceRoot(o.sourceRoot),
@@ -294,35 +277,3 @@ func (a *app) shutdown(pidFile string) {
a.logger.Warningf("Unable to remove pidfile: %v", err) a.logger.Warningf("Unable to remove pidfile: %v", err)
} }
} }
func (a *app) resolveSourceRoot(hostRoot string, packageType string) (string, error) {
resolvedPackageType, err := a.resolvePackageType(hostRoot, packageType)
if err != nil {
return "", err
}
switch resolvedPackageType {
case "deb":
return "/artifacts/deb", nil
case "rpm":
return "/artifacts/rpm", nil
default:
return "", fmt.Errorf("invalid package type: %v", resolvedPackageType)
}
}
func (a *app) resolvePackageType(hostRoot string, packageType string) (rPackageTypes string, rerr error) {
if packageType != "" && packageType != "auto" {
return packageType, nil
}
locator := lookup.NewExecutableLocator(a.logger, hostRoot)
if candidates, err := locator.Locate("/usr/bin/rpm"); err == nil && len(candidates) > 0 {
return "rpm", nil
}
if candidates, err := locator.Locate("/usr/bin/dpkg"); err == nil && len(candidates) > 0 {
return "deb", nil
}
return "deb", nil
}

View File

@@ -67,7 +67,7 @@ swarm-resource = ""
debug = "/dev/null" debug = "/dev/null"
log-level = "info" log-level = "info"
mode = "auto" mode = "auto"
runtimes = ["runc", "crun"] runtimes = ["docker-runc", "runc", "crun"]
[nvidia-container-runtime.modes] [nvidia-container-runtime.modes]
@@ -131,7 +131,7 @@ swarm-resource = ""
debug = "/dev/null" debug = "/dev/null"
log-level = "info" log-level = "info"
mode = "auto" mode = "auto"
runtimes = ["runc", "crun"] runtimes = ["docker-runc", "runc", "crun"]
[nvidia-container-runtime.modes] [nvidia-container-runtime.modes]
@@ -198,7 +198,7 @@ swarm-resource = ""
debug = "/dev/null" debug = "/dev/null"
log-level = "info" log-level = "info"
mode = "auto" mode = "auto"
runtimes = ["runc", "crun"] runtimes = ["docker-runc", "runc", "crun"]
[nvidia-container-runtime.modes] [nvidia-container-runtime.modes]
@@ -262,7 +262,7 @@ swarm-resource = ""
debug = "/dev/null" debug = "/dev/null"
log-level = "info" log-level = "info"
mode = "auto" mode = "auto"
runtimes = ["runc", "crun"] runtimes = ["docker-runc", "runc", "crun"]
[nvidia-container-runtime.modes] [nvidia-container-runtime.modes]
@@ -348,7 +348,7 @@ swarm-resource = ""
debug = "/dev/null" debug = "/dev/null"
log-level = "info" log-level = "info"
mode = "auto" mode = "auto"
runtimes = ["runc", "crun"] runtimes = ["docker-runc", "runc", "crun"]
[nvidia-container-runtime.modes] [nvidia-container-runtime.modes]
@@ -433,7 +433,7 @@ swarm-resource = ""
"--driver-root-ctr-path=" + hostRoot, "--driver-root-ctr-path=" + hostRoot,
"--pid-file=" + filepath.Join(testRoot, "toolkit.pid"), "--pid-file=" + filepath.Join(testRoot, "toolkit.pid"),
"--restart-mode=none", "--restart-mode=none",
"--toolkit-source-root=" + filepath.Join(artifactRoot, "deb"), "--source-root=" + filepath.Join(artifactRoot, "deb"),
} }
err := app.Run(append(testArgs, tc.args...)) err := app.Run(append(testArgs, tc.args...))

View File

@@ -47,9 +47,7 @@ var _ Installer = (*toolkitInstaller)(nil)
// New creates a toolkit installer with the specified options. // New creates a toolkit installer with the specified options.
func New(opts ...Option) (Installer, error) { func New(opts ...Option) (Installer, error) {
t := &toolkitInstaller{ t := &toolkitInstaller{}
sourceRoot: "/",
}
for _, opt := range opts { for _, opt := range opts {
opt(t) opt(t)
} }
@@ -57,6 +55,9 @@ func New(opts ...Option) (Installer, error) {
if t.logger == nil { if t.logger == nil {
t.logger = logger.New() t.logger = logger.New()
} }
if t.sourceRoot == "" {
t.sourceRoot = "/"
}
if t.artifactRoot == nil { if t.artifactRoot == nil {
artifactRoot, err := newArtifactRoot(t.logger, t.sourceRoot) artifactRoot, err := newArtifactRoot(t.logger, t.sourceRoot)
if err != nil { if err != nil {

View File

@@ -215,8 +215,7 @@ func Flags(opts *Options) []cli.Flag {
// An Installer is used to install the NVIDIA Container Toolkit from the toolkit container. // An Installer is used to install the NVIDIA Container Toolkit from the toolkit container.
type Installer struct { type Installer struct {
logger logger.Interface logger logger.Interface
sourceRoot string sourceRoot string
// toolkitRoot specifies the destination path at which the toolkit is installed. // toolkitRoot specifies the destination path at which the toolkit is installed.
toolkitRoot string toolkitRoot string

View File

@@ -57,7 +57,6 @@ type options struct {
configSearchPaths cli.StringSlice configSearchPaths cli.StringSlice
librarySearchPaths cli.StringSlice librarySearchPaths cli.StringSlice
disabledHooks cli.StringSlice
csv struct { csv struct {
files cli.StringSlice files cli.StringSlice
@@ -97,20 +96,17 @@ func (m command) build() *cli.Command {
Name: "config-search-path", Name: "config-search-path",
Usage: "Specify the path to search for config files when discovering the entities that should be included in the CDI specification.", Usage: "Specify the path to search for config files when discovering the entities that should be included in the CDI specification.",
Destination: &opts.configSearchPaths, Destination: &opts.configSearchPaths,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_CONFIG_SEARCH_PATHS"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "output", Name: "output",
Usage: "Specify the file to output the generated CDI specification to. If this is '' the specification is output to STDOUT", Usage: "Specify the file to output the generated CDI specification to. If this is '' the specification is output to STDOUT",
Destination: &opts.output, Destination: &opts.output,
EnvVars: []string{"NVIDIA_CTK_CDI_OUTPUT_FILE_PATH"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "format", Name: "format",
Usage: "The output format for the generated spec [json | yaml]. This overrides the format defined by the output file extension (if specified).", Usage: "The output format for the generated spec [json | yaml]. This overrides the format defined by the output file extension (if specified).",
Value: spec.FormatYAML, Value: spec.FormatYAML,
Destination: &opts.format, Destination: &opts.format,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_OUTPUT_FORMAT"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "mode", Name: "mode",
@@ -120,32 +116,27 @@ func (m command) build() *cli.Command {
"If mode is set to 'auto' the mode will be determined based on the system configuration.", "If mode is set to 'auto' the mode will be determined based on the system configuration.",
Value: string(nvcdi.ModeAuto), Value: string(nvcdi.ModeAuto),
Destination: &opts.mode, Destination: &opts.mode,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_MODE"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "dev-root", Name: "dev-root",
Usage: "Specify the root where `/dev` is located. If this is not specified, the driver-root is assumed.", Usage: "Specify the root where `/dev` is located. If this is not specified, the driver-root is assumed.",
Destination: &opts.devRoot, Destination: &opts.devRoot,
EnvVars: []string{"NVIDIA_CTK_DEV_ROOT"},
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "device-name-strategy", Name: "device-name-strategy",
Usage: "Specify the strategy for generating device names. If this is specified multiple times, the devices will be duplicated for each strategy. One of [index | uuid | type-index]", Usage: "Specify the strategy for generating device names. If this is specified multiple times, the devices will be duplicated for each strategy. One of [index | uuid | type-index]",
Value: cli.NewStringSlice(nvcdi.DeviceNameStrategyIndex, nvcdi.DeviceNameStrategyUUID), Value: cli.NewStringSlice(nvcdi.DeviceNameStrategyIndex, nvcdi.DeviceNameStrategyUUID),
Destination: &opts.deviceNameStrategies, Destination: &opts.deviceNameStrategies,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_DEVICE_NAME_STRATEGIES"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "driver-root", Name: "driver-root",
Usage: "Specify the NVIDIA GPU driver root to use when discovering the entities that should be included in the CDI specification.", Usage: "Specify the NVIDIA GPU driver root to use when discovering the entities that should be included in the CDI specification.",
Destination: &opts.driverRoot, Destination: &opts.driverRoot,
EnvVars: []string{"NVIDIA_CTK_DRIVER_ROOT"},
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "library-search-path", Name: "library-search-path",
Usage: "Specify the path to search for libraries when discovering the entities that should be included in the CDI specification.\n\tNote: This option only applies to CSV mode.", Usage: "Specify the path to search for libraries when discovering the entities that should be included in the CDI specification.\n\tNote: This option only applies to CSV mode.",
Destination: &opts.librarySearchPaths, Destination: &opts.librarySearchPaths,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_LIBRARY_SEARCH_PATHS"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "nvidia-cdi-hook-path", Name: "nvidia-cdi-hook-path",
@@ -154,13 +145,11 @@ func (m command) build() *cli.Command {
"If not specified, the PATH will be searched for `nvidia-cdi-hook`. " + "If not specified, the PATH will be searched for `nvidia-cdi-hook`. " +
"NOTE: That if this is specified as `nvidia-ctk`, the PATH will be searched for `nvidia-ctk` instead.", "NOTE: That if this is specified as `nvidia-ctk`, the PATH will be searched for `nvidia-ctk` instead.",
Destination: &opts.nvidiaCDIHookPath, Destination: &opts.nvidiaCDIHookPath,
EnvVars: []string{"NVIDIA_CTK_CDI_HOOK_PATH"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "ldconfig-path", Name: "ldconfig-path",
Usage: "Specify the path to use for ldconfig in the generated CDI specification", Usage: "Specify the path to use for ldconfig in the generated CDI specification",
Destination: &opts.ldconfigPath, Destination: &opts.ldconfigPath,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_LDCONFIG_PATH"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "vendor", Name: "vendor",
@@ -168,7 +157,6 @@ func (m command) build() *cli.Command {
Usage: "the vendor string to use for the generated CDI specification.", Usage: "the vendor string to use for the generated CDI specification.",
Value: "nvidia.com", Value: "nvidia.com",
Destination: &opts.vendor, Destination: &opts.vendor,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_VENDOR"},
}, },
&cli.StringFlag{ &cli.StringFlag{
Name: "class", Name: "class",
@@ -176,30 +164,17 @@ func (m command) build() *cli.Command {
Usage: "the class string to use for the generated CDI specification.", Usage: "the class string to use for the generated CDI specification.",
Value: "gpu", Value: "gpu",
Destination: &opts.class, Destination: &opts.class,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_CLASS"},
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "csv.file", Name: "csv.file",
Usage: "The path to the list of CSV files to use when generating the CDI specification in CSV mode.", Usage: "The path to the list of CSV files to use when generating the CDI specification in CSV mode.",
Value: cli.NewStringSlice(csv.DefaultFileList()...), Value: cli.NewStringSlice(csv.DefaultFileList()...),
Destination: &opts.csv.files, Destination: &opts.csv.files,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_CSV_FILES"},
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "csv.ignore-pattern", Name: "csv.ignore-pattern",
Usage: "specify a pattern the CSV mount specifications.", Usage: "Specify a pattern the CSV mount specifications.",
Destination: &opts.csv.ignorePatterns, Destination: &opts.csv.ignorePatterns,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_CSV_IGNORE_PATTERNS"},
},
&cli.StringSliceFlag{
Name: "disable-hook",
Aliases: []string{"disable-hooks"},
Usage: "specify a specific hook to skip when generating CDI " +
"specifications. This can be specified multiple times and the " +
"special hook name 'all' can be used ensure that the generated " +
"CDI specification does not include any hooks.",
Destination: &opts.disabledHooks,
EnvVars: []string{"NVIDIA_CTK_CDI_GENERATE_DISABLED_HOOKS"},
}, },
} }
@@ -287,7 +262,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
deviceNamers = append(deviceNamers, deviceNamer) deviceNamers = append(deviceNamers, deviceNamer)
} }
cdiOptions := []nvcdi.Option{ cdilib, err := nvcdi.New(
nvcdi.WithLogger(m.logger), nvcdi.WithLogger(m.logger),
nvcdi.WithDriverRoot(opts.driverRoot), nvcdi.WithDriverRoot(opts.driverRoot),
nvcdi.WithDevRoot(opts.devRoot), nvcdi.WithDevRoot(opts.devRoot),
@@ -301,13 +276,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()), nvcdi.WithCSVIgnorePatterns(opts.csv.ignorePatterns.Value()),
// We set the following to allow for dependency injection: // We set the following to allow for dependency injection:
nvcdi.WithNvmlLib(opts.nvmllib), nvcdi.WithNvmlLib(opts.nvmllib),
} )
for _, hook := range opts.disabledHooks.Value() {
cdiOptions = append(cdiOptions, nvcdi.WithDisabledHook(hook))
}
cdilib, err := nvcdi.New(cdiOptions...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create CDI library: %v", err) return nil, fmt.Errorf("failed to create CDI library: %v", err)
} }

View File

@@ -26,7 +26,6 @@ import (
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100" "github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
testlog "github.com/sirupsen/logrus/hooks/test" testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"github.com/NVIDIA/nvidia-container-toolkit/internal/test" "github.com/NVIDIA/nvidia-container-toolkit/internal/test"
) )
@@ -111,206 +110,6 @@ containerEdits:
- /lib/x86_64-linux-gnu - /lib/x86_64-linux-gnu
env: env:
- NVIDIA_CTK_DEBUG=false - NVIDIA_CTK_DEBUG=false
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- disable-device-node-modification
env:
- NVIDIA_CTK_DEBUG=false
mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
options:
- ro
- nosuid
- nodev
- rbind
- rprivate
`,
},
{
description: "disableHooks1",
options: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
driverRoot: driverRoot,
disabledHooks: valueOf(cli.NewStringSlice("enable-cuda-compat")),
},
expectedOptions: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
driverRoot: driverRoot,
disabledHooks: valueOf(cli.NewStringSlice("enable-cuda-compat")),
},
expectedSpec: `---
cdiVersion: 0.5.0
kind: example.com/device
devices:
- name: "0"
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
- name: all
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
containerEdits:
env:
- NVIDIA_VISIBLE_DEVICES=void
deviceNodes:
- path: /dev/nvidiactl
hostPath: {{ .driverRoot }}/dev/nvidiactl
hooks:
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- create-symlinks
- --link
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
env:
- NVIDIA_CTK_DEBUG=false
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- update-ldcache
- --folder
- /lib/x86_64-linux-gnu
env:
- NVIDIA_CTK_DEBUG=false
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- disable-device-node-modification
env:
- NVIDIA_CTK_DEBUG=false
mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
options:
- ro
- nosuid
- nodev
- rbind
- rprivate
`,
},
{
description: "disableHooks2",
options: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
driverRoot: driverRoot,
disabledHooks: valueOf(cli.NewStringSlice("enable-cuda-compat", "update-ldcache")),
},
expectedOptions: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
driverRoot: driverRoot,
disabledHooks: valueOf(cli.NewStringSlice("enable-cuda-compat", "update-ldcache")),
},
expectedSpec: `---
cdiVersion: 0.5.0
kind: example.com/device
devices:
- name: "0"
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
- name: all
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
containerEdits:
env:
- NVIDIA_VISIBLE_DEVICES=void
deviceNodes:
- path: /dev/nvidiactl
hostPath: {{ .driverRoot }}/dev/nvidiactl
hooks:
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- create-symlinks
- --link
- libcuda.so.1::/lib/x86_64-linux-gnu/libcuda.so
env:
- NVIDIA_CTK_DEBUG=false
- hookName: createContainer
path: /usr/bin/nvidia-cdi-hook
args:
- nvidia-cdi-hook
- disable-device-node-modification
env:
- NVIDIA_CTK_DEBUG=false
mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
options:
- ro
- nosuid
- nodev
- rbind
- rprivate
`,
},
{
description: "disableHooksAll",
options: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
driverRoot: driverRoot,
disabledHooks: valueOf(cli.NewStringSlice("all")),
},
expectedOptions: options{
format: "yaml",
mode: "nvml",
vendor: "example.com",
class: "device",
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
driverRoot: driverRoot,
disabledHooks: valueOf(cli.NewStringSlice("all")),
},
expectedSpec: `---
cdiVersion: 0.5.0
kind: example.com/device
devices:
- name: "0"
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
- name: all
containerEdits:
deviceNodes:
- path: /dev/nvidia0
hostPath: {{ .driverRoot }}/dev/nvidia0
containerEdits:
env:
- NVIDIA_VISIBLE_DEVICES=void
deviceNodes:
- path: /dev/nvidiactl
hostPath: {{ .driverRoot }}/dev/nvidiactl
mounts: mounts:
- hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77 - hostPath: {{ .driverRoot }}/lib/x86_64-linux-gnu/libcuda.so.999.88.77
containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77 containerPath: /lib/x86_64-linux-gnu/libcuda.so.999.88.77
@@ -363,9 +162,3 @@ containerEdits:
}) })
} }
} }
// valueOf returns the value of a pointer.
// Note that this does not check for a nil pointer and is only used for testing.
func valueOf[T any](v *T) T {
return *v
}

View File

@@ -64,7 +64,6 @@ func (m command) build() *cli.Command {
Usage: "specify the directories to scan for CDI specifications", Usage: "specify the directories to scan for CDI specifications",
Value: cli.NewStringSlice(cdi.DefaultSpecDirs...), Value: cli.NewStringSlice(cdi.DefaultSpecDirs...),
Destination: &cfg.cdiSpecDirs, Destination: &cfg.cdiSpecDirs,
EnvVars: []string{"NVIDIA_CTK_CDI_SPEC_DIRS"},
}, },
} }

View File

@@ -1,149 +0,0 @@
# SPDX-FileCopyrightText: Copyright (c) 2019 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG GOLANG_VERSION=x.x.x
ARG VERSION="N/A"
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubi9 AS build
RUN dnf install -y \
wget make git gcc \
&& \
rm -rf /var/cache/yum/*
ARG GOLANG_VERSION=x.x.x
RUN set -eux; \
\
arch="$(uname -m)"; \
case "${arch##*-}" in \
x86_64 | amd64) ARCH='amd64' ;; \
ppc64el | ppc64le) ARCH='ppc64le' ;; \
aarch64 | arm64) ARCH='arm64' ;; \
*) echo "unsupported architecture" ; exit 1 ;; \
esac; \
wget -nv -O - https://storage.googleapis.com/golang/go${GOLANG_VERSION}.linux-${ARCH}.tar.gz \
| tar -C /usr/local -xz
ENV GOPATH=/go
ENV PATH=$GOPATH/bin:/usr/local/go/bin:$PATH
WORKDIR /build
COPY . .
RUN mkdir -p /artifacts/bin
ARG VERSION="N/A"
ARG GIT_COMMIT="unknown"
RUN make PREFIX=/artifacts/bin cmd-nvidia-ctk-installer
# The packaging stage collects the deb and rpm packages built for supported
# architectures.
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubi9 AS packaging
ARG ARTIFACTS_ROOT
COPY ${ARTIFACTS_ROOT} /artifacts/packages/
WORKDIR /artifacts/packages
# build-args are added to the manifest.txt file below.
ARG PACKAGE_VERSION
ARG GIT_BRANCH
ARG GIT_COMMIT
ARG GIT_COMMIT_SHORT
ARG SOURCE_DATE_EPOCH
ARG VERSION
# Create a manifest.txt file with the absolute paths of all deb and rpm packages in the container
RUN echo "#IMAGE_EPOCH=$(date '+%s')" > /artifacts/manifest.txt && \
env | sed 's/^/#/g' >> /artifacts/manifest.txt && \
find /artifacts/packages -iname '*.deb' -o -iname '*.rpm' >> /artifacts/manifest.txt
RUN mkdir /licenses && mv /NGC-DL-CONTAINER-LICENSE /licenses/NGC-DL-CONTAINER-LICENSE
# The debpackages stage is used to extract the contents of deb packages.
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubuntu20.04 AS debpackages
ARG TARGETARCH
ARG PACKAGE_DIST_DEB=ubuntu18.04
COPY --from=packaging /artifacts/packages/${PACKAGE_DIST_DEB} /deb-packages
RUN mkdir -p /artifacts/deb
RUN set -eux; \
\
case "${TARGETARCH}" in \
x86_64 | amd64) ARCH='amd64' ;; \
ppc64el | ppc64le) ARCH='ppc64le' ;; \
aarch64 | arm64) ARCH='arm64' ;; \
*) echo "unsupported architecture" ; exit 1 ;; \
esac; \
for p in $(ls /deb-packages/${ARCH}/*.deb); do dpkg-deb -xv $p /artifacts/deb/; done
# The rpmpackages stage is used to extract the contents of the rpm packages.
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubi9 AS rpmpackages
RUN dnf install -y cpio
ARG TARGETARCH
ARG PACKAGE_DIST_RPM=centos7
COPY --from=packaging /artifacts/packages/${PACKAGE_DIST_RPM} /rpm-packages
RUN mkdir -p /artifacts/rpm
RUN set -eux; \
\
case "${TARGETARCH}" in \
x86_64 | amd64) ARCH='x86_64' ;; \
ppc64el | ppc64le) ARCH='ppc64le' ;; \
aarch64 | arm64) ARCH='aarch64' ;; \
*) echo "unsupported architecture" ; exit 1 ;; \
esac; \
for p in $(ls /rpm-packages/${ARCH}/*.rpm); do rpm2cpio $p | cpio -idmv -D /artifacts/rpm; done
# The artifacts image serves as an intermediate stage to collect the artifacts
# From the previous stages:
# - The extracted deb packages
# - The extracted rpm packages
# - The nvidia-ctk-installer binary
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubi9 AS artifacts
COPY --from=rpmpackages /artifacts/rpm /artifacts/rpm
COPY --from=debpackages /artifacts/deb /artifacts/deb
COPY --from=build /artifacts/bin /artifacts/build
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubi9
ENV NVIDIA_DISABLE_REQUIRE="true"
ENV NVIDIA_VISIBLE_DEVICES=void
ENV NVIDIA_DRIVER_CAPABILITIES=utility
COPY --from=artifacts /artifacts/rpm /artifacts/rpm
COPY --from=artifacts /artifacts/deb /artifacts/deb
COPY --from=artifacts /artifacts/build /work
WORKDIR /work
ENV PATH=/work:$PATH
ARG VERSION
LABEL io.k8s.display-name="NVIDIA Container Runtime Config"
LABEL name="NVIDIA Container Runtime Config"
LABEL vendor="NVIDIA"
LABEL version="${VERSION}"
LABEL release="N/A"
LABEL summary="Automatically Configure your Container Runtime for GPU support."
LABEL description="See summary"
RUN mkdir /licenses && mv /NGC-DL-CONTAINER-LICENSE /licenses/NGC-DL-CONTAINER-LICENSE
ENTRYPOINT ["/work/nvidia-ctk-installer"]

View File

@@ -0,0 +1,38 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG GOLANG_VERSION=x.x.x
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubuntu20.04
ARG ARTIFACTS_ROOT
COPY ${ARTIFACTS_ROOT} /artifacts/packages/
WORKDIR /artifacts/packages
# build-args are added to the manifest.txt file below.
ARG PACKAGE_DIST
ARG PACKAGE_VERSION
ARG GIT_BRANCH
ARG GIT_COMMIT
ARG GIT_COMMIT_SHORT
ARG SOURCE_DATE_EPOCH
ARG VERSION
# Create a manifest.txt file with the absolute paths of all deb and rpm packages in the container
RUN echo "#IMAGE_EPOCH=$(date '+%s')" > /artifacts/manifest.txt && \
env | sed 's/^/#/g' >> /artifacts/manifest.txt && \
find /artifacts/packages -iname '*.deb' -o -iname '*.rpm' >> /artifacts/manifest.txt
RUN mkdir /licenses && mv /NGC-DL-CONTAINER-LICENSE /licenses/NGC-DL-CONTAINER-LICENSE

View File

@@ -0,0 +1,90 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG GOLANG_VERSION=x.x.x
ARG VERSION="N/A"
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubi8 AS build
RUN yum install -y \
wget make git gcc \
&& \
rm -rf /var/cache/yum/*
ARG GOLANG_VERSION=x.x.x
RUN set -eux; \
\
arch="$(uname -m)"; \
case "${arch##*-}" in \
x86_64 | amd64) ARCH='amd64' ;; \
ppc64el | ppc64le) ARCH='ppc64le' ;; \
aarch64 | arm64) ARCH='arm64' ;; \
*) echo "unsupported architecture" ; exit 1 ;; \
esac; \
wget -nv -O - https://storage.googleapis.com/golang/go${GOLANG_VERSION}.linux-${ARCH}.tar.gz \
| tar -C /usr/local -xz
ENV GOPATH=/go
ENV PATH=$GOPATH/bin:/usr/local/go/bin:$PATH
WORKDIR /build
COPY . .
RUN mkdir /artifacts
ARG VERSION="N/A"
ARG GIT_COMMIT="unknown"
RUN make PREFIX=/artifacts cmd-nvidia-ctk-installer
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubi8
ENV NVIDIA_DISABLE_REQUIRE="true"
ENV NVIDIA_VISIBLE_DEVICES=void
ENV NVIDIA_DRIVER_CAPABILITIES=utility
ARG ARTIFACTS_ROOT
ARG PACKAGE_DIST
COPY ${ARTIFACTS_ROOT}/${PACKAGE_DIST} /artifacts/packages/${PACKAGE_DIST}
WORKDIR /artifacts/packages
ARG PACKAGE_VERSION
ARG TARGETARCH
ENV PACKAGE_ARCH=${TARGETARCH}
RUN PACKAGE_ARCH=${PACKAGE_ARCH/amd64/x86_64} && PACKAGE_ARCH=${PACKAGE_ARCH/arm64/aarch64} && \
yum localinstall -y \
${PACKAGE_DIST}/${PACKAGE_ARCH}/libnvidia-container1-1.*.rpm \
${PACKAGE_DIST}/${PACKAGE_ARCH}/libnvidia-container-tools-1.*.rpm \
${PACKAGE_DIST}/${PACKAGE_ARCH}/nvidia-container-toolkit*-${PACKAGE_VERSION}*.rpm
WORKDIR /work
COPY --from=build /artifacts/nvidia-ctk-installer /work/nvidia-ctk-installer
RUN ln -s nvidia-ctk-installer nvidia-toolkit
ENV PATH=/work:$PATH
ARG VERSION
LABEL io.k8s.display-name="NVIDIA Container Runtime Config"
LABEL name="NVIDIA Container Runtime Config"
LABEL vendor="NVIDIA"
LABEL version="${VERSION}"
LABEL release="N/A"
LABEL summary="Automatically Configure your Container Runtime for GPU support."
LABEL description="See summary"
RUN mkdir /licenses && mv /NGC-DL-CONTAINER-LICENSE /licenses/NGC-DL-CONTAINER-LICENSE
ENTRYPOINT ["/work/nvidia-ctk-installer"]

View File

@@ -0,0 +1,98 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ARG GOLANG_VERSION=x.x.x
ARG VERSION="N/A"
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubuntu20.04 AS build
RUN apt-get update && \
apt-get install -y wget make git gcc \
&& \
rm -rf /var/lib/apt/lists/*
ARG GOLANG_VERSION=x.x.x
RUN set -eux; \
\
arch="$(uname -m)"; \
case "${arch##*-}" in \
x86_64 | amd64) ARCH='amd64' ;; \
ppc64el | ppc64le) ARCH='ppc64le' ;; \
aarch64 | arm64) ARCH='arm64' ;; \
*) echo "unsupported architecture" ; exit 1 ;; \
esac; \
wget -nv -O - https://storage.googleapis.com/golang/go${GOLANG_VERSION}.linux-${ARCH}.tar.gz \
| tar -C /usr/local -xz
ENV GOPATH=/go
ENV PATH=$GOPATH/bin:/usr/local/go/bin:$PATH
WORKDIR /build
COPY . .
RUN mkdir /artifacts
ARG VERSION="N/A"
ARG GIT_COMMIT="unknown"
RUN make PREFIX=/artifacts cmd-nvidia-ctk-installer
FROM nvcr.io/nvidia/cuda:12.9.0-base-ubuntu20.04
# Remove the CUDA repository configurations to avoid issues with rotated GPG keys
RUN rm -f /etc/apt/sources.list.d/cuda.list
ARG DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y --no-install-recommends \
libcap2 \
curl \
&& \
rm -rf /var/lib/apt/lists/*
ENV NVIDIA_DISABLE_REQUIRE="true"
ENV NVIDIA_VISIBLE_DEVICES=void
ENV NVIDIA_DRIVER_CAPABILITIES=utility
ARG ARTIFACTS_ROOT
ARG PACKAGE_DIST
COPY ${ARTIFACTS_ROOT}/${PACKAGE_DIST} /artifacts/packages/${PACKAGE_DIST}
WORKDIR /artifacts/packages
ARG PACKAGE_VERSION
ARG TARGETARCH
ENV PACKAGE_ARCH=${TARGETARCH}
RUN dpkg -i \
${PACKAGE_DIST}/${PACKAGE_ARCH}/libnvidia-container1_1.*.deb \
${PACKAGE_DIST}/${PACKAGE_ARCH}/libnvidia-container-tools_1.*.deb \
${PACKAGE_DIST}/${PACKAGE_ARCH}/nvidia-container-toolkit*_${PACKAGE_VERSION}*.deb
WORKDIR /work
COPY --from=build /artifacts/nvidia-ctk-installer /work/nvidia-ctk-installer
RUN ln -s nvidia-ctk-installer nvidia-toolkit
ENV PATH=/work:$PATH
ARG VERSION
LABEL io.k8s.display-name="NVIDIA Container Runtime Config"
LABEL name="NVIDIA Container Runtime Config"
LABEL vendor="NVIDIA"
LABEL version="${VERSION}"
LABEL release="N/A"
LABEL summary="Automatically Configure your Container Runtime for GPU support."
LABEL description="See summary"
RUN mkdir /licenses && mv /NGC-DL-CONTAINER-LICENSE /licenses/NGC-DL-CONTAINER-LICENSE
ENTRYPOINT ["/work/nvidia-ctk-installer"]

View File

@@ -29,17 +29,17 @@ include $(CURDIR)/versions.mk
IMAGE_VERSION := $(VERSION) IMAGE_VERSION := $(VERSION)
IMAGE_TAG ?= $(VERSION) IMAGE_TAG ?= $(VERSION)-$(DIST)
IMAGE = $(IMAGE_NAME):$(IMAGE_TAG) IMAGE = $(IMAGE_NAME):$(IMAGE_TAG)
OUT_IMAGE_NAME ?= $(IMAGE_NAME) OUT_IMAGE_NAME ?= $(IMAGE_NAME)
OUT_IMAGE_VERSION ?= $(IMAGE_VERSION) OUT_IMAGE_VERSION ?= $(IMAGE_VERSION)
OUT_IMAGE_TAG = $(OUT_IMAGE_VERSION) OUT_IMAGE_TAG = $(OUT_IMAGE_VERSION)-$(DIST)
OUT_IMAGE = $(OUT_IMAGE_NAME):$(OUT_IMAGE_TAG) OUT_IMAGE = $(OUT_IMAGE_NAME):$(OUT_IMAGE_TAG)
##### Public rules ##### ##### Public rules #####
DEFAULT_PUSH_TARGET := ubi9 DEFAULT_PUSH_TARGET := ubuntu20.04
DISTRIBUTIONS := $(DEFAULT_PUSH_TARGET) DISTRIBUTIONS := ubuntu20.04 ubi8
META_TARGETS := packaging META_TARGETS := packaging
@@ -56,16 +56,30 @@ else
include $(CURDIR)/deployments/container/multi-arch.mk include $(CURDIR)/deployments/container/multi-arch.mk
endif endif
# For the default push target we also push a short tag equal to the version.
# We skip this for the development release
DEVEL_RELEASE_IMAGE_VERSION ?= devel
PUSH_MULTIPLE_TAGS ?= true
ifeq ($(strip $(OUT_IMAGE_VERSION)),$(DEVEL_RELEASE_IMAGE_VERSION))
PUSH_MULTIPLE_TAGS = false
endif
ifeq ($(PUSH_MULTIPLE_TAGS),true)
push-$(DEFAULT_PUSH_TARGET): push-short
endif
push-%: DIST = $(*)
push-short: DIST = $(DEFAULT_PUSH_TARGET)
# Define the push targets # Define the push targets
$(PUSH_TARGETS): push-%: $(PUSH_TARGETS): push-%:
$(CURDIR)/scripts/publish-image.sh $(IMAGE) $(OUT_IMAGE) $(CURDIR)/scripts/publish-image.sh $(IMAGE) $(OUT_IMAGE)
DOCKERFILE = $(CURDIR)/deployments/container/Dockerfile push-short:
$(CURDIR)/scripts/publish-image.sh $(IMAGE) $(OUT_IMAGE)
# For packaging targets we set the output image tag to include the -packaging suffix.
%-packaging: INTERMEDIATE_TARGET := --target=packaging build-%: DIST = $(*)
%-packaging: IMAGE_TAG = $(IMAGE_VERSION)-packaging build-%: DOCKERFILE = $(CURDIR)/deployments/container/Dockerfile.$(DOCKERFILE_SUFFIX)
%-packaging: OUT_IMAGE_TAG = $(IMAGE_VERSION)-packaging
ARTIFACTS_ROOT ?= $(shell realpath --relative-to=$(CURDIR) $(DIST_DIR)) ARTIFACTS_ROOT ?= $(shell realpath --relative-to=$(CURDIR) $(DIST_DIR))
@@ -76,12 +90,10 @@ $(IMAGE_TARGETS): image-%: $(ARTIFACTS_ROOT)
--provenance=false --sbom=false \ --provenance=false --sbom=false \
$(DOCKER_BUILD_OPTIONS) \ $(DOCKER_BUILD_OPTIONS) \
$(DOCKER_BUILD_PLATFORM_OPTIONS) \ $(DOCKER_BUILD_PLATFORM_OPTIONS) \
$(INTERMEDIATE_TARGET) \
--tag $(IMAGE) \ --tag $(IMAGE) \
--build-arg ARTIFACTS_ROOT="$(ARTIFACTS_ROOT)" \ --build-arg ARTIFACTS_ROOT="$(ARTIFACTS_ROOT)" \
--build-arg GOLANG_VERSION="$(GOLANG_VERSION)" \ --build-arg GOLANG_VERSION="$(GOLANG_VERSION)" \
--build-arg PACKAGE_DIST_DEB="$(PACKAGE_DIST_DEB)" \ --build-arg PACKAGE_DIST="$(PACKAGE_DIST)" \
--build-arg PACKAGE_DIST_RPM="$(PACKAGE_DIST_RPM)" \
--build-arg PACKAGE_VERSION="$(PACKAGE_VERSION)" \ --build-arg PACKAGE_VERSION="$(PACKAGE_VERSION)" \
--build-arg VERSION="$(VERSION)" \ --build-arg VERSION="$(VERSION)" \
--build-arg GIT_COMMIT="$(GIT_COMMIT)" \ --build-arg GIT_COMMIT="$(GIT_COMMIT)" \
@@ -91,19 +103,25 @@ $(IMAGE_TARGETS): image-%: $(ARTIFACTS_ROOT)
-f $(DOCKERFILE) \ -f $(DOCKERFILE) \
$(CURDIR) $(CURDIR)
build-ubuntu%: DOCKERFILE_SUFFIX := ubuntu
build-ubuntu%: PACKAGE_DIST = ubuntu18.04
PACKAGE_DIST_DEB = ubuntu18.04 build-ubi8: DOCKERFILE_SUFFIX := ubi8
# TODO: This needs to be set to centos8 for ppc64le builds build-ubi8: PACKAGE_DIST = centos7
PACKAGE_DIST_RPM = centos7
# Handle the default build target. build-packaging: DOCKERFILE_SUFFIX := packaging
.PHONY: build push build-packaging: PACKAGE_ARCH := amd64
build: build-$(DEFAULT_PUSH_TARGET) build-packaging: PACKAGE_DIST = all
push: push-$(DEFAULT_PUSH_TARGET)
# Test targets # Test targets
test-%: DIST = $(*) test-%: DIST = $(*)
# Handle the default build target.
.PHONY: build
build: $(DEFAULT_PUSH_TARGET)
$(DEFAULT_PUSH_TARGET): build-$(DEFAULT_PUSH_TARGET)
$(DEFAULT_PUSH_TARGET): DIST = $(DEFAULT_PUSH_TARGET)
TEST_CASES ?= docker crio containerd TEST_CASES ?= docker crio containerd
$(TEST_TARGETS): test-%: $(TEST_TARGETS): test-%:
TEST_CASES="$(TEST_CASES)" bash -x $(CURDIR)/test/container/main.sh run \ TEST_CASES="$(TEST_CASES)" bash -x $(CURDIR)/test/container/main.sh run \

View File

@@ -23,3 +23,11 @@ $(BUILD_TARGETS): build-%: image-%
else else
$(BUILD_TARGETS): build-%: image-% $(BUILD_TARGETS): build-%: image-%
endif endif
# For the default distribution we also retag the image.
# Note: This needs to be updated for multi-arch images.
ifeq ($(IMAGE_TAG),$(VERSION)-$(DIST))
$(DEFAULT_PUSH_TARGET):
$(DOCKER) image inspect $(IMAGE) > /dev/null || $(DOCKER) pull $(IMAGE)
$(DOCKER) tag $(IMAGE) $(subst :$(IMAGE_TAG),:$(VERSION),$(IMAGE))
endif

View File

@@ -14,7 +14,7 @@
# This Dockerfile is also used to define the golang version used in this project # This Dockerfile is also used to define the golang version used in this project
# This allows dependabot to manage this version in addition to other images. # This allows dependabot to manage this version in addition to other images.
FROM golang:1.24.4 FROM golang:1.24.3
WORKDIR /work WORKDIR /work
COPY * . COPY * .

View File

@@ -1,28 +0,0 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[Unit]
Description=Refresh NVIDIA CDI specification file
ConditionPathExists=/usr/bin/nvidia-smi
ConditionPathExists=/usr/bin/nvidia-ctk
[Service]
Type=oneshot
EnvironmentFile=-/etc/nvidia-container-toolkit/cdi-refresh.env
ExecCondition=/usr/bin/grep -qE '/nvidia.ko' /lib/modules/%v/modules.dep
ExecStart=/usr/bin/nvidia-ctk cdi generate --output=/var/run/cdi/nvidia.yaml
CapabilityBoundingSet=CAP_SYS_MODULE CAP_SYS_ADMIN CAP_MKNOD
[Install]
WantedBy=multi-user.target

View File

@@ -55,7 +55,6 @@ RUN make PREFIX=${DIST_DIR} cmds
WORKDIR $DIST_DIR WORKDIR $DIST_DIR
COPY packaging/debian ./debian COPY packaging/debian ./debian
COPY deployments/systemd/ .
ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION
ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION} ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION}

View File

@@ -46,7 +46,6 @@ RUN make PREFIX=${DIST_DIR} cmds
WORKDIR $DIST_DIR/.. WORKDIR $DIST_DIR/..
COPY packaging/rpm . COPY packaging/rpm .
COPY deployments/systemd/ .
ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION
ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION} ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION}

View File

@@ -71,7 +71,6 @@ RUN make PREFIX=${DIST_DIR} cmds
WORKDIR $DIST_DIR/.. WORKDIR $DIST_DIR/..
COPY packaging/rpm . COPY packaging/rpm .
COPY deployments/systemd/ ${DIST_DIR}/
ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION
ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION} ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION}

View File

@@ -53,7 +53,6 @@ RUN make PREFIX=${DIST_DIR} cmds
WORKDIR $DIST_DIR WORKDIR $DIST_DIR
COPY packaging/debian ./debian COPY packaging/debian ./debian
COPY deployments/systemd/ .
ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION ARG LIBNVIDIA_CONTAINER_TOOLS_VERSION
ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION} ENV LIBNVIDIA_CONTAINER_TOOLS_VERSION ${LIBNVIDIA_CONTAINER_TOOLS_VERSION}

11
go.mod
View File

@@ -3,8 +3,8 @@ module github.com/NVIDIA/nvidia-container-toolkit
go 1.23.0 go 1.23.0
require ( require (
github.com/NVIDIA/go-nvlib v0.7.3 github.com/NVIDIA/go-nvlib v0.7.2
github.com/NVIDIA/go-nvml v0.12.9-0 github.com/NVIDIA/go-nvml v0.12.4-1
github.com/cyphar/filepath-securejoin v0.4.1 github.com/cyphar/filepath-securejoin v0.4.1
github.com/moby/sys/reexec v0.1.0 github.com/moby/sys/reexec v0.1.0
github.com/moby/sys/symlink v0.3.0 github.com/moby/sys/symlink v0.3.0
@@ -13,21 +13,20 @@ require (
github.com/pelletier/go-toml v1.9.5 github.com/pelletier/go-toml v1.9.5
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.10.0 github.com/stretchr/testify v1.10.0
github.com/urfave/cli/v2 v2.27.7 github.com/urfave/cli/v2 v2.27.6
golang.org/x/mod v0.25.0 golang.org/x/mod v0.24.0
golang.org/x/sys v0.33.0 golang.org/x/sys v0.33.0
tags.cncf.io/container-device-interface v1.0.1 tags.cncf.io/container-device-interface v1.0.1
tags.cncf.io/container-device-interface/specs-go v1.0.0 tags.cncf.io/container-device-interface/specs-go v1.0.0
) )
require ( require (
github.com/cpuguy83/go-md2man/v2 v2.0.7 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect github.com/kr/pretty v0.3.1 // indirect
github.com/opencontainers/cgroups v0.0.1 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 // indirect github.com/opencontainers/runtime-tools v0.9.1-0.20221107090550-2e043c6bd626 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect

22
go.sum
View File

@@ -1,11 +1,11 @@
github.com/NVIDIA/go-nvlib v0.7.3 h1:kXc8PkWUlrwedSpM4fR8xT/DAq1NKy8HqhpgteFcGAw= github.com/NVIDIA/go-nvlib v0.7.2 h1:7sy/NVUa4sM9FLKwH6CjBfHSWrJUmv8emVyxLTzjfOA=
github.com/NVIDIA/go-nvlib v0.7.3/go.mod h1:i95Je7GinMy/+BDs++DAdbPmT2TubjNP8i8joC7DD7I= github.com/NVIDIA/go-nvlib v0.7.2/go.mod h1:2Kh2kYSP5IJ8EKf0/SYDzHiQKb9EJkwOf2LQzu6pXzY=
github.com/NVIDIA/go-nvml v0.12.9-0 h1:e344UK8ZkeMeeLkdQtRhmXRxNf+u532LDZPGMtkdus0= github.com/NVIDIA/go-nvml v0.12.4-1 h1:WKUvqshhWSNTfm47ETRhv0A0zJyr1ncCuHiXwoTrBEc=
github.com/NVIDIA/go-nvml v0.12.9-0/go.mod h1:+KNA7c7gIBH7SKSJ1ntlwkfN80zdx8ovl4hrK3LmPt4= github.com/NVIDIA/go-nvml v0.12.4-1/go.mod h1:8Llmj+1Rr+9VGGwZuRer5N/aCjxGuR5nPb/9ebBiIEQ=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ=
github.com/cpuguy83/go-md2man/v2 v2.0.7 h1:zbFlGlXEAKlwXpmvle3d8Oe3YnkKIK4xSRTd3sHPnBo= github.com/cpuguy83/go-md2man/v2 v2.0.5 h1:ZtcqGrnekaHpVLArFSe4HK5DoKx1T0rq2DwVB0alcyc=
github.com/cpuguy83/go-md2man/v2 v2.0.7/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/cpuguy83/go-md2man/v2 v2.0.5/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s=
github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
@@ -37,8 +37,6 @@ github.com/moby/sys/reexec v0.1.0/go.mod h1:EqjBg8F3X7iZe5pU6nRZnYCMUTXoxsjiIfHu
github.com/moby/sys/symlink v0.3.0 h1:GZX89mEZ9u53f97npBy4Rc3vJKj7JBDj/PN2I22GrNU= github.com/moby/sys/symlink v0.3.0 h1:GZX89mEZ9u53f97npBy4Rc3vJKj7JBDj/PN2I22GrNU=
github.com/moby/sys/symlink v0.3.0/go.mod h1:3eNdhduHmYPcgsJtZXW1W4XUJdZGBIkttZ8xKqPUJq0= github.com/moby/sys/symlink v0.3.0/go.mod h1:3eNdhduHmYPcgsJtZXW1W4XUJdZGBIkttZ8xKqPUJq0=
github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ= github.com/mrunalp/fileutils v0.5.0/go.mod h1:M1WthSahJixYnrXQl/DFQuteStB1weuxD2QJNHXfbSQ=
github.com/opencontainers/cgroups v0.0.1 h1:MXjMkkFpKv6kpuirUa4USFBas573sSAY082B4CiHEVA=
github.com/opencontainers/cgroups v0.0.1/go.mod h1:s8lktyhlGUqM7OSRL5P7eAW6Wb+kWPNvt4qvVfzA5vs=
github.com/opencontainers/runc v1.3.0 h1:cvP7xbEvD0QQAs0nZKLzkVog2OPZhI/V2w3WmTmUSXI= github.com/opencontainers/runc v1.3.0 h1:cvP7xbEvD0QQAs0nZKLzkVog2OPZhI/V2w3WmTmUSXI=
github.com/opencontainers/runc v1.3.0/go.mod h1:9wbWt42gV+KRxKRVVugNP6D5+PQciRbenB4fLVsqGPs= github.com/opencontainers/runc v1.3.0/go.mod h1:9wbWt42gV+KRxKRVVugNP6D5+PQciRbenB4fLVsqGPs=
github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= github.com/opencontainers/runtime-spec v1.0.3-0.20220825212826-86290f6a00fb/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0=
@@ -71,8 +69,8 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 h1:kdXcSzyDtseVEc4yCz2qF8ZrQvIDBJLl4S1c3GCXmoI= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 h1:kdXcSzyDtseVEc4yCz2qF8ZrQvIDBJLl4S1c3GCXmoI=
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
github.com/urfave/cli v1.19.1/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.19.1/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli/v2 v2.27.7 h1:bH59vdhbjLv3LAvIu6gd0usJHgoTTPhCFib8qqOwXYU= github.com/urfave/cli/v2 v2.27.6 h1:VdRdS98FNhKZ8/Az8B7MTyGQmpIr36O1EHybx/LaZ4g=
github.com/urfave/cli/v2 v2.27.7/go.mod h1:CyNAG/xg+iAOg0N4MPGZqVmv2rCoP267496AOXUZjA4= github.com/urfave/cli/v2 v2.27.6/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ=
github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo=
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU=
@@ -82,8 +80,8 @@ github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17
github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191115151921-52ab43148777/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -110,7 +110,7 @@ func GetDefault() (*Config, error) {
NVIDIAContainerRuntimeConfig: RuntimeConfig{ NVIDIAContainerRuntimeConfig: RuntimeConfig{
DebugFilePath: "/dev/null", DebugFilePath: "/dev/null",
LogLevel: "info", LogLevel: "info",
Runtimes: []string{"runc", "crun"}, Runtimes: []string{"docker-runc", "runc", "crun"},
Mode: "auto", Mode: "auto",
Modes: modesConfig{ Modes: modesConfig{
CSV: csvModeConfig{ CSV: csvModeConfig{

View File

@@ -63,7 +63,7 @@ func TestGetConfig(t *testing.T) {
NVIDIAContainerRuntimeConfig: RuntimeConfig{ NVIDIAContainerRuntimeConfig: RuntimeConfig{
DebugFilePath: "/dev/null", DebugFilePath: "/dev/null",
LogLevel: "info", LogLevel: "info",
Runtimes: []string{"runc", "crun"}, Runtimes: []string{"docker-runc", "runc", "crun"},
Mode: "auto", Mode: "auto",
Modes: modesConfig{ Modes: modesConfig{
CSV: csvModeConfig{ CSV: csvModeConfig{
@@ -170,7 +170,7 @@ func TestGetConfig(t *testing.T) {
NVIDIAContainerRuntimeConfig: RuntimeConfig{ NVIDIAContainerRuntimeConfig: RuntimeConfig{
DebugFilePath: "/dev/null", DebugFilePath: "/dev/null",
LogLevel: "info", LogLevel: "info",
Runtimes: []string{"runc", "crun"}, Runtimes: []string{"docker-runc", "runc", "crun"},
Mode: "auto", Mode: "auto",
Modes: modesConfig{ Modes: modesConfig{
CSV: csvModeConfig{ CSV: csvModeConfig{
@@ -289,7 +289,7 @@ func TestGetConfig(t *testing.T) {
NVIDIAContainerRuntimeConfig: RuntimeConfig{ NVIDIAContainerRuntimeConfig: RuntimeConfig{
DebugFilePath: "/dev/null", DebugFilePath: "/dev/null",
LogLevel: "info", LogLevel: "info",
Runtimes: []string{"runc", "crun"}, Runtimes: []string{"docker-runc", "runc", "crun"},
Mode: "auto", Mode: "auto",
Modes: modesConfig{ Modes: modesConfig{
CSV: csvModeConfig{ CSV: csvModeConfig{
@@ -331,7 +331,7 @@ func TestGetConfig(t *testing.T) {
NVIDIAContainerRuntimeConfig: RuntimeConfig{ NVIDIAContainerRuntimeConfig: RuntimeConfig{
DebugFilePath: "/dev/null", DebugFilePath: "/dev/null",
LogLevel: "info", LogLevel: "info",
Runtimes: []string{"runc", "crun"}, Runtimes: []string{"docker-runc", "runc", "crun"},
Mode: "auto", Mode: "auto",
Modes: modesConfig{ Modes: modesConfig{
CSV: csvModeConfig{ CSV: csvModeConfig{

View File

@@ -21,35 +21,22 @@ import (
"strings" "strings"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
) )
type builder struct { type builder struct {
CUDA env map[string]string
mounts []specs.Mount
disableRequire bool disableRequire bool
} }
// Option is a functional option for creating a CUDA image.
type Option func(*builder) error
// New creates a new CUDA image from the input options. // New creates a new CUDA image from the input options.
func New(opt ...Option) (CUDA, error) { func New(opt ...Option) (CUDA, error) {
b := &builder{ b := &builder{}
CUDA: CUDA{
acceptEnvvarUnprivileged: true,
},
}
for _, o := range opt { for _, o := range opt {
if err := o(b); err != nil { if err := o(b); err != nil {
return CUDA{}, err return CUDA{}, err
} }
} }
if b.logger == nil {
b.logger = logger.New()
}
if b.env == nil { if b.env == nil {
b.env = make(map[string]string) b.env = make(map[string]string)
} }
@@ -63,36 +50,15 @@ func (b builder) build() (CUDA, error) {
b.env[EnvVarNvidiaDisableRequire] = "true" b.env[EnvVarNvidiaDisableRequire] = "true"
} }
return b.CUDA, nil c := CUDA{
env: b.env,
mounts: b.mounts,
}
return c, nil
} }
func WithAcceptDeviceListAsVolumeMounts(acceptDeviceListAsVolumeMounts bool) Option { // Option is a functional option for creating a CUDA image.
return func(b *builder) error { type Option func(*builder) error
b.acceptDeviceListAsVolumeMounts = acceptDeviceListAsVolumeMounts
return nil
}
}
func WithAcceptEnvvarUnprivileged(acceptEnvvarUnprivileged bool) Option {
return func(b *builder) error {
b.acceptEnvvarUnprivileged = acceptEnvvarUnprivileged
return nil
}
}
func WithAnnotations(annotations map[string]string) Option {
return func(b *builder) error {
b.annotations = annotations
return nil
}
}
func WithAnnotationsPrefixes(annotationsPrefixes []string) Option {
return func(b *builder) error {
b.annotationsPrefixes = annotationsPrefixes
return nil
}
}
// WithDisableRequire sets the disable require option. // WithDisableRequire sets the disable require option.
func WithDisableRequire(disableRequire bool) Option { func WithDisableRequire(disableRequire bool) Option {
@@ -127,14 +93,6 @@ func WithEnvMap(env map[string]string) Option {
} }
} }
// WithLogger sets the logger to use when creating the CUDA image.
func WithLogger(logger logger.Interface) Option {
return func(b *builder) error {
b.logger = logger
return nil
}
}
// WithMounts sets the mounts associated with the CUDA image. // WithMounts sets the mounts associated with the CUDA image.
func WithMounts(mounts []specs.Mount) Option { func WithMounts(mounts []specs.Mount) Option {
return func(b *builder) error { return func(b *builder) error {
@@ -142,20 +100,3 @@ func WithMounts(mounts []specs.Mount) Option {
return nil return nil
} }
} }
// WithPreferredVisibleDevicesEnvVars sets the environment variables that
// should take precedence over the default NVIDIA_VISIBLE_DEVICES.
func WithPreferredVisibleDevicesEnvVars(preferredVisibleDeviceEnvVars ...string) Option {
return func(b *builder) error {
b.preferredVisibleDeviceEnvVars = preferredVisibleDeviceEnvVars
return nil
}
}
// WithPrivileged sets whether an image is privileged or not.
func WithPrivileged(isPrivileged bool) Option {
return func(b *builder) error {
b.isPrivileged = isPrivileged
return nil
}
}

View File

@@ -19,15 +19,12 @@ package image
import ( import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"slices"
"strconv" "strconv"
"strings" "strings"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
"tags.cncf.io/container-device-interface/pkg/parser" "tags.cncf.io/container-device-interface/pkg/parser"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
) )
const ( const (
@@ -41,44 +38,27 @@ const (
// a map of environment variable to values that can be used to perform lookups // a map of environment variable to values that can be used to perform lookups
// such as requirements. // such as requirements.
type CUDA struct { type CUDA struct {
logger logger.Interface env map[string]string
mounts []specs.Mount
annotations map[string]string
env map[string]string
isPrivileged bool
mounts []specs.Mount
annotationsPrefixes []string
acceptDeviceListAsVolumeMounts bool
acceptEnvvarUnprivileged bool
preferredVisibleDeviceEnvVars []string
} }
// NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec. // NewCUDAImageFromSpec creates a CUDA image from the input OCI runtime spec.
// The process environment is read (if present) to construc the CUDA Image. // The process environment is read (if present) to construc the CUDA Image.
func NewCUDAImageFromSpec(spec *specs.Spec, opts ...Option) (CUDA, error) { func NewCUDAImageFromSpec(spec *specs.Spec) (CUDA, error) {
if spec == nil {
return New(opts...)
}
var env []string var env []string
if spec.Process != nil { if spec != nil && spec.Process != nil {
env = spec.Process.Env env = spec.Process.Env
} }
specOpts := []Option{ return New(
WithAnnotations(spec.Annotations),
WithEnv(env), WithEnv(env),
WithMounts(spec.Mounts), WithMounts(spec.Mounts),
WithPrivileged(IsPrivileged((*OCISpec)(spec))), )
}
return New(append(opts, specOpts...)...)
} }
// newCUDAImageFromEnv creates a CUDA image from the input environment. The environment // NewCUDAImageFromEnv creates a CUDA image from the input environment. The environment
// is a list of strings of the form ENVAR=VALUE. // is a list of strings of the form ENVAR=VALUE.
func newCUDAImageFromEnv(env []string) (CUDA, error) { func NewCUDAImageFromEnv(env []string) (CUDA, error) {
return New(WithEnv(env)) return New(WithEnv(env))
} }
@@ -103,10 +83,6 @@ func (i CUDA) IsLegacy() bool {
return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0 return len(legacyCudaVersion) > 0 && len(cudaRequire) == 0
} }
func (i CUDA) IsPrivileged() bool {
return i.isPrivileged
}
// GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment // GetRequirements returns the requirements from all NVIDIA_REQUIRE_ environment
// variables. // variables.
func (i CUDA) GetRequirements() ([]string, error) { func (i CUDA) GetRequirements() ([]string, error) {
@@ -144,8 +120,8 @@ func (i CUDA) HasDisableRequire() bool {
return false return false
} }
// devicesFromEnvvars returns the devices requested by the image through environment variables // DevicesFromEnvvars returns the devices requested by the image through environment variables
func (i CUDA) devicesFromEnvvars(envVars ...string) []string { func (i CUDA) DevicesFromEnvvars(envVars ...string) VisibleDevices {
// We concantenate all the devices from the specified env. // We concantenate all the devices from the specified env.
var isSet bool var isSet bool
var devices []string var devices []string
@@ -166,15 +142,15 @@ func (i CUDA) devicesFromEnvvars(envVars ...string) []string {
// Environment variable unset with legacy image: default to "all". // Environment variable unset with legacy image: default to "all".
if !isSet && len(devices) == 0 && i.IsLegacy() { if !isSet && len(devices) == 0 && i.IsLegacy() {
devices = []string{"all"} return NewVisibleDevices("all")
} }
// Environment variable unset or empty or "void": return nil // Environment variable unset or empty or "void": return nil
if len(devices) == 0 || requested["void"] { if len(devices) == 0 || requested["void"] {
devices = []string{"void"} return NewVisibleDevices("void")
} }
return NewVisibleDevices(devices...).List() return NewVisibleDevices(devices...)
} }
// GetDriverCapabilities returns the requested driver capabilities. // GetDriverCapabilities returns the requested driver capabilities.
@@ -224,137 +200,46 @@ func parseMajorMinorVersion(version string) (string, error) {
// OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/ // OnlyFullyQualifiedCDIDevices returns true if all devices requested in the image are requested as CDI devices/
func (i CUDA) OnlyFullyQualifiedCDIDevices() bool { func (i CUDA) OnlyFullyQualifiedCDIDevices() bool {
var hasCDIdevice bool var hasCDIdevice bool
for _, device := range i.VisibleDevices() { for _, device := range i.VisibleDevicesFromEnvVar() {
if !parser.IsQualifiedName(device) { if !parser.IsQualifiedName(device) {
return false return false
} }
hasCDIdevice = true hasCDIdevice = true
} }
for _, device := range i.DevicesFromMounts() {
if !strings.HasPrefix(device, "cdi/") {
return false
}
hasCDIdevice = true
}
return hasCDIdevice return hasCDIdevice
} }
// visibleEnvVars returns the environment variables that are used to determine device visibility. // VisibleDevicesFromEnvVar returns the set of visible devices requested through
// It returns the preferred environment variables that are set, or NVIDIA_VISIBLE_DEVICES if none are set. // the NVIDIA_VISIBLE_DEVICES environment variable.
func (i CUDA) visibleEnvVars() []string { func (i CUDA) VisibleDevicesFromEnvVar() []string {
var envVars []string return i.DevicesFromEnvvars(EnvVarNvidiaVisibleDevices).List()
for _, envVar := range i.preferredVisibleDeviceEnvVars {
if !i.HasEnvvar(envVar) {
continue
}
envVars = append(envVars, envVar)
}
if len(envVars) > 0 {
return envVars
}
return []string{EnvVarNvidiaVisibleDevices}
} }
// VisibleDevices returns a list of devices requested in the container image. // VisibleDevicesFromMounts returns the set of visible devices requested as mounts.
// If volume mount requests are enabled these are returned if requested, func (i CUDA) VisibleDevicesFromMounts() []string {
// otherwise device requests through environment variables are considered.
// In cases where environment variable requests required privileged containers,
// such devices requests are ignored.
func (i CUDA) VisibleDevices() []string {
// If annotation device requests are present, these are preferred.
annotationDeviceRequests := i.cdiDeviceRequestsFromAnnotations()
if len(annotationDeviceRequests) > 0 {
return annotationDeviceRequests
}
// If enabled, try and get the device list from volume mounts first
if i.acceptDeviceListAsVolumeMounts {
volumeMountDeviceRequests := i.visibleDevicesFromMounts()
if len(volumeMountDeviceRequests) > 0 {
return volumeMountDeviceRequests
}
}
// Get the Fallback to reading from the environment variable if privileges are correct
envVarDeviceRequests := i.visibleDevicesFromEnvVar()
if len(envVarDeviceRequests) == 0 {
return nil
}
// If the container is privileged, or environment variable requests are
// allowed for unprivileged containers, these devices are returned.
if i.isPrivileged || i.acceptEnvvarUnprivileged {
return envVarDeviceRequests
}
// We log a warning if we are ignoring the environment variable requests.
envVars := i.visibleEnvVars()
if len(envVars) > 0 {
i.logger.Warningf("Ignoring devices requested by environment variable(s) in unprivileged container: %v", envVars)
}
return nil
}
// cdiDeviceRequestsFromAnnotations returns a list of devices specified in the
// annotations.
// Keys starting with the specified prefixes are considered and expected to
// contain a comma-separated list of fully-qualified CDI devices names.
// The format of the requested devices is not checked and the list is not
// deduplicated.
func (i CUDA) cdiDeviceRequestsFromAnnotations() []string {
if len(i.annotationsPrefixes) == 0 || len(i.annotations) == 0 {
return nil
}
var annotationKeys []string
for key := range i.annotations {
for _, prefix := range i.annotationsPrefixes {
if strings.HasPrefix(key, prefix) {
annotationKeys = append(annotationKeys, key)
// There is no need to check additional prefixes since we
// typically deduplicate devices in any case.
break
}
}
}
// We sort the annotationKeys for consistent results.
slices.Sort(annotationKeys)
var devices []string var devices []string
for _, key := range annotationKeys { for _, device := range i.DevicesFromMounts() {
devices = append(devices, strings.Split(i.annotations[key], ",")...)
}
return devices
}
// visibleDevicesFromEnvVar returns the set of visible devices requested through environment variables.
// If any of the preferredVisibleDeviceEnvVars are present in the image, they
// are used to determine the visible devices. If this is not the case, the
// NVIDIA_VISIBLE_DEVICES environment variable is used.
func (i CUDA) visibleDevicesFromEnvVar() []string {
envVars := i.visibleEnvVars()
return i.devicesFromEnvvars(envVars...)
}
// visibleDevicesFromMounts returns the set of visible devices requested as mounts.
func (i CUDA) visibleDevicesFromMounts() []string {
var devices []string
for _, device := range i.requestsFromMounts() {
switch { switch {
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
continue
case strings.HasPrefix(device, volumeMountDevicePrefixImex): case strings.HasPrefix(device, volumeMountDevicePrefixImex):
continue continue
case strings.HasPrefix(device, volumeMountDevicePrefixCDI):
name, err := cdiDeviceMountRequest(device).qualifiedName()
if err != nil {
i.logger.Warningf("Ignoring invalid mount request for CDI device %v: %v", device, err)
continue
}
devices = append(devices, name)
default:
devices = append(devices, device)
} }
devices = append(devices, device)
} }
return devices return devices
} }
// requestsFromMounts returns a list of device specified as mounts. // DevicesFromMounts returns a list of device specified as mounts.
func (i CUDA) requestsFromMounts() []string { // TODO: This should be merged with getDevicesFromMounts used in the NVIDIA Container Runtime
func (i CUDA) DevicesFromMounts() []string {
root := filepath.Clean(DeviceListAsVolumeMountsRoot) root := filepath.Clean(DeviceListAsVolumeMountsRoot)
seen := make(map[string]bool) seen := make(map[string]bool)
var devices []string var devices []string
@@ -386,35 +271,28 @@ func (i CUDA) requestsFromMounts() []string {
return devices return devices
} }
// a cdiDeviceMountRequest represents a CDI device requests as a mount. // CDIDevicesFromMounts returns a list of CDI devices specified as mounts on the image.
// Here the host path /dev/null is mounted to a particular path in the container. func (i CUDA) CDIDevicesFromMounts() []string {
// The container path has the form: var devices []string
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>/<device> for _, mountDevice := range i.DevicesFromMounts() {
// or if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixCDI) {
// /var/run/nvidia-container-devices/cdi/<vendor>/<class>=<device> continue
type cdiDeviceMountRequest string }
parts := strings.SplitN(strings.TrimPrefix(mountDevice, volumeMountDevicePrefixCDI), "/", 3)
// qualifiedName returns the fully-qualified name of the CDI device. if len(parts) != 3 {
func (m cdiDeviceMountRequest) qualifiedName() (string, error) { continue
if !strings.HasPrefix(string(m), volumeMountDevicePrefixCDI) { }
return "", fmt.Errorf("invalid mount CDI device request: %s", m) vendor := parts[0]
class := parts[1]
device := parts[2]
devices = append(devices, fmt.Sprintf("%s/%s=%s", vendor, class, device))
} }
return devices
requestedDevice := strings.TrimPrefix(string(m), volumeMountDevicePrefixCDI)
if parser.IsQualifiedName(requestedDevice) {
return requestedDevice, nil
}
parts := strings.SplitN(requestedDevice, "/", 3)
if len(parts) != 3 {
return "", fmt.Errorf("invalid mount CDI device request: %s", m)
}
return fmt.Sprintf("%s/%s=%s", parts[0], parts[1], parts[2]), nil
} }
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image. // ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromEnvVar() []string { func (i CUDA) ImexChannelsFromEnvVar() []string {
imexChannels := i.devicesFromEnvvars(EnvVarNvidiaImexChannels) imexChannels := i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List()
if len(imexChannels) == 1 && imexChannels[0] == "all" { if len(imexChannels) == 1 && imexChannels[0] == "all" {
return nil return nil
} }
@@ -424,7 +302,7 @@ func (i CUDA) ImexChannelsFromEnvVar() []string {
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image. // ImexChannelsFromMounts returns the list of IMEX channels requested for the image.
func (i CUDA) ImexChannelsFromMounts() []string { func (i CUDA) ImexChannelsFromMounts() []string {
var channels []string var channels []string
for _, mountDevice := range i.requestsFromMounts() { for _, mountDevice := range i.DevicesFromMounts() {
if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) { if !strings.HasPrefix(mountDevice, volumeMountDevicePrefixImex) {
continue continue
} }

View File

@@ -21,91 +21,9 @@ import (
"testing" "testing"
"github.com/opencontainers/runtime-spec/specs-go" "github.com/opencontainers/runtime-spec/specs-go"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestNewCUDAImageFromSpec(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct {
description string
spec *specs.Spec
options []Option
expected CUDA
}{
{
description: "no env vars",
spec: &specs.Spec{
Process: &specs.Process{
Env: []string{},
},
},
expected: CUDA{
logger: logger,
env: map[string]string{},
acceptEnvvarUnprivileged: true,
},
},
{
description: "NVIDIA_VISIBLE_DEVICES=all",
spec: &specs.Spec{
Process: &specs.Process{
Env: []string{"NVIDIA_VISIBLE_DEVICES=all"},
},
},
expected: CUDA{
logger: logger,
env: map[string]string{"NVIDIA_VISIBLE_DEVICES": "all"},
acceptEnvvarUnprivileged: true,
},
},
{
description: "Spec overrides options",
spec: &specs.Spec{
Process: &specs.Process{
Env: []string{"NVIDIA_VISIBLE_DEVICES=all"},
},
Mounts: []specs.Mount{
{
Source: "/spec-source",
Destination: "/spec-destination",
},
},
},
options: []Option{
WithEnvMap(map[string]string{"OTHER": "value"}),
WithMounts([]specs.Mount{
{
Source: "/option-source",
Destination: "/option-destination",
},
}),
},
expected: CUDA{
logger: logger,
env: map[string]string{"NVIDIA_VISIBLE_DEVICES": "all"},
mounts: []specs.Mount{
{
Source: "/spec-source",
Destination: "/spec-destination",
},
},
acceptEnvvarUnprivileged: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
options := append([]Option{WithLogger(logger)}, tc.options...)
image, err := NewCUDAImageFromSpec(tc.spec, options...)
require.NoError(t, err)
require.EqualValues(t, tc.expected, image)
})
}
}
func TestParseMajorMinorVersionValid(t *testing.T) { func TestParseMajorMinorVersionValid(t *testing.T) {
var tests = []struct { var tests = []struct {
version string version string
@@ -204,7 +122,7 @@ func TestGetRequirements(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, err := newCUDAImageFromEnv(tc.env) image, err := NewCUDAImageFromEnv(tc.env)
require.NoError(t, err) require.NoError(t, err)
requirements, err := image.GetRequirements() requirements, err := image.GetRequirements()
@@ -215,226 +133,6 @@ func TestGetRequirements(t *testing.T) {
} }
} }
func TestGetDevicesFromEnvvar(t *testing.T) {
envDockerResourceGPUs := "DOCKER_RESOURCE_GPUS"
gpuID := "GPU-12345"
anotherGPUID := "GPU-67890"
thirdGPUID := "MIG-12345"
var tests = []struct {
description string
preferredVisibleDeviceEnvVars []string
env map[string]string
expectedDevices []string
}{
{
description: "empty env returns nil for non-legacy image",
},
{
description: "blank NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: "",
},
},
{
description: "'void' NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: "void",
},
},
{
description: "'none' NVIDIA_VISIBLE_DEVICES returns empty for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: "none",
},
expectedDevices: []string{""},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
},
expectedDevices: []string{gpuID},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{gpuID},
},
{
description: "empty env returns all for legacy image",
env: map[string]string{
EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{"all"},
},
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is ignored when
// not enabled
{
description: "missing NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
envDockerResourceGPUs: anotherGPUID,
},
},
{
description: "blank NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: "",
envDockerResourceGPUs: anotherGPUID,
},
},
{
description: "'void' NVIDIA_VISIBLE_DEVICES returns nil for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: "void",
envDockerResourceGPUs: anotherGPUID,
},
},
{
description: "'none' NVIDIA_VISIBLE_DEVICES returns empty for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: "none",
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{""},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for non-legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{gpuID},
},
{
description: "NVIDIA_VISIBLE_DEVICES set returns value for legacy image",
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID,
EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{gpuID},
},
{
description: "empty env returns all for legacy image",
env: map[string]string{
envDockerResourceGPUs: anotherGPUID,
EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{"all"},
},
// Add the `DOCKER_RESOURCE_GPUS` envvar and ensure that this is selected when
// enabled
{
description: "empty env returns nil for non-legacy image",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
},
{
description: "blank DOCKER_RESOURCE_GPUS returns nil for non-legacy image",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: "",
},
},
{
description: "'void' DOCKER_RESOURCE_GPUS returns nil for non-legacy image",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: "void",
},
},
{
description: "'none' DOCKER_RESOURCE_GPUS returns empty for non-legacy image",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: "none",
},
expectedDevices: []string{""},
},
{
description: "DOCKER_RESOURCE_GPUS set returns value for non-legacy image",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: gpuID,
},
expectedDevices: []string{gpuID},
},
{
description: "DOCKER_RESOURCE_GPUS set returns value for legacy image",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: gpuID,
EnvVarCudaVersion: "legacy",
},
expectedDevices: []string{gpuID},
},
{
description: "DOCKER_RESOURCE_GPUS is selected if present",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
env: map[string]string{
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
{
description: "DOCKER_RESOURCE_GPUS overrides NVIDIA_VISIBLE_DEVICES if present",
preferredVisibleDeviceEnvVars: []string{envDockerResourceGPUs},
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
envDockerResourceGPUs: anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
{
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL overrides NVIDIA_VISIBLE_DEVICES if present",
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
{
description: "All available swarm resource envvars are selected and override NVIDIA_VISIBLE_DEVICES if present",
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS": thirdGPUID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
},
expectedDevices: []string{thirdGPUID, anotherGPUID},
},
{
description: "DOCKER_RESOURCE_GPUS_ADDITIONAL or DOCKER_RESOURCE_GPUS override NVIDIA_VISIBLE_DEVICES if present",
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
EnvVarNvidiaVisibleDevices: gpuID,
"DOCKER_RESOURCE_GPUS_ADDITIONAL": anotherGPUID,
},
expectedDevices: []string{anotherGPUID},
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
image, err := New(
WithEnvMap(tc.env),
WithPrivileged(true),
WithAcceptDeviceListAsVolumeMounts(false),
WithAcceptEnvvarUnprivileged(false),
WithPreferredVisibleDevicesEnvVars(tc.preferredVisibleDeviceEnvVars...),
)
require.NoError(t, err)
devices := image.visibleDevicesFromEnvVar()
require.EqualValues(t, tc.expectedDevices, devices)
})
}
}
func TestGetVisibleDevicesFromMounts(t *testing.T) { func TestGetVisibleDevicesFromMounts(t *testing.T) {
var tests = []struct { var tests = []struct {
description string description string
@@ -487,9 +185,9 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"}, expectedDevices: []string{"GPU0-MIG0/0/1", "GPU1-MIG0/0/1"},
}, },
{ {
description: "cdi devices are included", description: "cdi devices are ignored",
mounts: makeTestMounts("GPU0", "nvidia.com/gpu=all", "GPU1"), mounts: makeTestMounts("GPU0", "cdi/nvidia.com/gpu=all", "GPU1"),
expectedDevices: []string{"GPU0", "nvidia.com/gpu=all", "GPU1"}, expectedDevices: []string{"GPU0", "GPU1"},
}, },
{ {
description: "imex devices are ignored", description: "imex devices are ignored",
@@ -499,195 +197,8 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
image, err := New(WithMounts(tc.mounts)) image, _ := New(WithMounts(tc.mounts))
require.NoError(t, err) require.Equal(t, tc.expectedDevices, image.VisibleDevicesFromMounts())
require.Equal(t, tc.expectedDevices, image.visibleDevicesFromMounts())
})
}
}
func TestVisibleDevices(t *testing.T) {
var tests = []struct {
description string
mountDevices []specs.Mount
envvarDevices string
privileged bool
acceptUnprivileged bool
acceptMounts bool
preferredVisibleDeviceEnvVars []string
env map[string]string
expectedDevices []string
}{
{
description: "Mount devices, unprivileged, no accept unprivileged",
mountDevices: []specs.Mount{
{
Source: "/dev/null",
Destination: filepath.Join(DeviceListAsVolumeMountsRoot, "GPU0"),
},
{
Source: "/dev/null",
Destination: filepath.Join(DeviceListAsVolumeMountsRoot, "GPU1"),
},
},
envvarDevices: "GPU2,GPU3",
privileged: false,
acceptUnprivileged: false,
acceptMounts: true,
expectedDevices: []string{"GPU0", "GPU1"},
},
{
description: "No mount devices, unprivileged, no accept unprivileged",
mountDevices: nil,
envvarDevices: "GPU0,GPU1",
privileged: false,
acceptUnprivileged: false,
acceptMounts: true,
expectedDevices: nil,
},
{
description: "No mount devices, privileged, no accept unprivileged",
mountDevices: nil,
envvarDevices: "GPU0,GPU1",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
expectedDevices: []string{"GPU0", "GPU1"},
},
{
description: "No mount devices, unprivileged, accept unprivileged",
mountDevices: nil,
envvarDevices: "GPU0,GPU1",
privileged: false,
acceptUnprivileged: true,
acceptMounts: true,
expectedDevices: []string{"GPU0", "GPU1"},
},
{
description: "Mount devices, unprivileged, accept unprivileged, no accept mounts",
mountDevices: []specs.Mount{
{
Source: "/dev/null",
Destination: filepath.Join(DeviceListAsVolumeMountsRoot, "GPU0"),
},
{
Source: "/dev/null",
Destination: filepath.Join(DeviceListAsVolumeMountsRoot, "GPU1"),
},
},
envvarDevices: "GPU2,GPU3",
privileged: false,
acceptUnprivileged: true,
acceptMounts: false,
expectedDevices: []string{"GPU2", "GPU3"},
},
{
description: "Mount devices, unprivileged, no accept unprivileged, no accept mounts",
mountDevices: []specs.Mount{
{
Source: "/dev/null",
Destination: filepath.Join(DeviceListAsVolumeMountsRoot, "GPU0"),
},
{
Source: "/dev/null",
Destination: filepath.Join(DeviceListAsVolumeMountsRoot, "GPU1"),
},
},
envvarDevices: "GPU2,GPU3",
privileged: false,
acceptUnprivileged: false,
acceptMounts: false,
expectedDevices: nil,
},
// New test cases for visibleEnvVars functionality
{
description: "preferred env var set and present in env, privileged",
mountDevices: nil,
envvarDevices: "",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
},
expectedDevices: []string{"GPU-12345"},
},
{
description: "preferred env var set and present in env, unprivileged but accepted",
mountDevices: nil,
envvarDevices: "",
privileged: false,
acceptUnprivileged: true,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
},
expectedDevices: []string{"GPU-12345"},
},
{
description: "preferred env var set and present in env, unprivileged and not accepted",
mountDevices: nil,
envvarDevices: "",
privileged: false,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
},
expectedDevices: nil,
},
{
description: "multiple preferred env vars, both present, privileged",
mountDevices: nil,
envvarDevices: "",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS", "DOCKER_RESOURCE_GPUS_ADDITIONAL"},
env: map[string]string{
"DOCKER_RESOURCE_GPUS": "GPU-12345",
"DOCKER_RESOURCE_GPUS_ADDITIONAL": "GPU-67890",
},
expectedDevices: []string{"GPU-12345", "GPU-67890"},
},
{
description: "preferred env var not present, fallback to NVIDIA_VISIBLE_DEVICES, privileged",
mountDevices: nil,
envvarDevices: "GPU-12345",
privileged: true,
acceptUnprivileged: false,
acceptMounts: true,
preferredVisibleDeviceEnvVars: []string{"DOCKER_RESOURCE_GPUS"},
env: map[string]string{
EnvVarNvidiaVisibleDevices: "GPU-12345",
},
expectedDevices: []string{"GPU-12345"},
},
}
for _, tc := range tests {
t.Run(tc.description, func(t *testing.T) {
// Create env map with both NVIDIA_VISIBLE_DEVICES and any additional env vars
env := make(map[string]string)
if tc.envvarDevices != "" {
env[EnvVarNvidiaVisibleDevices] = tc.envvarDevices
}
for k, v := range tc.env {
env[k] = v
}
image, err := New(
WithEnvMap(env),
WithMounts(tc.mountDevices),
WithPrivileged(tc.privileged),
WithAcceptDeviceListAsVolumeMounts(tc.acceptMounts),
WithAcceptEnvvarUnprivileged(tc.acceptUnprivileged),
WithPreferredVisibleDevicesEnvVars(tc.preferredVisibleDeviceEnvVars...),
)
require.NoError(t, err)
require.Equal(t, tc.expectedDevices, image.VisibleDevices())
}) })
} }
} }
@@ -713,7 +224,7 @@ func TestImexChannelsFromEnvVar(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} { for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} {
t.Run(tc.description+id, func(t *testing.T) { t.Run(tc.description+id, func(t *testing.T) {
i, err := newCUDAImageFromEnv(append(baseEnvvars, tc.env...)) i, err := NewCUDAImageFromEnv(append(baseEnvvars, tc.env...))
require.NoError(t, err) require.NoError(t, err)
channels := i.ImexChannelsFromEnvVar() channels := i.ImexChannelsFromEnvVar()
@@ -723,73 +234,6 @@ func TestImexChannelsFromEnvVar(t *testing.T) {
} }
} }
func TestCDIDeviceRequestsFromAnnotations(t *testing.T) {
testCases := []struct {
description string
prefixes []string
annotations map[string]string
expectedDevices []string
}{
{
description: "no annotations",
},
{
description: "no matching annotations",
prefixes: []string{"not-prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
},
},
{
description: "single matching annotation",
prefixes: []string{"prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
},
expectedDevices: []string{"example.com/device=bar"},
},
{
description: "multiple matching annotations",
prefixes: []string{"prefix/", "another-prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
"another-prefix/bar": "example.com/device=baz",
},
expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"},
},
{
description: "multiple matching annotations with duplicate devices",
prefixes: []string{"prefix/", "another-prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device=bar",
"another-prefix/bar": "example.com/device=bar",
},
expectedDevices: []string{"example.com/device=bar", "example.com/device=bar"},
},
{
description: "invalid devices are returned as is",
prefixes: []string{"prefix/"},
annotations: map[string]string{
"prefix/foo": "example.com/device",
},
expectedDevices: []string{"example.com/device"},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
image, err := New(
WithAnnotationsPrefixes(tc.prefixes),
WithAnnotations(tc.annotations),
)
require.NoError(t, err)
devices := image.cdiDeviceRequestsFromAnnotations()
require.ElementsMatch(t, tc.expectedDevices, devices)
})
}
}
func makeTestMounts(paths ...string) []specs.Mount { func makeTestMounts(paths ...string) []specs.Mount {
var mounts []specs.Mount var mounts []specs.Mount
for _, path := range paths { for _, path := range paths {

View File

@@ -24,39 +24,20 @@ const (
capSysAdmin = "CAP_SYS_ADMIN" capSysAdmin = "CAP_SYS_ADMIN"
) )
type CapabilitiesGetter interface {
GetCapabilities() []string
}
type OCISpec specs.Spec
type OCISpecCapabilities specs.LinuxCapabilities
// IsPrivileged returns true if the container is a privileged container. // IsPrivileged returns true if the container is a privileged container.
func IsPrivileged(s CapabilitiesGetter) bool { func IsPrivileged(s *specs.Spec) bool {
if s == nil { if s.Process.Capabilities == nil {
return false return false
} }
for _, c := range s.GetCapabilities() {
// We only make sure that the bounding capabibility set has
// CAP_SYS_ADMIN. This allows us to make sure that the container was
// actually started as '--privileged', but also allow non-root users to
// access the privileged NVIDIA capabilities.
for _, c := range s.Process.Capabilities.Bounding {
if c == capSysAdmin { if c == capSysAdmin {
return true return true
} }
} }
return false return false
} }
func (s OCISpec) GetCapabilities() []string {
if s.Process == nil || s.Process.Capabilities == nil {
return nil
}
return (*OCISpecCapabilities)(s.Process.Capabilities).GetCapabilities()
}
func (c OCISpecCapabilities) GetCapabilities() []string {
// We only make sure that the bounding capability set has
// CAP_SYS_ADMIN. This allows us to make sure that the container was
// actually started as '--privileged', but also allow non-root users to
// access the privileged NVIDIA capabilities.
return c.Bounding
}

View File

@@ -1,57 +0,0 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package image
import (
"testing"
"github.com/opencontainers/runtime-spec/specs-go"
"github.com/stretchr/testify/require"
)
func TestIsPrivileged(t *testing.T) {
var tests = []struct {
spec specs.Spec
expected bool
}{
{
specs.Spec{
Process: &specs.Process{
Capabilities: &specs.LinuxCapabilities{
Bounding: []string{"CAP_SYS_ADMIN"},
},
},
},
true,
},
{
specs.Spec{
Process: &specs.Process{
Capabilities: &specs.LinuxCapabilities{
Bounding: []string{"CAP_SYS_FOO"},
},
},
},
false,
},
}
for i, tc := range tests {
privileged := IsPrivileged((*OCISpec)(&tc.spec))
require.Equal(t, tc.expected, privileged, "%d: %v", i, tc)
}
}

View File

@@ -62,7 +62,7 @@ load-kmods = true
#debug = "/var/log/nvidia-container-runtime.log" #debug = "/var/log/nvidia-container-runtime.log"
log-level = "info" log-level = "info"
mode = "auto" mode = "auto"
runtimes = ["runc", "crun"] runtimes = ["docker-runc", "runc", "crun"]
[nvidia-container-runtime.modes] [nvidia-container-runtime.modes]

View File

@@ -20,7 +20,6 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings" "strings"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
@@ -82,25 +81,14 @@ func NewGraphicsMountsDiscoverer(logger logger.Interface, driver *root.Driver, h
// vulkan ICD files are at {{ .driverRoot }}/vulkan instead of in /etc/vulkan. // vulkan ICD files are at {{ .driverRoot }}/vulkan instead of in /etc/vulkan.
func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Discover { func newVulkanConfigsDiscover(logger logger.Interface, driver *root.Driver) Discover {
locator := lookup.First(driver.Configs(), driver.Files()) locator := lookup.First(driver.Configs(), driver.Files())
required := []string{
"vulkan/icd.d/nvidia_icd.json",
"vulkan/icd.d/nvidia_layers.json",
"vulkan/implicit_layer.d/nvidia_layers.json",
}
// For some RPM-based driver packages, the vulkan ICD files are installed to
// /usr/share/vulkan/icd.d/nvidia_icd.%{_target_cpu}.json
// We also include this in the list of candidates for the ICD file.
switch runtime.GOARCH {
case "amd64":
required = append(required, "vulkan/icd.d/nvidia_icd.x86_64.json")
case "arm64":
required = append(required, "vulkan/icd.d/nvidia_icd.aarch64.json")
}
return &mountsToContainerPath{ return &mountsToContainerPath{
logger: logger, logger: logger,
locator: locator, locator: locator,
required: required, required: []string{
"vulkan/icd.d/nvidia_icd.json",
"vulkan/icd.d/nvidia_layers.json",
"vulkan/implicit_layer.d/nvidia_layers.json",
},
containerRoot: "/etc", containerRoot: "/etc",
} }
} }

View File

@@ -25,7 +25,7 @@ import (
func TestGraphicsLibrariesDiscoverer(t *testing.T) { func TestGraphicsLibrariesDiscoverer(t *testing.T) {
logger, _ := testlog.NewNullLogger() logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator() hookCreator := NewHookCreator("/usr/bin/nvidia-cdi-hook", false)
testCases := []struct { testCases := []struct {
description string description string

View File

@@ -23,33 +23,6 @@ import (
"tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/pkg/cdi"
) )
// A HookName represents a supported CDI hooks.
type HookName string
const (
// AllHooks is a special hook name that allows all hooks to be matched.
AllHooks = HookName("all")
// A ChmodHook is used to set the file mode of the specified paths.
// Deprecated: The chmod hook is deprecated and will be removed in a future release.
ChmodHook = HookName("chmod")
// A CreateSymlinksHook is used to create symlinks in the container.
CreateSymlinksHook = HookName("create-symlinks")
// DisableDeviceNodeModificationHook refers to the hook used to ensure that
// device nodes are not created by libnvidia-ml.so or nvidia-smi in a
// container.
// Added in v1.17.8
DisableDeviceNodeModificationHook = HookName("disable-device-node-modification")
// An EnableCudaCompatHook is used to enabled CUDA Forward Compatibility.
// Added in v1.17.5
EnableCudaCompatHook = HookName("enable-cuda-compat")
// An UpdateLDCacheHook is the hook used to update the ldcache in the
// container. This allows injected libraries to be discoverable.
UpdateLDCacheHook = HookName("update-ldcache")
defaultNvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
)
var _ Discover = (*Hook)(nil) var _ Discover = (*Hook)(nil)
// Devices returns an empty list of devices for a Hook discoverer. // Devices returns an empty list of devices for a Hook discoverer.
@@ -72,130 +45,52 @@ func (h *Hook) Hooks() ([]Hook, error) {
return []Hook{*h}, nil return []Hook{*h}, nil
} }
type Option func(*cdiHookCreator) // Option is a function that configures the nvcdilib
type Option func(*CDIHook)
type cdiHookCreator struct { type CDIHook struct {
nvidiaCDIHookPath string nvidiaCDIHookPath string
disabledHooks map[HookName]bool debugLogging bool
fixedArgs []string
debugLogging bool
} }
// An allDisabledHookCreator is a HookCreator that does not create any hooks.
type allDisabledHookCreator struct{}
// Create returns nil for all hooks for an allDisabledHookCreator.
func (a *allDisabledHookCreator) Create(name HookName, args ...string) *Hook {
return nil
}
// A HookCreator defines an interface for creating discover hooks.
type HookCreator interface { type HookCreator interface {
Create(HookName, ...string) *Hook Create(string, ...string) *Hook
} }
// WithDisabledHooks sets the set of hooks that are disabled for the CDI hook creator. func NewHookCreator(nvidiaCDIHookPath string, debugLogging bool) HookCreator {
// This can be specified multiple times. CDIHook := &CDIHook{
func WithDisabledHooks(hooks ...HookName) Option { nvidiaCDIHookPath: nvidiaCDIHookPath,
return func(c *cdiHookCreator) { debugLogging: debugLogging,
for _, hook := range hooks { }
c.disabledHooks[hook] = true
return CDIHook
}
func (c CDIHook) Create(name string, args ...string) *Hook {
if name == "create-symlinks" {
if len(args) == 0 {
return nil
} }
}
}
// WithNVIDIACDIHookPath sets the path to the nvidia-cdi-hook binary. links := []string{}
func WithNVIDIACDIHookPath(nvidiaCDIHookPath string) Option { for _, arg := range args {
return func(c *cdiHookCreator) { links = append(links, "--link", arg)
c.nvidiaCDIHookPath = nvidiaCDIHookPath }
} args = links
}
func NewHookCreator(opts ...Option) HookCreator {
cdiHookCreator := &cdiHookCreator{
nvidiaCDIHookPath: defaultNvidiaCDIHookPath,
disabledHooks: make(map[HookName]bool),
}
for _, opt := range opts {
opt(cdiHookCreator)
}
if cdiHookCreator.disabledHooks[AllHooks] {
return &allDisabledHookCreator{}
}
cdiHookCreator.fixedArgs = getFixedArgsForCDIHookCLI(cdiHookCreator.nvidiaCDIHookPath)
return cdiHookCreator
}
// Create creates a new hook with the given name and arguments.
// If a hook is disabled, a nil hook is returned.
func (c cdiHookCreator) Create(name HookName, args ...string) *Hook {
if c.isDisabled(name, args...) {
return nil
} }
return &Hook{ return &Hook{
Lifecycle: cdi.CreateContainerHook, Lifecycle: cdi.CreateContainerHook,
Path: c.nvidiaCDIHookPath, Path: c.nvidiaCDIHookPath,
Args: append(c.requiredArgs(name), c.transformArgs(name, args...)...), Args: append(c.requiredArgs(name), args...),
Env: []string{fmt.Sprintf("NVIDIA_CTK_DEBUG=%v", c.debugLogging)}, Env: []string{fmt.Sprintf("NVIDIA_CTK_DEBUG=%v", c.debugLogging)},
} }
} }
// isDisabled checks if the specified hook name is disabled. func (c CDIHook) requiredArgs(name string) []string {
func (c cdiHookCreator) isDisabled(name HookName, args ...string) bool { base := filepath.Base(c.nvidiaCDIHookPath)
if c.disabledHooks[name] {
return true
}
switch name {
case CreateSymlinksHook:
if len(args) == 0 {
return true
}
case ChmodHook:
if len(args) == 0 {
return true
}
}
return false
}
func (c cdiHookCreator) requiredArgs(name HookName) []string {
return append(c.fixedArgs, string(name))
}
func (c cdiHookCreator) transformArgs(name HookName, args ...string) []string {
switch name {
case CreateSymlinksHook:
var transformedArgs []string
for _, arg := range args {
transformedArgs = append(transformedArgs, "--link", arg)
}
return transformedArgs
case ChmodHook:
var transformedArgs = []string{"--mode", "755"}
for _, arg := range args {
transformedArgs = append(transformedArgs, "--path", arg)
}
return transformedArgs
default:
return args
}
}
// getFixedArgsForCDIHookCLI returns the fixed arguments for the hook CLI.
// If the nvidia-ctk binary is used, hooks are implemented under the hook
// subcommand.
// For the nvidia-cdi-hook binary, the hooks are implemented as subcommands of
// the top-level CLI.
func getFixedArgsForCDIHookCLI(nvidiaCDIHookPath string) []string {
base := filepath.Base(nvidiaCDIHookPath)
if base == "nvidia-ctk" { if base == "nvidia-ctk" {
return []string{base, "hook"} return []string{base, "hook", name}
} }
return []string{base} return []string{base, name}
} }

View File

@@ -72,7 +72,7 @@ func createLDCacheUpdateHook(hookCreator HookCreator, ldconfig string, libraries
args = append(args, "--folder", f) args = append(args, "--folder", f)
} }
return hookCreator.Create(UpdateLDCacheHook, args...) return hookCreator.Create("update-ldcache", args...)
} }
// getLibraryPaths extracts the library dirs from the specified mounts // getLibraryPaths extracts the library dirs from the specified mounts

View File

@@ -31,7 +31,7 @@ const (
func TestLDCacheUpdateHook(t *testing.T) { func TestLDCacheUpdateHook(t *testing.T) {
logger, _ := testlog.NewNullLogger() logger, _ := testlog.NewNullLogger()
hookCreator := NewHookCreator(WithNVIDIACDIHookPath(testNvidiaCDIHookPath)) hookCreator := NewHookCreator(testNvidiaCDIHookPath, false)
testCases := []struct { testCases := []struct {
description string description string

View File

@@ -113,7 +113,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
expectedHooks: []Hook{ expectedHooks: []Hook{
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "/path/to/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"}, Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
Env: []string{"NVIDIA_CTK_DEBUG=false"}, Env: []string{"NVIDIA_CTK_DEBUG=false"},
}, },
@@ -146,7 +146,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
expectedHooks: []Hook{ expectedHooks: []Hook{
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "/path/to/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"}, Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
Env: []string{"NVIDIA_CTK_DEBUG=false"}, Env: []string{"NVIDIA_CTK_DEBUG=false"},
}, },
@@ -178,7 +178,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
expectedHooks: []Hook{ expectedHooks: []Hook{
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "/path/to/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"}, Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
Env: []string{"NVIDIA_CTK_DEBUG=false"}, Env: []string{"NVIDIA_CTK_DEBUG=false"},
}, },
@@ -248,7 +248,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
}, },
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "/path/to/nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"}, Args: []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
Env: []string{"NVIDIA_CTK_DEBUG=false"}, Env: []string{"NVIDIA_CTK_DEBUG=false"},
}, },
@@ -298,7 +298,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
expectedHooks: []Hook{ expectedHooks: []Hook{
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "/path/to/nvidia-cdi-hook",
Args: []string{ Args: []string{
"nvidia-cdi-hook", "create-symlinks", "nvidia-cdi-hook", "create-symlinks",
"--link", "libcuda.so.1::/usr/lib/libcuda.so", "--link", "libcuda.so.1::/usr/lib/libcuda.so",
@@ -311,7 +311,7 @@ func TestWithWithDriverDotSoSymlinks(t *testing.T) {
}, },
} }
hookCreator := NewHookCreator() hookCreator := NewHookCreator("/path/to/nvidia-cdi-hook", false)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
d := WithDriverDotSoSymlinks( d := WithDriverDotSoSymlinks(

View File

@@ -20,8 +20,6 @@ import (
"tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go" "tags.cncf.io/container-device-interface/specs-go"
"github.com/opencontainers/runc/libcontainer/devices"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover" "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
) )
@@ -45,37 +43,19 @@ func (d device) toEdits() (*cdi.ContainerEdits, error) {
// toSpec converts a discovered Device to a CDI Spec Device. Note // toSpec converts a discovered Device to a CDI Spec Device. Note
// that missing info is filled in when edits are applied by querying the Device node. // that missing info is filled in when edits are applied by querying the Device node.
func (d device) toSpec() (*specs.DeviceNode, error) { func (d device) toSpec() (*specs.DeviceNode, error) {
s := d.fromPathOrDefault()
// The HostPath field was added in the v0.5.0 CDI specification. // The HostPath field was added in the v0.5.0 CDI specification.
// The cdi package uses strict unmarshalling when loading specs from file causing failures for // The cdi package uses strict unmarshalling when loading specs from file causing failures for
// unexpected fields. // unexpected fields.
// Since the behaviour for HostPath == "" and HostPath == Path are equivalent, we clear HostPath // Since the behaviour for HostPath == "" and HostPath == Path are equivalent, we clear HostPath
// if it is equal to Path to ensure compatibility with the widest range of specs. // if it is equal to Path to ensure compatibility with the widest range of specs.
if s.HostPath == d.Path { hostPath := d.HostPath
s.HostPath = "" if hostPath == d.Path {
hostPath = ""
} }
s := specs.DeviceNode{
return s, nil HostPath: hostPath,
}
// fromPathOrDefault attempts to return the returns the information about the
// CDI device from the specified host path.
// If this fails a minimal device is returned so that this information can be
// queried by the container runtime such as containerd.
func (d device) fromPathOrDefault() *specs.DeviceNode {
dn, err := devices.DeviceFromPath(d.HostPath, "rwm")
if err != nil {
return &specs.DeviceNode{
HostPath: d.HostPath,
Path: d.Path,
}
}
return &specs.DeviceNode{
HostPath: d.HostPath,
Path: d.Path, Path: d.Path,
Major: dn.Major,
Minor: dn.Minor,
FileMode: &dn.FileMode,
} }
return &s, nil
} }

View File

@@ -184,7 +184,7 @@ func TestResolveAutoMode(t *testing.T) {
expectedMode: "legacy", expectedMode: "legacy",
}, },
{ {
description: "cdi mount and non-CDI envvar resolves to cdi", description: "cdi mount and non-CDI envvar resolves to legacy",
mode: "auto", mode: "auto",
envmap: map[string]string{ envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "0", "NVIDIA_VISIBLE_DEVICES": "0",
@@ -197,22 +197,6 @@ func TestResolveAutoMode(t *testing.T) {
"tegra": false, "tegra": false,
"nvgpu": false, "nvgpu": false,
}, },
expectedMode: "cdi",
},
{
description: "non-cdi mount and CDI envvar resolves to legacy",
mode: "auto",
envmap: map[string]string{
"NVIDIA_VISIBLE_DEVICES": "nvidia.com/gpu=0",
},
mounts: []string{
"/var/run/nvidia-container-devices/0",
},
info: map[string]bool{
"nvml": true,
"tegra": false,
"nvgpu": false,
},
expectedMode: "legacy", expectedMode: "legacy",
}, },
} }
@@ -248,8 +232,6 @@ func TestResolveAutoMode(t *testing.T) {
image, _ := image.New( image, _ := image.New(
image.WithEnvMap(tc.envmap), image.WithEnvMap(tc.envmap),
image.WithMounts(mounts), image.WithMounts(mounts),
image.WithAcceptDeviceListAsVolumeMounts(true),
image.WithAcceptEnvvarUnprivileged(true),
) )
mode := resolveMode(logger, tc.mode, image, properties) mode := resolveMode(logger, tc.mode, image, properties)
require.EqualValues(t, tc.expectedMode, mode) require.EqualValues(t, tc.expectedMode, mode)

View File

@@ -18,6 +18,7 @@ package modifier
import ( import (
"fmt" "fmt"
"strings"
"tags.cncf.io/container-device-interface/pkg/parser" "tags.cncf.io/container-device-interface/pkg/parser"
@@ -33,13 +34,11 @@ import (
// NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the // NewCDIModifier creates an OCI spec modifier that determines the modifications to make based on the
// CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is // CDI specifications available on the system. The NVIDIA_VISIBLE_DEVICES environment variable is
// used to select the devices to include. // used to select the devices to include.
func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { func NewCDIModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (oci.SpecModifier, error) {
deviceRequestor := newCDIDeviceRequestor( devices, err := getDevicesFromSpec(logger, ociSpec, cfg)
logger, if err != nil {
image, return nil, fmt.Errorf("failed to get required devices from OCI specification: %v", err)
cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, }
)
devices := deviceRequestor.DeviceRequests()
if len(devices) == 0 { if len(devices) == 0 {
logger.Debugf("No devices requested; no modification required.") logger.Debugf("No devices requested; no modification required.")
return nil, nil return nil, nil
@@ -66,38 +65,87 @@ func NewCDIModifier(logger logger.Interface, cfg *config.Config, image image.CUD
) )
} }
type deviceRequestor interface { func getDevicesFromSpec(logger logger.Interface, ociSpec oci.Spec, cfg *config.Config) ([]string, error) {
DeviceRequests() []string rawSpec, err := ociSpec.Load()
} if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
type cdiDeviceRequestor struct {
image image.CUDA
logger logger.Interface
defaultKind string
}
func newCDIDeviceRequestor(logger logger.Interface, image image.CUDA, defaultKind string) deviceRequestor {
c := &cdiDeviceRequestor{
logger: logger,
image: image,
defaultKind: defaultKind,
} }
return withUniqueDevices(c)
}
func (c *cdiDeviceRequestor) DeviceRequests() []string { annotationDevices, err := getAnnotationDevices(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes, rawSpec.Annotations)
if c == nil { if err != nil {
return nil return nil, fmt.Errorf("failed to parse container annotations: %v", err)
} }
if len(annotationDevices) > 0 {
return annotationDevices, nil
}
container, err := image.NewCUDAImageFromSpec(rawSpec)
if err != nil {
return nil, err
}
if cfg.AcceptDeviceListAsVolumeMounts {
mountDevices := container.CDIDevicesFromMounts()
if len(mountDevices) > 0 {
return mountDevices, nil
}
}
var devices []string var devices []string
for _, name := range c.image.VisibleDevices() { seen := make(map[string]bool)
for _, name := range container.VisibleDevicesFromEnvVar() {
if !parser.IsQualifiedName(name) { if !parser.IsQualifiedName(name) {
name = fmt.Sprintf("%s=%s", c.defaultKind, name) name = fmt.Sprintf("%s=%s", cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.DefaultKind, name)
}
if seen[name] {
logger.Debugf("Ignoring duplicate device %q", name)
continue
} }
devices = append(devices, name) devices = append(devices, name)
} }
return devices if len(devices) == 0 {
return nil, nil
}
if cfg.AcceptEnvvarUnprivileged || image.IsPrivileged(rawSpec) {
return devices, nil
}
logger.Warningf("Ignoring devices specified in NVIDIA_VISIBLE_DEVICES: %v", devices)
return nil, nil
}
// getAnnotationDevices returns a list of devices specified in the annotations.
// Keys starting with the specified prefixes are considered and expected to contain a comma-separated list of
// fully-qualified CDI devices names. If any device name is not fully-quality an error is returned.
// The list of returned devices is deduplicated.
func getAnnotationDevices(prefixes []string, annotations map[string]string) ([]string, error) {
devicesByKey := make(map[string][]string)
for key, value := range annotations {
for _, prefix := range prefixes {
if strings.HasPrefix(key, prefix) {
devicesByKey[key] = strings.Split(value, ",")
}
}
}
seen := make(map[string]bool)
var annotationDevices []string
for key, devices := range devicesByKey {
for _, device := range devices {
if !parser.IsQualifiedName(device) {
return nil, fmt.Errorf("invalid device name %q in annotation %q", device, key)
}
if seen[device] {
continue
}
annotationDevices = append(annotationDevices, device)
seen[device] = true
}
}
return annotationDevices, nil
} }
// filterAutomaticDevices searches for "automatic" device names in the input slice. // filterAutomaticDevices searches for "automatic" device names in the input slice.
@@ -121,7 +169,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate CDI spec: %w", err) return nil, fmt.Errorf("failed to generate CDI spec: %w", err)
} }
cdiDeviceRequestor, err := cdi.New( cdiModifier, err := cdi.New(
cdi.WithLogger(logger), cdi.WithLogger(logger),
cdi.WithSpec(spec.Raw()), cdi.WithSpec(spec.Raw()),
) )
@@ -129,7 +177,7 @@ func newAutomaticCDISpecModifier(logger logger.Interface, cfg *config.Config, de
return nil, fmt.Errorf("failed to construct CDI modifier: %w", err) return nil, fmt.Errorf("failed to construct CDI modifier: %w", err)
} }
return cdiDeviceRequestor, nil return cdiModifier, nil
} }
func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) { func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devices []string) (spec.Interface, error) {
@@ -144,35 +192,26 @@ func generateAutomaticCDISpec(logger logger.Interface, cfg *config.Config, devic
return nil, fmt.Errorf("failed to construct CDI library: %w", err) return nil, fmt.Errorf("failed to construct CDI library: %w", err)
} }
var identifiers []string identifiers := []string{}
for _, device := range devices { for _, device := range devices {
_, _, id := parser.ParseDevice(device) _, _, id := parser.ParseDevice(device)
identifiers = append(identifiers, id) identifiers = append(identifiers, id)
} }
return cdilib.GetSpec(identifiers...) deviceSpecs, err := cdilib.GetDeviceSpecsByID(identifiers...)
} if err != nil {
return nil, fmt.Errorf("failed to get CDI device specs: %w", err)
type deduplicatedDeviceRequestor struct {
deviceRequestor
}
func withUniqueDevices(deviceRequestor deviceRequestor) deviceRequestor {
return &deduplicatedDeviceRequestor{deviceRequestor: deviceRequestor}
}
func (d *deduplicatedDeviceRequestor) DeviceRequests() []string {
if d == nil {
return nil
} }
seen := make(map[string]bool)
var devices []string commonEdits, err := cdilib.GetCommonEdits()
for _, device := range d.deviceRequestor.DeviceRequests() { if err != nil {
if seen[device] { return nil, fmt.Errorf("failed to get common CDI spec edits: %w", err)
continue
}
seen[device] = true
devices = append(devices, device)
} }
return devices
return spec.New(
spec.WithDeviceSpecs(deviceSpecs),
spec.WithEdits(*commonEdits.ContainerEdits),
spec.WithVendor("runtime.nvidia.com"),
spec.WithClass("gpu"),
)
} }

View File

@@ -17,144 +17,76 @@
package modifier package modifier
import ( import (
"fmt"
"testing" "testing"
"github.com/opencontainers/runtime-spec/specs-go"
testlog "github.com/sirupsen/logrus/hooks/test"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
) )
func TestDeviceRequests(t *testing.T) { func TestGetAnnotationDevices(t *testing.T) {
logger, _ := testlog.NewNullLogger()
testCases := []struct { testCases := []struct {
description string description string
input cdiDeviceRequestor
spec *specs.Spec
prefixes []string prefixes []string
annotations map[string]string
expectedDevices []string expectedDevices []string
expectedError error
}{ }{
{ {
description: "empty spec yields no devices", description: "no annotations",
},
{
description: "cdi devices from mounts",
input: cdiDeviceRequestor{
defaultKind: "nvidia.com/gpu",
},
spec: &specs.Spec{
Mounts: []specs.Mount{
{
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/0",
Source: "/dev/null",
},
{
Destination: "/var/run/nvidia-container-devices/cdi/nvidia.com/gpu/1",
Source: "/dev/null",
},
},
},
expectedDevices: []string{"nvidia.com/gpu=0", "nvidia.com/gpu=1"},
},
{
description: "cdi devices from envvar",
input: cdiDeviceRequestor{
defaultKind: "nvidia.com/gpu",
},
spec: &specs.Spec{
Process: &specs.Process{
Env: []string{"NVIDIA_VISIBLE_DEVICES=0,example.com/class=device"},
},
},
expectedDevices: []string{"nvidia.com/gpu=0", "example.com/class=device"},
}, },
{ {
description: "no matching annotations", description: "no matching annotations",
prefixes: []string{"not-prefix/"}, prefixes: []string{"not-prefix/"},
spec: &specs.Spec{ annotations: map[string]string{
Annotations: map[string]string{ "prefix/foo": "example.com/device=bar",
"prefix/foo": "example.com/device=bar",
},
}, },
}, },
{ {
description: "single matching annotation", description: "single matching annotation",
prefixes: []string{"prefix/"}, prefixes: []string{"prefix/"},
spec: &specs.Spec{ annotations: map[string]string{
Annotations: map[string]string{ "prefix/foo": "example.com/device=bar",
"prefix/foo": "example.com/device=bar",
},
}, },
expectedDevices: []string{"example.com/device=bar"}, expectedDevices: []string{"example.com/device=bar"},
}, },
{ {
description: "multiple matching annotations", description: "multiple matching annotations",
prefixes: []string{"prefix/", "another-prefix/"}, prefixes: []string{"prefix/", "another-prefix/"},
spec: &specs.Spec{ annotations: map[string]string{
Annotations: map[string]string{ "prefix/foo": "example.com/device=bar",
"prefix/foo": "example.com/device=bar", "another-prefix/bar": "example.com/device=baz",
"another-prefix/bar": "example.com/device=baz",
},
}, },
expectedDevices: []string{"example.com/device=baz", "example.com/device=bar"}, expectedDevices: []string{"example.com/device=bar", "example.com/device=baz"},
}, },
{ {
description: "multiple matching annotations with duplicate devices", description: "multiple matching annotations with duplicate devices",
prefixes: []string{"prefix/", "another-prefix/"}, prefixes: []string{"prefix/", "another-prefix/"},
spec: &specs.Spec{ annotations: map[string]string{
Annotations: map[string]string{ "prefix/foo": "example.com/device=bar",
"prefix/foo": "example.com/device=bar", "another-prefix/bar": "example.com/device=bar",
"another-prefix/bar": "example.com/device=bar",
},
}, },
expectedDevices: []string{"example.com/device=bar", "example.com/device=bar"}, expectedDevices: []string{"example.com/device=bar"},
}, },
{ {
description: "devices in annotations are expanded", description: "invalid devices",
input: cdiDeviceRequestor{ prefixes: []string{"prefix/"},
defaultKind: "nvidia.com/gpu", annotations: map[string]string{
"prefix/foo": "example.com/device",
}, },
prefixes: []string{"prefix/"}, expectedError: fmt.Errorf("invalid device %q", "example.com/device"),
spec: &specs.Spec{
Annotations: map[string]string{
"prefix/foo": "device",
},
},
expectedDevices: []string{"nvidia.com/gpu=device"},
},
{
description: "invalid devices in annotations are treated as strings",
input: cdiDeviceRequestor{
defaultKind: "nvidia.com/gpu",
},
prefixes: []string{"prefix/"},
spec: &specs.Spec{
Annotations: map[string]string{
"prefix/foo": "example.com/device",
},
},
expectedDevices: []string{"nvidia.com/gpu=example.com/device"},
}, },
} }
for _, tc := range testCases { for _, tc := range testCases {
tc.input.logger = logger
image, err := image.NewCUDAImageFromSpec(
tc.spec,
image.WithAcceptDeviceListAsVolumeMounts(true),
image.WithAcceptEnvvarUnprivileged(true),
image.WithAnnotationsPrefixes(tc.prefixes),
)
require.NoError(t, err)
tc.input.image = image
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
devices := tc.input.DeviceRequests() devices, err := getAnnotationDevices(tc.prefixes, tc.annotations)
if tc.expectedError != nil {
require.Error(t, err)
return
}
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, tc.expectedDevices, devices) require.ElementsMatch(t, tc.expectedDevices, devices)
}) })
} }
} }

View File

@@ -33,7 +33,7 @@ import (
// NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper. // NewCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
// The modifications are defined by CSV MountSpecs. // The modifications are defined by CSV MountSpecs.
func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image.CUDA) (oci.SpecModifier, error) { func NewCSVModifier(logger logger.Interface, cfg *config.Config, container image.CUDA) (oci.SpecModifier, error) {
if devices := container.VisibleDevices(); len(devices) == 0 { if devices := container.VisibleDevicesFromEnvVar(); len(devices) == 0 {
logger.Infof("No modification required; no devices requested") logger.Infof("No modification required; no devices requested")
return nil, nil return nil, nil
} }

View File

@@ -37,7 +37,7 @@ import (
// //
// If not devices are selected, no changes are made. // If not devices are selected, no changes are made.
func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { func NewFeatureGatedModifier(logger logger.Interface, cfg *config.Config, image image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
if devices := image.VisibleDevices(); len(devices) == 0 { if devices := image.VisibleDevicesFromEnvVar(); len(devices) == 0 {
logger.Infof("No modification required; no devices requested") logger.Infof("No modification required; no devices requested")
return nil, nil return nil, nil
} }

View File

@@ -29,10 +29,9 @@ import (
// NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification. // NewGraphicsModifier constructs a modifier that injects graphics-related modifications into an OCI runtime specification.
// The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made. // The value of the NVIDIA_DRIVER_CAPABILITIES environment variable is checked to determine if this modification should be made.
func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, container image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) { func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, containerImage image.CUDA, driver *root.Driver, hookCreator discover.HookCreator) (oci.SpecModifier, error) {
devices, reason := requiresGraphicsModifier(container) if required, reason := requiresGraphicsModifier(containerImage); !required {
if len(devices) == 0 { logger.Infof("No graphics modifier required: %v", reason)
logger.Infof("No graphics modifier required; %v", reason)
return nil, nil return nil, nil
} }
@@ -49,7 +48,7 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, container
devRoot := driver.Root devRoot := driver.Root
drmNodes, err := discover.NewDRMNodesDiscoverer( drmNodes, err := discover.NewDRMNodesDiscoverer(
logger, logger,
image.NewVisibleDevices(devices...), containerImage.DevicesFromEnvvars(image.EnvVarNvidiaVisibleDevices),
devRoot, devRoot,
hookCreator, hookCreator,
) )
@@ -65,15 +64,14 @@ func NewGraphicsModifier(logger logger.Interface, cfg *config.Config, container
} }
// requiresGraphicsModifier determines whether a graphics modifier is required. // requiresGraphicsModifier determines whether a graphics modifier is required.
func requiresGraphicsModifier(cudaImage image.CUDA) ([]string, string) { func requiresGraphicsModifier(cudaImage image.CUDA) (bool, string) {
devices := cudaImage.VisibleDevices() if devices := cudaImage.VisibleDevicesFromEnvVar(); len(devices) == 0 {
if len(devices) == 0 { return false, "no devices requested"
return nil, "no devices requested"
} }
if !cudaImage.GetDriverCapabilities().Any(image.DriverCapabilityGraphics, image.DriverCapabilityDisplay) { if !cudaImage.GetDriverCapabilities().Any(image.DriverCapabilityGraphics, image.DriverCapabilityDisplay) {
return nil, "no required capabilities requested" return false, "no required capabilities requested"
} }
return devices, "" return true, ""
} }

View File

@@ -26,9 +26,9 @@ import (
func TestGraphicsModifier(t *testing.T) { func TestGraphicsModifier(t *testing.T) {
testCases := []struct { testCases := []struct {
description string description string
envmap map[string]string envmap map[string]string
expectedDevices []string expectedRequired bool
}{ }{
{ {
description: "empty image does not create modifier", description: "empty image does not create modifier",
@@ -52,7 +52,7 @@ func TestGraphicsModifier(t *testing.T) {
"NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "all", "NVIDIA_DRIVER_CAPABILITIES": "all",
}, },
expectedDevices: []string{"all"}, expectedRequired: true,
}, },
{ {
description: "devices with graphics capability creates modifier", description: "devices with graphics capability creates modifier",
@@ -60,7 +60,7 @@ func TestGraphicsModifier(t *testing.T) {
"NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "graphics", "NVIDIA_DRIVER_CAPABILITIES": "graphics",
}, },
expectedDevices: []string{"all"}, expectedRequired: true,
}, },
{ {
description: "devices with compute,graphics capability creates modifier", description: "devices with compute,graphics capability creates modifier",
@@ -68,7 +68,7 @@ func TestGraphicsModifier(t *testing.T) {
"NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "compute,graphics", "NVIDIA_DRIVER_CAPABILITIES": "compute,graphics",
}, },
expectedDevices: []string{"all"}, expectedRequired: true,
}, },
{ {
description: "devices with display capability creates modifier", description: "devices with display capability creates modifier",
@@ -76,7 +76,7 @@ func TestGraphicsModifier(t *testing.T) {
"NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "display", "NVIDIA_DRIVER_CAPABILITIES": "display",
}, },
expectedDevices: []string{"all"}, expectedRequired: true,
}, },
{ {
description: "devices with display,graphics capability creates modifier", description: "devices with display,graphics capability creates modifier",
@@ -84,7 +84,7 @@ func TestGraphicsModifier(t *testing.T) {
"NVIDIA_VISIBLE_DEVICES": "all", "NVIDIA_VISIBLE_DEVICES": "all",
"NVIDIA_DRIVER_CAPABILITIES": "display,graphics", "NVIDIA_DRIVER_CAPABILITIES": "display,graphics",
}, },
expectedDevices: []string{"all"}, expectedRequired: true,
}, },
} }
@@ -94,7 +94,7 @@ func TestGraphicsModifier(t *testing.T) {
image.WithEnvMap(tc.envmap), image.WithEnvMap(tc.envmap),
) )
required, _ := requiresGraphicsModifier(image) required, _ := requiresGraphicsModifier(image)
require.EqualValues(t, tc.expectedDevices, required) require.EqualValues(t, tc.expectedRequired, required)
}) })
} }
} }

View File

@@ -183,7 +183,7 @@ func TestDiscovererFromCSVFiles(t *testing.T) {
}, },
} }
hookCreator := discover.NewHookCreator() hookCreator := discover.NewHookCreator("/usr/bin/nvidia-cdi-hook", false)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) { t.Run(tc.description, func(t *testing.T) {
defer setGetTargetsFromCSVFiles(tc.moutSpecs)() defer setGetTargetsFromCSVFiles(tc.moutSpecs)()

View File

@@ -18,6 +18,7 @@ package runtime
import ( import (
"fmt" "fmt"
"os"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config" "github.com/NVIDIA/nvidia-container-toolkit/internal/config"
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image" "github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
@@ -65,17 +66,29 @@ func newNVIDIAContainerRuntime(logger logger.Interface, cfg *config.Config, argv
// newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config. // newSpecModifier is a factory method that creates constructs an OCI spec modifer based on the provided config.
func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) { func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec, driver *root.Driver) (oci.SpecModifier, error) {
mode, image, err := initRuntimeModeAndImage(logger, cfg, ociSpec) rawSpec, err := ociSpec.Load()
if err != nil {
return nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
image, err := image.NewCUDAImageFromSpec(rawSpec)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modeModifier, err := newModeModifier(logger, mode, cfg, *image) hookCreator := discover.NewHookCreator(
cfg.NVIDIACTKConfig.Path,
cfg.NVIDIAContainerRuntimeConfig.DebugFilePath == "" || cfg.NVIDIAContainerRuntimeConfig.DebugFilePath == os.DevNull,
)
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
// We update the mode here so that we can continue passing just the config to other functions.
cfg.NVIDIAContainerRuntimeConfig.Mode = mode
modeModifier, err := newModeModifier(logger, mode, cfg, ociSpec, image)
if err != nil { if err != nil {
return nil, err return nil, err
} }
hookCreator := discover.NewHookCreator(discover.WithNVIDIACDIHookPath(cfg.NVIDIACTKConfig.Path))
var modifiers modifier.List var modifiers modifier.List
for _, modifierType := range supportedModifierTypes(mode) { for _, modifierType := range supportedModifierTypes(mode) {
switch modifierType { switch modifierType {
@@ -84,13 +97,13 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
case "nvidia-hook-remover": case "nvidia-hook-remover":
modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger)) modifiers = append(modifiers, modifier.NewNvidiaContainerRuntimeHookRemover(logger))
case "graphics": case "graphics":
graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, *image, driver, hookCreator) graphicsModifier, err := modifier.NewGraphicsModifier(logger, cfg, image, driver, hookCreator)
if err != nil { if err != nil {
return nil, err return nil, err
} }
modifiers = append(modifiers, graphicsModifier) modifiers = append(modifiers, graphicsModifier)
case "feature-gated": case "feature-gated":
featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, *image, driver, hookCreator) featureGatedModifier, err := modifier.NewFeatureGatedModifier(logger, cfg, image, driver, hookCreator)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -101,58 +114,19 @@ func newSpecModifier(logger logger.Interface, cfg *config.Config, ociSpec oci.Sp
return modifiers, nil return modifiers, nil
} }
func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, image image.CUDA) (oci.SpecModifier, error) { func newModeModifier(logger logger.Interface, mode string, cfg *config.Config, ociSpec oci.Spec, image image.CUDA) (oci.SpecModifier, error) {
switch mode { switch mode {
case "legacy": case "legacy":
return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil return modifier.NewStableRuntimeModifier(logger, cfg.NVIDIAContainerRuntimeHookConfig.Path), nil
case "csv": case "csv":
return modifier.NewCSVModifier(logger, cfg, image) return modifier.NewCSVModifier(logger, cfg, image)
case "cdi": case "cdi":
return modifier.NewCDIModifier(logger, cfg, image) return modifier.NewCDIModifier(logger, cfg, ociSpec)
} }
return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode) return nil, fmt.Errorf("invalid runtime mode: %v", cfg.NVIDIAContainerRuntimeConfig.Mode)
} }
// initRuntimeModeAndImage constructs an image from the specified OCI runtime
// specification and runtime config.
// The image is also used to determine the runtime mode to apply.
// If a non-CDI mode is detected we ensure that the image does not process
// annotation devices.
func initRuntimeModeAndImage(logger logger.Interface, cfg *config.Config, ociSpec oci.Spec) (string, *image.CUDA, error) {
rawSpec, err := ociSpec.Load()
if err != nil {
return "", nil, fmt.Errorf("failed to load OCI spec: %v", err)
}
image, err := image.NewCUDAImageFromSpec(
rawSpec,
image.WithLogger(logger),
image.WithAcceptDeviceListAsVolumeMounts(cfg.AcceptDeviceListAsVolumeMounts),
image.WithAcceptEnvvarUnprivileged(cfg.AcceptEnvvarUnprivileged),
image.WithAnnotationsPrefixes(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes),
)
if err != nil {
return "", nil, err
}
mode := info.ResolveAutoMode(logger, cfg.NVIDIAContainerRuntimeConfig.Mode, image)
// We update the mode here so that we can continue passing just the config to other functions.
cfg.NVIDIAContainerRuntimeConfig.Mode = mode
if mode == "cdi" || len(cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes) == 0 {
return mode, &image, nil
}
// For non-cdi modes we explicitly set the annotation prefixes to nil and
// call this function again to force a reconstruction of the image.
// Note that since the mode is now explicitly set, we will effectively skip
// the mode resolution.
cfg.NVIDIAContainerRuntimeConfig.Modes.CDI.AnnotationPrefixes = nil
return initRuntimeModeAndImage(logger, cfg, ociSpec)
}
// supportedModifierTypes returns the modifiers supported for a specific runtime mode. // supportedModifierTypes returns the modifiers supported for a specific runtime mode.
func supportedModifierTypes(mode string) []string { func supportedModifierTypes(mode string) []string {
switch mode { switch mode {

View File

@@ -1,5 +1,3 @@
nvidia-container-runtime /usr/bin nvidia-container-runtime /usr/bin
nvidia-ctk /usr/bin nvidia-ctk /usr/bin
nvidia-cdi-hook /usr/bin nvidia-cdi-hook /usr/bin
nvidia-cdi-refresh.service /etc/systemd/system/
nvidia-cdi-refresh.path /etc/systemd/system/

View File

@@ -5,16 +5,6 @@ set -e
case "$1" in case "$1" in
configure) configure)
/usr/bin/nvidia-ctk --quiet config --config-file=/etc/nvidia-container-runtime/config.toml --in-place /usr/bin/nvidia-ctk --quiet config --config-file=/etc/nvidia-container-runtime/config.toml --in-place
if command -v systemctl >/dev/null 2>&1 \
&& systemctl --quiet is-system-running 2>/dev/null; then
systemctl daemon-reload || true
if [ -z "$2" ]; then # $2 empty → first install
systemctl enable --now nvidia-cdi-refresh.path || true
fi
fi
;; ;;
abort-upgrade|abort-remove|abort-deconfigure) abort-upgrade|abort-remove|abort-deconfigure)

View File

@@ -5,14 +5,3 @@
%: %:
dh $@ dh $@
override_dh_fixperms:
dh_fixperms
chmod 755 debian/$(shell dh_listpackages)/usr/bin/nvidia-container-runtime-hook || true
chmod 755 debian/$(shell dh_listpackages)/usr/bin/nvidia-container-runtime || true
chmod 755 debian/$(shell dh_listpackages)/usr/bin/nvidia-container-runtime.cdi || true
chmod 755 debian/$(shell dh_listpackages)/usr/bin/nvidia-container-runtime.legacy || true
chmod 755 debian/$(shell dh_listpackages)/usr/bin/nvidia-ctk || true
chmod 755 debian/$(shell dh_listpackages)/usr/bin/nvidia-cdi-hook || true
chmod 644 debian/$(shell dh_listpackages)/etc/systemd/system/nvidia-cdi-refresh.service || true
chmod 644 debian/$(shell dh_listpackages)/etc/systemd/system/nvidia-cdi-refresh.path || true

View File

@@ -17,8 +17,6 @@ Source3: nvidia-container-runtime
Source4: nvidia-container-runtime.cdi Source4: nvidia-container-runtime.cdi
Source5: nvidia-container-runtime.legacy Source5: nvidia-container-runtime.legacy
Source6: nvidia-cdi-hook Source6: nvidia-cdi-hook
Source7: nvidia-cdi-refresh.service
Source8: nvidia-cdi-refresh.path
Obsoletes: nvidia-container-runtime <= 3.5.0-1, nvidia-container-runtime-hook <= 1.4.0-2 Obsoletes: nvidia-container-runtime <= 3.5.0-1, nvidia-container-runtime-hook <= 1.4.0-2
Provides: nvidia-container-runtime Provides: nvidia-container-runtime
@@ -30,20 +28,16 @@ Requires: nvidia-container-toolkit-base == %{version}-%{release}
Provides tools and utilities to enable GPU support in containers. Provides tools and utilities to enable GPU support in containers.
%prep %prep
cp %{SOURCE0} %{SOURCE1} %{SOURCE2} %{SOURCE3} %{SOURCE4} %{SOURCE5} %{SOURCE6} %{SOURCE7} %{SOURCE8} . cp %{SOURCE0} %{SOURCE1} %{SOURCE2} %{SOURCE3} %{SOURCE4} %{SOURCE5} %{SOURCE6} .
%install %install
mkdir -p %{buildroot}%{_bindir} mkdir -p %{buildroot}%{_bindir}
mkdir -p %{buildroot}%{_sysconfdir}/systemd/system/
install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime-hook install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime-hook
install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime
install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime.cdi install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime.cdi
install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime.legacy install -m 755 -t %{buildroot}%{_bindir} nvidia-container-runtime.legacy
install -m 755 -t %{buildroot}%{_bindir} nvidia-ctk install -m 755 -t %{buildroot}%{_bindir} nvidia-ctk
install -m 755 -t %{buildroot}%{_bindir} nvidia-cdi-hook install -m 755 -t %{buildroot}%{_bindir} nvidia-cdi-hook
install -m 644 -t %{buildroot}%{_sysconfdir}/systemd/system nvidia-cdi-refresh.service
install -m 644 -t %{buildroot}%{_sysconfdir}/systemd/system nvidia-cdi-refresh.path
%post %post
if [ $1 -gt 1 ]; then # only on package upgrade if [ $1 -gt 1 ]; then # only on package upgrade
@@ -51,17 +45,6 @@ if [ $1 -gt 1 ]; then # only on package upgrade
cp -af %{_bindir}/nvidia-container-runtime-hook %{_localstatedir}/lib/rpm-state/nvidia-container-toolkit cp -af %{_bindir}/nvidia-container-runtime-hook %{_localstatedir}/lib/rpm-state/nvidia-container-toolkit
fi fi
# Reload systemd unit cache
if command -v systemctl >/dev/null 2>&1 \
&& systemctl --quiet is-system-running 2>/dev/null; then
systemctl daemon-reload || true
# On fresh install ($1 == 1) enable the path unit so it starts at boot
if [ "$1" -eq 1 ]; then
systemctl enable --now nvidia-cdi-refresh.path || true
fi
fi
%posttrans %posttrans
if [ ! -e %{_bindir}/nvidia-container-runtime-hook ]; then if [ ! -e %{_bindir}/nvidia-container-runtime-hook ]; then
# repairing lost file nvidia-container-runtime-hook # repairing lost file nvidia-container-runtime-hook
@@ -106,8 +89,6 @@ Provides tools such as the NVIDIA Container Runtime and NVIDIA Container Toolkit
%{_bindir}/nvidia-container-runtime %{_bindir}/nvidia-container-runtime
%{_bindir}/nvidia-ctk %{_bindir}/nvidia-ctk
%{_bindir}/nvidia-cdi-hook %{_bindir}/nvidia-cdi-hook
%{_sysconfdir}/systemd/system/nvidia-cdi-refresh.service
%{_sysconfdir}/systemd/system/nvidia-cdi-refresh.path
# The OPERATOR EXTENSIONS package consists of components that are required to enable GPU support in Kubernetes. # The OPERATOR EXTENSIONS package consists of components that are required to enable GPU support in Kubernetes.
# This package is not distributed as part of the NVIDIA Container Toolkit RPMs. # This package is not distributed as part of the NVIDIA Container Toolkit RPMs.

View File

@@ -21,13 +21,12 @@ import (
"tags.cncf.io/container-device-interface/pkg/cdi" "tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go" "tags.cncf.io/container-device-interface/specs-go"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
) )
// Interface defines the API for the nvcdi package // Interface defines the API for the nvcdi package
type Interface interface { type Interface interface {
GetSpec(...string) (spec.Interface, error) GetSpec() (spec.Interface, error)
GetCommonEdits() (*cdi.ContainerEdits, error) GetCommonEdits() (*cdi.ContainerEdits, error)
GetAllDeviceSpecs() ([]specs.Device, error) GetAllDeviceSpecs() ([]specs.Device, error)
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)
@@ -37,32 +36,14 @@ type Interface interface {
GetDeviceSpecsByID(...string) ([]specs.Device, error) GetDeviceSpecsByID(...string) ([]specs.Device, error)
} }
// A HookName represents one of the predefined NVIDIA CDI hooks. // A HookName refers to one of the predefined set of CDI hooks that may be
type HookName = discover.HookName // included in the generated CDI specification.
type HookName string
const ( const (
// AllHooks is a special hook name that allows all hooks to be matched. // HookEnableCudaCompat refers to the hook used to enable CUDA Forward Compatibility.
AllHooks = discover.AllHooks // This was added with v1.17.5 of the NVIDIA Container Toolkit.
HookEnableCudaCompat = HookName("enable-cuda-compat")
// A CreateSymlinksHook is used to create symlinks in the container.
CreateSymlinksHook = discover.CreateSymlinksHook
// DisableDeviceNodeModificationHook refers to the hook used to ensure that
// device nodes are not created by libnvidia-ml.so or nvidia-smi in a
// container.
// Added in v1.17.8
DisableDeviceNodeModificationHook = discover.DisableDeviceNodeModificationHook
// An EnableCudaCompatHook is used to enabled CUDA Forward Compatibility.
// Added in v1.17.5
EnableCudaCompatHook = discover.EnableCudaCompatHook
// An UpdateLDCacheHook is used to update the ldcache in the container.
UpdateLDCacheHook = discover.UpdateLDCacheHook
// Deprecated: Use CreateSymlinksHook instead.
HookCreateSymlinks = CreateSymlinksHook
// Deprecated: Use EnableCudaCompatHook instead.
HookEnableCudaCompat = EnableCudaCompatHook
// Deprecated: Use UpdateLDCacheHook instead.
HookUpdateLDCache = UpdateLDCacheHook
) )
// A FeatureFlag refers to a specific feature that can be toggled in the CDI api. // A FeatureFlag refers to a specific feature that can be toggled in the CDI api.

View File

@@ -106,16 +106,15 @@ func (l *nvcdilib) NewDriverLibraryDiscoverer(version string) (discover.Discover
) )
discoverers = append(discoverers, driverDotSoSymlinksDiscoverer) discoverers = append(discoverers, driverDotSoSymlinksDiscoverer)
// TODO: The following should use the version directly. if l.HookIsSupported(HookEnableCudaCompat) {
cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver) // TODO: The following should use the version directly.
discoverers = append(discoverers, cudaCompatLibHookDiscoverer) cudaCompatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(l.logger, l.hookCreator, l.driver)
discoverers = append(discoverers, cudaCompatLibHookDiscoverer)
}
updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath) updateLDCache, _ := discover.NewLDCacheUpdateHook(l.logger, libraries, l.hookCreator, l.ldconfigPath)
discoverers = append(discoverers, updateLDCache) discoverers = append(discoverers, updateLDCache)
disableDeviceNodeModification := l.hookCreator.Create(DisableDeviceNodeModificationHook)
discoverers = append(discoverers, disableDeviceNodeModification)
d := discover.Merge(discoverers...) d := discover.Merge(discoverers...)
return d, nil return d, nil

View File

@@ -40,12 +40,13 @@ var requiredDriverStoreFiles = []string{
// newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers. // newWSLDriverDiscoverer returns a Discoverer for WSL2 drivers.
func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string) (discover.Discover, error) { func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string) (discover.Discover, error) {
if err := dxcore.Init(); err != nil { err := dxcore.Init()
return nil, fmt.Errorf("failed to initialize dxcore: %w", err) if err != nil {
return nil, fmt.Errorf("failed to initialize dxcore: %v", err)
} }
defer func() { defer func() {
if err := dxcore.Shutdown(); err != nil { if err := dxcore.Shutdown(); err != nil {
logger.Warningf("failed to shutdown dxcore: %w", err) logger.Warningf("failed to shutdown dxcore: %v", err)
} }
}() }()
@@ -53,19 +54,32 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCrea
if len(driverStorePaths) == 0 { if len(driverStorePaths) == 0 {
return nil, fmt.Errorf("no driver store paths found") return nil, fmt.Errorf("no driver store paths found")
} }
if len(driverStorePaths) > 1 {
logger.Warningf("Found multiple driver store paths: %v", driverStorePaths)
}
logger.Infof("Using WSL driver store paths: %v", driverStorePaths) logger.Infof("Using WSL driver store paths: %v", driverStorePaths)
driverStorePaths = append(driverStorePaths, "/usr/lib/wsl/lib") return newWSLDriverStoreDiscoverer(logger, driverRoot, hookCreator, ldconfigPath, driverStorePaths)
}
driverStoreMounts := discover.NewMounts( // newWSLDriverStoreDiscoverer returns a Discoverer for WSL2 drivers in the driver store associated with a dxcore adapter.
func newWSLDriverStoreDiscoverer(logger logger.Interface, driverRoot string, hookCreator discover.HookCreator, ldconfigPath string, driverStorePaths []string) (discover.Discover, error) {
var searchPaths []string
seen := make(map[string]bool)
for _, path := range driverStorePaths {
if seen[path] {
continue
}
searchPaths = append(searchPaths, path)
}
if len(searchPaths) > 1 {
logger.Warningf("Found multiple driver store paths: %v", searchPaths)
}
searchPaths = append(searchPaths, "/usr/lib/wsl/lib")
libraries := discover.NewMounts(
logger, logger,
lookup.NewFileLocator( lookup.NewFileLocator(
lookup.WithLogger(logger), lookup.WithLogger(logger),
lookup.WithSearchPaths( lookup.WithSearchPaths(
driverStorePaths..., searchPaths...,
), ),
lookup.WithCount(1), lookup.WithCount(1),
), ),
@@ -75,14 +89,14 @@ func newWSLDriverDiscoverer(logger logger.Interface, driverRoot string, hookCrea
symlinkHook := nvidiaSMISimlinkHook{ symlinkHook := nvidiaSMISimlinkHook{
logger: logger, logger: logger,
mountsFrom: driverStoreMounts, mountsFrom: libraries,
hookCreator: hookCreator, hookCreator: hookCreator,
} }
ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, driverStoreMounts, hookCreator, ldconfigPath) ldcacheHook, _ := discover.NewLDCacheUpdateHook(logger, libraries, hookCreator, ldconfigPath)
d := discover.Merge( d := discover.Merge(
driverStoreMounts, libraries,
symlinkHook, symlinkHook,
ldcacheHook, ldcacheHook,
) )
@@ -121,7 +135,7 @@ func (m nvidiaSMISimlinkHook) Hooks() ([]discover.Hook, error) {
} }
link := "/usr/bin/nvidia-smi" link := "/usr/bin/nvidia-smi"
links := []string{fmt.Sprintf("%s::%s", target, link)} links := []string{fmt.Sprintf("%s::%s", target, link)}
symlinkHook := m.hookCreator.Create(CreateSymlinksHook, links...) symlinkHook := m.hookCreator.Create("create-symlinks", links...)
return symlinkHook.Hooks() return symlinkHook.Hooks()
} }

View File

@@ -29,7 +29,7 @@ import (
func TestNvidiaSMISymlinkHook(t *testing.T) { func TestNvidiaSMISymlinkHook(t *testing.T) {
logger, _ := testlog.NewNullLogger() logger, _ := testlog.NewNullLogger()
hookCreator := discover.NewHookCreator() hookCreator := discover.NewHookCreator("nvidia-cdi-hook", false)
errMounts := errors.New("mounts error") errMounts := errors.New("mounts error")
@@ -93,7 +93,7 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
expectedHooks: []discover.Hook{ expectedHooks: []discover.Hook{
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks", Args: []string{"nvidia-cdi-hook", "create-symlinks",
"--link", "nvidia-smi::/usr/bin/nvidia-smi"}, "--link", "nvidia-smi::/usr/bin/nvidia-smi"},
Env: []string{"NVIDIA_CTK_DEBUG=false"}, Env: []string{"NVIDIA_CTK_DEBUG=false"},
@@ -114,7 +114,7 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
expectedHooks: []discover.Hook{ expectedHooks: []discover.Hook{
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks", Args: []string{"nvidia-cdi-hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"}, "--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
Env: []string{"NVIDIA_CTK_DEBUG=false"}, Env: []string{"NVIDIA_CTK_DEBUG=false"},
@@ -135,7 +135,7 @@ func TestNvidiaSMISymlinkHook(t *testing.T) {
expectedHooks: []discover.Hook{ expectedHooks: []discover.Hook{
{ {
Lifecycle: "createContainer", Lifecycle: "createContainer",
Path: "/usr/bin/nvidia-cdi-hook", Path: "nvidia-cdi-hook",
Args: []string{"nvidia-cdi-hook", "create-symlinks", Args: []string{"nvidia-cdi-hook", "create-symlinks",
"--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"}, "--link", "/some/path/nvidia-smi::/usr/bin/nvidia-smi"},
Env: []string{"NVIDIA_CTK_DEBUG=false"}, Env: []string{"NVIDIA_CTK_DEBUG=false"},

View File

@@ -58,7 +58,7 @@ func (l *gdslib) GetCommonEdits() (*cdi.ContainerEdits, error) {
// GetSpec is unsppported for the gdslib specs. // GetSpec is unsppported for the gdslib specs.
// gdslib is typically wrapped by a spec that implements GetSpec. // gdslib is typically wrapped by a spec that implements GetSpec.
func (l *gdslib) GetSpec(...string) (spec.Interface, error) { func (l *gdslib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("GetSpec is not supported") return nil, fmt.Errorf("GetSpec is not supported")
} }

View File

@@ -1,3 +1,4 @@
/**
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,13 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
**/
[Unit] package nvcdi
Description=Trigger CDI refresh on NVIDIA driver install / uninstall events
[Path] // disabledHooks allows individual hooks to be disabled.
PathChanged=/lib/modules/%v/modules.dep type disabledHooks map[HookName]bool
PathChanged=/lib/modules/%v/modules.dep.bin
[Install] // HookIsSupported checks whether a hook of the specified name is supported.
WantedBy=multi-user.target // Hooks must be explicitly disabled, meaning that if no disabled hooks are
// all hooks are supported.
func (l *nvcdilib) HookIsSupported(h HookName) bool {
if len(l.disabledHooks) == 0 {
return true
}
return !l.disabledHooks[h]
}

View File

@@ -34,7 +34,7 @@ type csvlib nvcdilib
var _ Interface = (*csvlib)(nil) var _ Interface = (*csvlib)(nil)
// GetSpec should not be called for wsllib // GetSpec should not be called for wsllib
func (l *csvlib) GetSpec(...string) (spec.Interface, error) { func (l *csvlib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("unexpected call to csvlib.GetSpec()") return nil, fmt.Errorf("unexpected call to csvlib.GetSpec()")
} }

View File

@@ -41,7 +41,7 @@ const (
) )
// GetSpec should not be called for imexlib. // GetSpec should not be called for imexlib.
func (l *imexlib) GetSpec(...string) (spec.Interface, error) { func (l *imexlib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("unexpected call to imexlib.GetSpec()") return nil, fmt.Errorf("unexpected call to imexlib.GetSpec()")
} }

View File

@@ -36,7 +36,7 @@ type nvmllib nvcdilib
var _ Interface = (*nvmllib)(nil) var _ Interface = (*nvmllib)(nil)
// GetSpec should not be called for nvmllib // GetSpec should not be called for nvmllib
func (l *nvmllib) GetSpec(...string) (spec.Interface, error) { func (l *nvmllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()") return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()")
} }

View File

@@ -32,7 +32,7 @@ type wsllib nvcdilib
var _ Interface = (*wsllib)(nil) var _ Interface = (*wsllib)(nil)
// GetSpec should not be called for wsllib // GetSpec should not be called for wsllib
func (l *wsllib) GetSpec(...string) (spec.Interface, error) { func (l *wsllib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("unexpected call to wsllib.GetSpec()") return nil, fmt.Errorf("unexpected call to wsllib.GetSpec()")
} }

View File

@@ -58,13 +58,16 @@ type nvcdilib struct {
featureFlags map[FeatureFlag]bool featureFlags map[FeatureFlag]bool
disabledHooks []discover.HookName disabledHooks disabledHooks
hookCreator discover.HookCreator hookCreator discover.HookCreator
} }
// New creates a new nvcdi library // New creates a new nvcdi library
func New(opts ...Option) (Interface, error) { func New(opts ...Option) (Interface, error) {
l := &nvcdilib{} l := &nvcdilib{
disabledHooks: make(disabledHooks),
featureFlags: make(map[FeatureFlag]bool),
}
for _, opt := range opts { for _, opt := range opts {
opt(l) opt(l)
} }
@@ -81,6 +84,9 @@ func New(opts ...Option) (Interface, error) {
if l.nvidiaCDIHookPath == "" { if l.nvidiaCDIHookPath == "" {
l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook" l.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
} }
// create hookCreator
l.hookCreator = discover.NewHookCreator(l.nvidiaCDIHookPath, false)
if l.driverRoot == "" { if l.driverRoot == "" {
l.driverRoot = "/" l.driverRoot = "/"
} }
@@ -130,7 +136,7 @@ func New(opts ...Option) (Interface, error) {
l.vendor = "management.nvidia.com" l.vendor = "management.nvidia.com"
} }
// Management containers in general do not require CUDA Forward compatibility. // Management containers in general do not require CUDA Forward compatibility.
l.disabledHooks = append(l.disabledHooks, HookEnableCudaCompat, DisableDeviceNodeModificationHook) l.disabledHooks[HookEnableCudaCompat] = true
lib = (*managementlib)(l) lib = (*managementlib)(l)
case ModeNvml: case ModeNvml:
lib = (*nvmllib)(l) lib = (*nvmllib)(l)
@@ -155,12 +161,6 @@ func New(opts ...Option) (Interface, error) {
return nil, fmt.Errorf("unknown mode %q", l.mode) return nil, fmt.Errorf("unknown mode %q", l.mode)
} }
// create hookCreator
l.hookCreator = discover.NewHookCreator(
discover.WithNVIDIACDIHookPath(l.nvidiaCDIHookPath),
discover.WithDisabledHooks(l.disabledHooks...),
)
w := wrapper{ w := wrapper{
Interface: lib, Interface: lib,
vendor: l.vendor, vendor: l.vendor,

View File

@@ -180,7 +180,7 @@ func (m managementDiscoverer) nodeIsBlocked(path string) bool {
// GetSpec is unsppported for the managementlib specs. // GetSpec is unsppported for the managementlib specs.
// managementlib is typically wrapped by a spec that implements GetSpec. // managementlib is typically wrapped by a spec that implements GetSpec.
func (m *managementlib) GetSpec(...string) (spec.Interface, error) { func (m *managementlib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("GetSpec is not supported") return nil, fmt.Errorf("GetSpec is not supported")
} }

View File

@@ -58,7 +58,7 @@ func (l *mofedlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
// GetSpec is unsppported for the mofedlib specs. // GetSpec is unsppported for the mofedlib specs.
// mofedlib is typically wrapped by a spec that implements GetSpec. // mofedlib is typically wrapped by a spec that implements GetSpec.
func (l *mofedlib) GetSpec(...string) (spec.Interface, error) { func (l *mofedlib) GetSpec() (spec.Interface, error) {
return nil, fmt.Errorf("GetSpec is not supported") return nil, fmt.Errorf("GetSpec is not supported")
} }

View File

@@ -21,7 +21,6 @@ import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info" "github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
"github.com/NVIDIA/go-nvml/pkg/nvml" "github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform"
) )
@@ -159,9 +158,12 @@ func WithLibrarySearchPaths(paths []string) Option {
// WithDisabledHook allows specific hooks to the disabled. // WithDisabledHook allows specific hooks to the disabled.
// This option can be specified multiple times for each hook. // This option can be specified multiple times for each hook.
func WithDisabledHook[T string | HookName](hook T) Option { func WithDisabledHook(hook HookName) Option {
return func(o *nvcdilib) { return func(o *nvcdilib) {
o.disabledHooks = append(o.disabledHooks, discover.HookName(hook)) if o.disabledHooks == nil {
o.disabledHooks = make(map[HookName]bool)
}
o.disabledHooks[hook] = true
} }
} }

View File

@@ -61,9 +61,18 @@ func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get device subfolders: %v", err) return nil, fmt.Errorf("failed to get device subfolders: %v", err)
} }
if len(folders) == 0 {
return nil, nil
}
//nolint:staticcheck // The ChmodHook is deprecated and will be removed in a future release. args := []string{"--mode", "755"}
return d.hookCreator.Create(discover.ChmodHook, folders...).Hooks() for _, folder := range folders {
args = append(args, "--path", folder)
}
hook := d.hookCreator.Create("chmod", args...)
return []discover.Hook{*hook}, nil
} }
func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) { func (d *deviceFolderPermissions) getDeviceSubfolders() ([]string, error) {

View File

@@ -35,11 +35,8 @@ type wrapper struct {
} }
// GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface. // GetSpec combines the device specs and common edits from the wrapped Interface to a single spec.Interface.
func (l *wrapper) GetSpec(devices ...string) (spec.Interface, error) { func (l *wrapper) GetSpec() (spec.Interface, error) {
if len(devices) == 0 { deviceSpecs, err := l.GetAllDeviceSpecs()
devices = append(devices, "all")
}
deviceSpecs, err := l.GetDeviceSpecsByID(devices...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,16 +55,6 @@ func (l *wrapper) GetSpec(devices ...string) (spec.Interface, error) {
) )
} }
func (l *wrapper) GetDeviceSpecsByID(devices ...string) ([]specs.Device, error) {
for _, device := range devices {
if device != "all" {
continue
}
return l.GetAllDeviceSpecs()
}
return l.Interface.GetDeviceSpecsByID(devices...)
}
// GetAllDeviceSpecs returns the device specs for all available devices. // GetAllDeviceSpecs returns the device specs for all available devices.
func (l *wrapper) GetAllDeviceSpecs() ([]specs.Device, error) { func (l *wrapper) GetAllDeviceSpecs() ([]specs.Device, error) {
return l.Interface.GetAllDeviceSpecs() return l.Interface.GetAllDeviceSpecs()

View File

@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,8 +20,14 @@ LOG_ARTIFACTS_DIR ?= $(CURDIR)/e2e_logs
GINKGO_BIN := $(CURDIR)/bin/ginkgo GINKGO_BIN := $(CURDIR)/bin/ginkgo
# If GINKGO_FOCUS is not set, run all tests
# current available tests:
# - nvidia-container-cli
# - docker
GINKGO_FOCUS ?=
test: $(GINKGO_BIN) test: $(GINKGO_BIN)
$(GINKGO_BIN) $(GINKGO_ARGS) -v --json-report ginkgo.json ./tests/e2e/... $(GINKGO_BIN) $(GINKGO_ARGS) -v --json-report ginkgo.json --focus="$(GINKGO_FOCUS)" ./tests/e2e/...
$(GINKGO_BIN): $(GINKGO_BIN):
mkdir -p $(CURDIR)/bin mkdir -p $(CURDIR)/bin

View File

@@ -28,11 +28,20 @@ var dockerInstallTemplate = `
#! /usr/bin/env bash #! /usr/bin/env bash
set -xe set -xe
: ${IMAGE:={{.Image}}} # if the TEMP_DIR is already set, use it
if [ -f /tmp/ctk_e2e_temp_dir.txt ]; then
TEMP_DIR=$(cat /tmp/ctk_e2e_temp_dir.txt)
else
TEMP_DIR="/tmp/ctk_e2e.$(date +%s)_$RANDOM"
echo "$TEMP_DIR" > /tmp/ctk_e2e_temp_dir.txt
fi
# Create a temporary directory # if TEMP_DIR does not exist, create it
TEMP_DIR="/tmp/ctk_e2e.$(date +%s)_$RANDOM" if [ ! -d "$TEMP_DIR" ]; then
mkdir -p "$TEMP_DIR" mkdir -p "$TEMP_DIR"
fi
: ${IMAGE:={{.Image}}}
# Given that docker has an init function that checks for the existence of the # Given that docker has an init function that checks for the existence of the
# nvidia-container-toolkit, we need to create a symlink to the nvidia-container-runtime-hook # nvidia-container-toolkit, we need to create a symlink to the nvidia-container-runtime-hook

View File

@@ -0,0 +1,208 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package e2e
import (
"context"
"fmt"
"strings"
"text/template"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
const (
dockerDindTemplate = `docker run -d --rm --privileged \
-v {{.SharedDir}}/etc/docker:/etc/docker \
-v {{.SharedDir}}/run/nvidia:/run/nvidia \
-v {{.SharedDir}}/usr/local/nvidia:/usr/local/nvidia \
--name {{.ContainerName}} \
docker:dind -H unix://{{.DockerSocket}}`
dockerToolkitTemplate = `docker run -d --rm --privileged \
--volumes-from {{.DindContainerName}} \
--pid "container:{{.DindContainerName}}" \
-e RUNTIME_ARGS="--socket {{.DockerSocket}}" \
-v {{.TestScriptPath}}:/usr/local/bin/libnvidia-container-cli.sh \
--name {{.ContainerName}} \
{{.ToolkitImage}} /usr/local/bin/libnvidia-container-cli.sh`
dockerDefaultConfigTemplate = `
{
"registry-mirrors": ["https://mirror.gcr.io"]
}`
libnvidiaContainerCliTestTemplate = `#!/usr/bin/env bash
set -euo pipefail
apt-get update -y && apt-get install -y curl gnupg2
WORKDIR="$(mktemp -d)"
ROOTFS="${WORKDIR}/rootfs"
mkdir -p "${ROOTFS}"
export WORKDIR ROOTFS # make them visible in the child shell
unshare --mount --pid --fork --propagation private -- bash -eux <<'IN_NS'
: "${ROOTFS:?}" "${WORKDIR:?}" # abort if either is empty
# 1 Populate minimal Ubuntu base
curl -L http://cdimage.ubuntu.com/ubuntu-base/releases/22.04/release/ubuntu-base-22.04-base-amd64.tar.gz \
| tar -C "$ROOTFS" -xz
# 2 Add non-root user
useradd -R "$ROOTFS" -U -u 1000 -s /bin/bash nvidia
# 3 Bind-mount new root and unshare mounts
mount --bind "$ROOTFS" "$ROOTFS"
mount --make-private "$ROOTFS"
cd "$ROOTFS"
# 4 Minimal virtual filesystems
mount -t proc proc proc
mount -t sysfs sys sys
mount -t tmpfs tmp tmp
mount -t tmpfs run run
# 5 GPU setup
nvidia-container-cli --load-kmods --debug=container-cli.log \
configure --ldconfig=@/sbin/ldconfig.real \
--no-cgroups --utility --device=0 "$(pwd)"
# 6 Switch root
mkdir -p mnt
pivot_root . mnt
umount -l /mnt
exec nvidia-smi -L
IN_NS
`
)
// Integration tests for Docker runtime
var _ = Describe("nvidia-container-cli", Ordered, ContinueOnFailure, func() {
var runner Runner
var sharedDir string
var dindContainerName string
var toolkitContainerName string
var dockerSocket string
var hostOutput string
// Install the NVIDIA Container Toolkit
BeforeAll(func(ctx context.Context) {
runner = NewRunner(
WithHost(sshHost),
WithPort(sshPort),
WithSshKey(sshKey),
WithSshUser(sshUser),
)
// Setup shared directory and container names
sharedDir = "/tmp/nvidia-container-toolkit-test"
dindContainerName = "nvidia-container-toolkit-dind"
toolkitContainerName = "nvidia-container-toolkit-test"
dockerSocket = "/run/nvidia/docker.sock"
// Get host nvidia-smi output
var err error
hostOutput, _, err = runner.Run("nvidia-smi -L")
Expect(err).ToNot(HaveOccurred())
// Pull ubuntu image
_, _, err = runner.Run("docker pull ubuntu")
Expect(err).ToNot(HaveOccurred())
// Create shared directory structure
_, _, err = runner.Run(fmt.Sprintf("mkdir -p %s/{etc/docker,run/nvidia,usr/local/nvidia}", sharedDir))
Expect(err).ToNot(HaveOccurred())
// Copy docker default config
createDockerConfigCmd := fmt.Sprintf("cat > %s/etc/docker/daemon.json <<'EOF'\n%s\nEOF",
sharedDir, dockerDefaultConfigTemplate)
_, _, err = runner.Run(createDockerConfigCmd)
Expect(err).ToNot(HaveOccurred())
// Start Docker-in-Docker container
tmpl, err := template.New("dockerDind").Parse(dockerDindTemplate)
Expect(err).ToNot(HaveOccurred())
var dindCmdBuilder strings.Builder
err = tmpl.Execute(&dindCmdBuilder, map[string]string{
"SharedDir": sharedDir,
"ContainerName": dindContainerName,
"DockerSocket": dockerSocket,
})
Expect(err).ToNot(HaveOccurred())
_, _, err = runner.Run(dindCmdBuilder.String())
Expect(err).ToNot(HaveOccurred())
})
AfterAll(func(ctx context.Context) {
// Cleanup containers
runner.Run(fmt.Sprintf("docker rm -f %s", toolkitContainerName))
runner.Run(fmt.Sprintf("docker rm -f %s", dindContainerName))
// Cleanup shared directory
_, _, err := runner.Run(fmt.Sprintf("rm -rf %s", sharedDir))
Expect(err).ToNot(HaveOccurred())
})
When("running nvidia-smi -L", Ordered, func() {
It("should support NVIDIA_VISIBLE_DEVICES and NVIDIA_DRIVER_CAPABILITIES", func(ctx context.Context) {
// 1. Create the test script
testScriptPath := fmt.Sprintf("%s/libnvidia-container-cli.sh", sharedDir)
createScriptCmd := fmt.Sprintf("cat > %s <<'EOF'\n%s\nEOF\nchmod +x %s",
testScriptPath, libnvidiaContainerCliTestTemplate, testScriptPath)
_, _, err := runner.Run(createScriptCmd)
Expect(err).ToNot(HaveOccurred())
// 2. Start the toolkit container
tmpl, err := template.New("dockerToolkit").Parse(dockerToolkitTemplate)
Expect(err).ToNot(HaveOccurred())
var toolkitCmdBuilder strings.Builder
err = tmpl.Execute(&toolkitCmdBuilder, map[string]string{
"DindContainerName": dindContainerName,
"ContainerName": toolkitContainerName,
"DockerSocket": dockerSocket,
"TestScriptPath": testScriptPath,
"ToolkitImage": imageName + ":" + imageTag,
})
Expect(err).ToNot(HaveOccurred())
_, _, err = runner.Run(toolkitCmdBuilder.String())
Expect(err).ToNot(HaveOccurred())
// 3. Wait for and verify the output
expected := strings.TrimSpace(strings.ReplaceAll(hostOutput, "\r", ""))
Eventually(func() string {
logs, _, err := runner.Run(fmt.Sprintf("docker logs %s | tail -n 20", toolkitContainerName))
if err != nil {
return ""
}
logLines := strings.Split(strings.TrimSpace(logs), "\n")
if len(logLines) == 0 {
return ""
}
return strings.TrimSpace(strings.ReplaceAll(logLines[len(logLines)-1], "\r", ""))
}, "5m", "5s").Should(Equal(expected))
})
})
})

View File

@@ -216,23 +216,4 @@ var _ = Describe("docker", Ordered, ContinueOnFailure, func() {
Expect(ldconfigOut).To(ContainSubstring("/usr/lib64")) Expect(ldconfigOut).To(ContainSubstring("/usr/lib64"))
}) })
}) })
Describe("Disabling device node creation", Ordered, func() {
BeforeAll(func(ctx context.Context) {
_, _, err := runner.Run("docker pull ubuntu")
Expect(err).ToNot(HaveOccurred())
})
It("should work with nvidia-container-runtime-hook", func(ctx context.Context) {
output, _, err := runner.Run("docker run --rm -i --runtime=runc --gpus=all ubuntu bash -c \"grep ModifyDeviceFiles: /proc/driver/nvidia/params\"")
Expect(err).ToNot(HaveOccurred())
Expect(output).To(Equal("ModifyDeviceFiles: 0\n"))
})
It("should work with automatic CDI spec generation", func(ctx context.Context) {
output, _, err := runner.Run("docker run --rm -i --runtime=nvidia -e NVIDIA_VISIBLE_DEVICES=runtime.nvidia.com/gpu=all ubuntu bash -c \"grep ModifyDeviceFiles: /proc/driver/nvidia/params\"")
Expect(err).ToNot(HaveOccurred())
Expect(output).To(Equal("ModifyDeviceFiles: 0\n"))
})
})
}) })

View File

@@ -7,7 +7,7 @@ toolchain go1.24.1
require ( require (
github.com/onsi/ginkgo/v2 v2.23.4 github.com/onsi/ginkgo/v2 v2.23.4
github.com/onsi/gomega v1.37.0 github.com/onsi/gomega v1.37.0
golang.org/x/crypto v0.39.0 golang.org/x/crypto v0.38.0
) )
require ( require (
@@ -16,9 +16,9 @@ require (
github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-cmp v0.7.0 // indirect
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect
golang.org/x/net v0.40.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.33.0 // indirect golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect golang.org/x/text v0.25.0 // indirect
golang.org/x/tools v0.33.0 // indirect golang.org/x/tools v0.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@@ -24,18 +24,18 @@ github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcU
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw=
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ=
google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM=
google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -0,0 +1 @@
{"ociVersion":"1.0.1-dev","process":{"terminal":true,"user":{"uid":0,"gid":0},"args":["sh"],"env":["PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin","TERM=xterm"],"cwd":"/","capabilities":{"bounding":["CAP_AUDIT_WRITE","CAP_KILL","CAP_NET_BIND_SERVICE"],"effective":["CAP_AUDIT_WRITE","CAP_KILL","CAP_NET_BIND_SERVICE"],"inheritable":["CAP_AUDIT_WRITE","CAP_KILL","CAP_NET_BIND_SERVICE"],"permitted":["CAP_AUDIT_WRITE","CAP_KILL","CAP_NET_BIND_SERVICE"],"ambient":["CAP_AUDIT_WRITE","CAP_KILL","CAP_NET_BIND_SERVICE"]},"rlimits":[{"type":"RLIMIT_NOFILE","hard":1024,"soft":1024}],"noNewPrivileges":true},"root":{"path":"rootfs","readonly":true},"hostname":"runc","mounts":[{"destination":"/proc","type":"proc","source":"proc"},{"destination":"/dev","type":"tmpfs","source":"tmpfs","options":["nosuid","strictatime","mode=755","size=65536k"]},{"destination":"/dev/pts","type":"devpts","source":"devpts","options":["nosuid","noexec","newinstance","ptmxmode=0666","mode=0620","gid=5"]},{"destination":"/dev/shm","type":"tmpfs","source":"shm","options":["nosuid","noexec","nodev","mode=1777","size=65536k"]},{"destination":"/dev/mqueue","type":"mqueue","source":"mqueue","options":["nosuid","noexec","nodev"]},{"destination":"/sys","type":"sysfs","source":"sysfs","options":["nosuid","noexec","nodev","ro"]},{"destination":"/sys/fs/cgroup","type":"cgroup","source":"cgroup","options":["nosuid","noexec","nodev","relatime","ro"]}],"hooks":{"prestart":[{"path":"nvidia-container-runtime-hook","args":["nvidia-container-runtime-hook","prestart"]}]},"linux":{"resources":{"devices":[{"allow":false,"access":"rwm"}]},"namespaces":[{"type":"pid"},{"type":"network"},{"type":"ipc"},{"type":"uts"},{"type":"mount"}],"maskedPaths":["/proc/kcore","/proc/latency_stats","/proc/timer_list","/proc/timer_stats","/proc/sched_debug","/sys/firmware","/proc/scsi"],"readonlyPaths":["/proc/asound","/proc/bus","/proc/fs","/proc/irq","/proc/sys","/proc/sysrq-trigger"]}}

View File

@@ -20,19 +20,14 @@ import (
// returned by MultiAlgorithmSigner and don't appear in the Signature.Format // returned by MultiAlgorithmSigner and don't appear in the Signature.Format
// field. // field.
const ( const (
CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com"
// Deprecated: DSA is only supported at insecure key sizes, and was removed CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com"
// from major implementations. CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoDSAv01 = InsecureCertAlgoDSAv01 CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
// Deprecated: DSA is only supported at insecure key sizes, and was removed CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
// from major implementations. CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
InsecureCertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com"
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com"
// CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a // CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a
// Certificate.Type (or PublicKey.Type), but only in // Certificate.Type (or PublicKey.Type), but only in
@@ -490,16 +485,16 @@ func (c *Certificate) SignCert(rand io.Reader, authority Signer) error {
// //
// This map must be kept in sync with the one in agent/client.go. // This map must be kept in sync with the one in agent/client.go.
var certKeyAlgoNames = map[string]string{ var certKeyAlgoNames = map[string]string{
CertAlgoRSAv01: KeyAlgoRSA, CertAlgoRSAv01: KeyAlgoRSA,
CertAlgoRSASHA256v01: KeyAlgoRSASHA256, CertAlgoRSASHA256v01: KeyAlgoRSASHA256,
CertAlgoRSASHA512v01: KeyAlgoRSASHA512, CertAlgoRSASHA512v01: KeyAlgoRSASHA512,
InsecureCertAlgoDSAv01: InsecureKeyAlgoDSA, CertAlgoDSAv01: KeyAlgoDSA,
CertAlgoECDSA256v01: KeyAlgoECDSA256, CertAlgoECDSA256v01: KeyAlgoECDSA256,
CertAlgoECDSA384v01: KeyAlgoECDSA384, CertAlgoECDSA384v01: KeyAlgoECDSA384,
CertAlgoECDSA521v01: KeyAlgoECDSA521, CertAlgoECDSA521v01: KeyAlgoECDSA521,
CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256,
CertAlgoED25519v01: KeyAlgoED25519, CertAlgoED25519v01: KeyAlgoED25519,
CertAlgoSKED25519v01: KeyAlgoSKED25519, CertAlgoSKED25519v01: KeyAlgoSKED25519,
} }
// underlyingAlgo returns the signature algorithm associated with algo (which is // underlyingAlgo returns the signature algorithm associated with algo (which is

View File

@@ -58,11 +58,11 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
type cipherMode struct { type cipherMode struct {
keySize int keySize int
ivSize int ivSize int
create func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
} }
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
stream, err := createFunc(key, iv) stream, err := createFunc(key, iv)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -98,36 +98,36 @@ func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream,
var cipherModes = map[string]*cipherMode{ var cipherModes = map[string]*cipherMode{
// Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms // Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC. // are defined in the order specified in the RFC.
CipherAES128CTR: {16, aes.BlockSize, streamCipherMode(0, newAESCTR)}, "aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
CipherAES192CTR: {24, aes.BlockSize, streamCipherMode(0, newAESCTR)}, "aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
CipherAES256CTR: {32, aes.BlockSize, streamCipherMode(0, newAESCTR)}, "aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
// Ciphers from RFC 4345, which introduces security-improved arcfour ciphers. // Ciphers from RFC 4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC. // They are defined in the order specified in the RFC.
InsecureCipherRC4128: {16, 0, streamCipherMode(1536, newRC4)}, "arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
InsecureCipherRC4256: {32, 0, streamCipherMode(1536, newRC4)}, "arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution." // RC4) has problems with weak keys, and should be used with caution."
// RFC 4345 introduces improved versions of Arcfour. // RFC 4345 introduces improved versions of Arcfour.
InsecureCipherRC4: {16, 0, streamCipherMode(0, newRC4)}, "arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AEAD ciphers // AEAD ciphers
CipherAES128GCM: {16, 12, newGCMCipher}, gcm128CipherID: {16, 12, newGCMCipher},
CipherAES256GCM: {32, 12, newGCMCipher}, gcm256CipherID: {32, 12, newGCMCipher},
CipherChaCha20Poly1305: {64, 0, newChaCha20Cipher}, chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
// CBC mode is insecure and so is not included in the default config. // CBC mode is insecure and so is not included in the default config.
// (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely // (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely
// needed, it's possible to specify a custom Config to enable it. // needed, it's possible to specify a custom Config to enable it.
// You should expect that an active attacker can recover plaintext if // You should expect that an active attacker can recover plaintext if
// you do. // you do.
InsecureCipherAES128CBC: {16, aes.BlockSize, newAESCBCCipher}, aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
// 3des-cbc is insecure and is not included in the default // 3des-cbc is insecure and is not included in the default
// config. // config.
InsecureCipherTripleDESCBC: {24, des.BlockSize, newTripleDESCBCCipher}, tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
} }
// prefixLen is the length of the packet prefix that contains the packet length // prefixLen is the length of the packet prefix that contains the packet length
@@ -307,7 +307,7 @@ type gcmCipher struct {
buf []byte buf []byte
} }
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) { func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key) c, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -429,7 +429,7 @@ type cbcCipher struct {
oracleCamouflage uint32 oracleCamouflage uint32
} }
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
cbc := &cbcCipher{ cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey), mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv), decrypter: cipher.NewCBCDecrypter(c, iv),
@@ -443,7 +443,7 @@ func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs DirectionAlgorith
return cbc, nil return cbc, nil
} }
func newAESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key) c, err := aes.NewCipher(key)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -457,7 +457,7 @@ func newAESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCi
return cbc, nil return cbc, nil
} }
func newTripleDESCBCCipher(key, iv, macKey []byte, algs DirectionAlgorithms) (packetCipher, error) { func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := des.NewTripleDESCipher(key) c, err := des.NewTripleDESCipher(key)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -635,6 +635,8 @@ func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader
return nil return nil
} }
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com // chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
// AEAD, which is described here: // AEAD, which is described here:
// //
@@ -648,7 +650,7 @@ type chacha20Poly1305Cipher struct {
buf []byte buf []byte
} }
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs DirectionAlgorithms) (packetCipher, error) { func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
if len(key) != 64 { if len(key) != 64 {
panic(len(key)) panic(len(key))
} }

View File

@@ -110,7 +110,6 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
} }
c.sessionID = c.transport.getSessionID() c.sessionID = c.transport.getSessionID()
c.algorithms = c.transport.getAlgorithms()
return c.clientAuthenticate(config) return c.clientAuthenticate(config)
} }

View File

@@ -10,7 +10,6 @@ import (
"fmt" "fmt"
"io" "io"
"math" "math"
"slices"
"sync" "sync"
_ "crypto/sha1" _ "crypto/sha1"
@@ -25,258 +24,69 @@ const (
serviceSSH = "ssh-connection" serviceSSH = "ssh-connection"
) )
// The ciphers currently or previously implemented by this library, to use in // supportedCiphers lists ciphers we support but might not recommend.
// [Config.Ciphers]. For a list, see the [Algorithms.Ciphers] returned by var supportedCiphers = []string{
// [SupportedAlgorithms] or [InsecureAlgorithms]. "aes128-ctr", "aes192-ctr", "aes256-ctr",
const ( "aes128-gcm@openssh.com", gcm256CipherID,
CipherAES128GCM = "aes128-gcm@openssh.com" chacha20Poly1305ID,
CipherAES256GCM = "aes256-gcm@openssh.com" "arcfour256", "arcfour128", "arcfour",
CipherChaCha20Poly1305 = "chacha20-poly1305@openssh.com" aes128cbcID,
CipherAES128CTR = "aes128-ctr" tripledescbcID,
CipherAES192CTR = "aes192-ctr"
CipherAES256CTR = "aes256-ctr"
InsecureCipherAES128CBC = "aes128-cbc"
InsecureCipherTripleDESCBC = "3des-cbc"
InsecureCipherRC4 = "arcfour"
InsecureCipherRC4128 = "arcfour128"
InsecureCipherRC4256 = "arcfour256"
)
// The key exchanges currently or previously implemented by this library, to use
// in [Config.KeyExchanges]. For a list, see the
// [Algorithms.KeyExchanges] returned by [SupportedAlgorithms] or
// [InsecureAlgorithms].
const (
InsecureKeyExchangeDH1SHA1 = "diffie-hellman-group1-sha1"
InsecureKeyExchangeDH14SHA1 = "diffie-hellman-group14-sha1"
KeyExchangeDH14SHA256 = "diffie-hellman-group14-sha256"
KeyExchangeDH16SHA512 = "diffie-hellman-group16-sha512"
KeyExchangeECDHP256 = "ecdh-sha2-nistp256"
KeyExchangeECDHP384 = "ecdh-sha2-nistp384"
KeyExchangeECDHP521 = "ecdh-sha2-nistp521"
KeyExchangeCurve25519 = "curve25519-sha256"
InsecureKeyExchangeDHGEXSHA1 = "diffie-hellman-group-exchange-sha1"
KeyExchangeDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
// KeyExchangeMLKEM768X25519 is supported from Go 1.24.
KeyExchangeMLKEM768X25519 = "mlkem768x25519-sha256"
// An alias for KeyExchangeCurve25519SHA256. This kex ID will be added if
// KeyExchangeCurve25519SHA256 is requested for backward compatibility with
// OpenSSH versions up to 7.2.
keyExchangeCurve25519LibSSH = "curve25519-sha256@libssh.org"
)
// The message authentication code (MAC) currently or previously implemented by
// this library, to use in [Config.MACs]. For a list, see the
// [Algorithms.MACs] returned by [SupportedAlgorithms] or
// [InsecureAlgorithms].
const (
HMACSHA256ETM = "hmac-sha2-256-etm@openssh.com"
HMACSHA512ETM = "hmac-sha2-512-etm@openssh.com"
HMACSHA256 = "hmac-sha2-256"
HMACSHA512 = "hmac-sha2-512"
HMACSHA1 = "hmac-sha1"
InsecureHMACSHA196 = "hmac-sha1-96"
)
var (
// supportedKexAlgos specifies key-exchange algorithms implemented by this
// package in preference order, excluding those with security issues.
supportedKexAlgos = []string{
KeyExchangeCurve25519,
KeyExchangeECDHP256,
KeyExchangeECDHP384,
KeyExchangeECDHP521,
KeyExchangeDH14SHA256,
KeyExchangeDH16SHA512,
KeyExchangeDHGEXSHA256,
}
// defaultKexAlgos specifies the default preference for key-exchange
// algorithms in preference order.
defaultKexAlgos = []string{
KeyExchangeCurve25519,
KeyExchangeECDHP256,
KeyExchangeECDHP384,
KeyExchangeECDHP521,
KeyExchangeDH14SHA256,
InsecureKeyExchangeDH14SHA1,
}
// insecureKexAlgos specifies key-exchange algorithms implemented by this
// package and which have security issues.
insecureKexAlgos = []string{
InsecureKeyExchangeDH14SHA1,
InsecureKeyExchangeDH1SHA1,
InsecureKeyExchangeDHGEXSHA1,
}
// supportedCiphers specifies cipher algorithms implemented by this package
// in preference order, excluding those with security issues.
supportedCiphers = []string{
CipherAES128GCM,
CipherAES256GCM,
CipherChaCha20Poly1305,
CipherAES128CTR,
CipherAES192CTR,
CipherAES256CTR,
}
// defaultCiphers specifies the default preference for ciphers algorithms
// in preference order.
defaultCiphers = supportedCiphers
// insecureCiphers specifies cipher algorithms implemented by this
// package and which have security issues.
insecureCiphers = []string{
InsecureCipherAES128CBC,
InsecureCipherTripleDESCBC,
InsecureCipherRC4256,
InsecureCipherRC4128,
InsecureCipherRC4,
}
// supportedMACs specifies MAC algorithms implemented by this package in
// preference order, excluding those with security issues.
supportedMACs = []string{
HMACSHA256ETM,
HMACSHA512ETM,
HMACSHA256,
HMACSHA512,
HMACSHA1,
}
// defaultMACs specifies the default preference for MAC algorithms in
// preference order.
defaultMACs = []string{
HMACSHA256ETM,
HMACSHA512ETM,
HMACSHA256,
HMACSHA512,
HMACSHA1,
InsecureHMACSHA196,
}
// insecureMACs specifies MAC algorithms implemented by this
// package and which have security issues.
insecureMACs = []string{
InsecureHMACSHA196,
}
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e.
// methods of authenticating servers) implemented by this package in
// preference order, excluding those with security issues.
supportedHostKeyAlgos = []string{
CertAlgoRSASHA256v01,
CertAlgoRSASHA512v01,
CertAlgoECDSA256v01,
CertAlgoECDSA384v01,
CertAlgoECDSA521v01,
CertAlgoED25519v01,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoED25519,
}
// defaultHostKeyAlgos specifies the default preference for host-key
// algorithms in preference order.
defaultHostKeyAlgos = []string{
CertAlgoRSASHA256v01,
CertAlgoRSASHA512v01,
CertAlgoRSAv01,
InsecureCertAlgoDSAv01,
CertAlgoECDSA256v01,
CertAlgoECDSA384v01,
CertAlgoECDSA521v01,
CertAlgoED25519v01,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
KeyAlgoRSA,
InsecureKeyAlgoDSA,
KeyAlgoED25519,
}
// insecureHostKeyAlgos specifies host-key algorithms implemented by this
// package and which have security issues.
insecureHostKeyAlgos = []string{
KeyAlgoRSA,
InsecureKeyAlgoDSA,
CertAlgoRSAv01,
InsecureCertAlgoDSAv01,
}
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate
// types since those use the underlying algorithm. Order is irrelevant.
supportedPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519,
KeyAlgoSKECDSA256,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
}
// defaultPubKeyAuthAlgos specifies the preferred client public key
// authentication algorithms. This list is sent to the client if it supports
// the server-sig-algs extension. Order is irrelevant.
defaultPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519,
KeyAlgoSKECDSA256,
KeyAlgoECDSA256,
KeyAlgoECDSA384,
KeyAlgoECDSA521,
KeyAlgoRSASHA256,
KeyAlgoRSASHA512,
KeyAlgoRSA,
InsecureKeyAlgoDSA,
}
// insecurePubKeyAuthAlgos specifies client public key authentication
// algorithms implemented by this package and which have security issues.
insecurePubKeyAuthAlgos = []string{
KeyAlgoRSA,
InsecureKeyAlgoDSA,
}
)
// NegotiatedAlgorithms defines algorithms negotiated between client and server.
type NegotiatedAlgorithms struct {
KeyExchange string
HostKey string
Read DirectionAlgorithms
Write DirectionAlgorithms
} }
// Algorithms defines a set of algorithms that can be configured in the client // preferredCiphers specifies the default preference for ciphers.
// or server config for negotiation during a handshake. var preferredCiphers = []string{
type Algorithms struct { "aes128-gcm@openssh.com", gcm256CipherID,
KeyExchanges []string chacha20Poly1305ID,
Ciphers []string "aes128-ctr", "aes192-ctr", "aes256-ctr",
MACs []string
HostKeys []string
PublicKeyAuths []string
} }
// SupportedAlgorithms returns algorithms currently implemented by this package, // supportedKexAlgos specifies the supported key-exchange algorithms in
// excluding those with security issues, which are returned by // preference order.
// InsecureAlgorithms. The algorithms listed here are in preference order. var supportedKexAlgos = []string{
func SupportedAlgorithms() Algorithms { kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
return Algorithms{ // P384 and P521 are not constant-time yet, but since we don't
Ciphers: slices.Clone(supportedCiphers), // reuse ephemeral keys, using them for ECDH should be OK.
MACs: slices.Clone(supportedMACs), kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
KeyExchanges: slices.Clone(supportedKexAlgos), kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1,
HostKeys: slices.Clone(supportedHostKeyAlgos), kexAlgoDH1SHA1,
PublicKeyAuths: slices.Clone(supportedPubKeyAuthAlgos),
}
} }
// InsecureAlgorithms returns algorithms currently implemented by this package // serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden
// and which have security issues. // for the server half.
func InsecureAlgorithms() Algorithms { var serverForbiddenKexAlgos = map[string]struct{}{
return Algorithms{ kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests
KeyExchanges: slices.Clone(insecureKexAlgos), kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests
Ciphers: slices.Clone(insecureCiphers), }
MACs: slices.Clone(insecureMACs),
HostKeys: slices.Clone(insecureHostKeyAlgos), // preferredKexAlgos specifies the default preference for key-exchange
PublicKeyAuths: slices.Clone(insecurePubKeyAuthAlgos), // algorithms in preference order. The diffie-hellman-group16-sha512 algorithm
} // is disabled by default because it is a bit slower than the others.
var preferredKexAlgos = []string{
kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH,
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA256, kexAlgoDH14SHA1,
}
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
CertAlgoRSASHA256v01, CertAlgoRSASHA512v01,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512,
KeyAlgoRSA, KeyAlgoDSA,
KeyAlgoED25519,
}
// supportedMACs specifies a default set of MAC algorithms in preference order.
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
// because they have reached the end of their useful life.
var supportedMACs = []string{
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96",
} }
var supportedCompressions = []string{compressionNone} var supportedCompressions = []string{compressionNone}
@@ -284,13 +94,13 @@ var supportedCompressions = []string{compressionNone}
// hashFuncs keeps the mapping of supported signature algorithms to their // hashFuncs keeps the mapping of supported signature algorithms to their
// respective hashes needed for signing and verification. // respective hashes needed for signing and verification.
var hashFuncs = map[string]crypto.Hash{ var hashFuncs = map[string]crypto.Hash{
KeyAlgoRSA: crypto.SHA1, KeyAlgoRSA: crypto.SHA1,
KeyAlgoRSASHA256: crypto.SHA256, KeyAlgoRSASHA256: crypto.SHA256,
KeyAlgoRSASHA512: crypto.SHA512, KeyAlgoRSASHA512: crypto.SHA512,
InsecureKeyAlgoDSA: crypto.SHA1, KeyAlgoDSA: crypto.SHA1,
KeyAlgoECDSA256: crypto.SHA256, KeyAlgoECDSA256: crypto.SHA256,
KeyAlgoECDSA384: crypto.SHA384, KeyAlgoECDSA384: crypto.SHA384,
KeyAlgoECDSA521: crypto.SHA512, KeyAlgoECDSA521: crypto.SHA512,
// KeyAlgoED25519 doesn't pre-hash. // KeyAlgoED25519 doesn't pre-hash.
KeyAlgoSKECDSA256: crypto.SHA256, KeyAlgoSKECDSA256: crypto.SHA256,
KeyAlgoSKED25519: crypto.SHA256, KeyAlgoSKED25519: crypto.SHA256,
@@ -325,6 +135,18 @@ func isRSACert(algo string) bool {
return isRSA(algo) return isRSA(algo)
} }
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate types
// since those use the underlying algorithm. This list is sent to the client if
// it supports the server-sig-algs extension. Order is irrelevant.
var supportedPubKeyAuthAlgos = []string{
KeyAlgoED25519,
KeyAlgoSKED25519, KeyAlgoSKECDSA256,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA,
KeyAlgoDSA,
}
// unexpectedMessageError results when the SSH message that we received didn't // unexpectedMessageError results when the SSH message that we received didn't
// match what we wanted. // match what we wanted.
func unexpectedMessageError(expected, got uint8) error { func unexpectedMessageError(expected, got uint8) error {
@@ -347,21 +169,20 @@ func findCommon(what string, client []string, server []string) (common string, e
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
} }
// DirectionAlgorithms defines the algorithms negotiated in one direction // directionAlgorithms records algorithm choices in one direction (either read or write)
// (either read or write). type directionAlgorithms struct {
type DirectionAlgorithms struct {
Cipher string Cipher string
MAC string MAC string
compression string Compression string
} }
// rekeyBytes returns a rekeying intervals in bytes. // rekeyBytes returns a rekeying intervals in bytes.
func (a *DirectionAlgorithms) rekeyBytes() int64 { func (a *directionAlgorithms) rekeyBytes() int64 {
// According to RFC 4344 block ciphers should rekey after // According to RFC 4344 block ciphers should rekey after
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128. // 128.
switch a.Cipher { switch a.Cipher {
case CipherAES128CTR, CipherAES192CTR, CipherAES256CTR, CipherAES128GCM, CipherAES256GCM, InsecureCipherAES128CBC: case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID:
return 16 * (1 << 32) return 16 * (1 << 32)
} }
@@ -371,25 +192,32 @@ func (a *DirectionAlgorithms) rekeyBytes() int64 {
} }
var aeadCiphers = map[string]bool{ var aeadCiphers = map[string]bool{
CipherAES128GCM: true, gcm128CipherID: true,
CipherAES256GCM: true, gcm256CipherID: true,
CipherChaCha20Poly1305: true, chacha20Poly1305ID: true,
} }
func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *NegotiatedAlgorithms, err error) { type algorithms struct {
result := &NegotiatedAlgorithms{} kex string
hostKey string
w directionAlgorithms
r directionAlgorithms
}
result.KeyExchange, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
result := &algorithms{}
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if err != nil { if err != nil {
return return
} }
result.HostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if err != nil { if err != nil {
return return
} }
stoc, ctos := &result.Write, &result.Read stoc, ctos := &result.w, &result.r
if isClient { if isClient {
ctos, stoc = stoc, ctos ctos, stoc = stoc, ctos
} }
@@ -418,12 +246,12 @@ func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMs
} }
} }
ctos.compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if err != nil { if err != nil {
return return
} }
stoc.compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if err != nil { if err != nil {
return return
} }
@@ -469,7 +297,7 @@ func (c *Config) SetDefaults() {
c.Rand = rand.Reader c.Rand = rand.Reader
} }
if c.Ciphers == nil { if c.Ciphers == nil {
c.Ciphers = defaultCiphers c.Ciphers = preferredCiphers
} }
var ciphers []string var ciphers []string
for _, c := range c.Ciphers { for _, c := range c.Ciphers {
@@ -481,22 +309,19 @@ func (c *Config) SetDefaults() {
c.Ciphers = ciphers c.Ciphers = ciphers
if c.KeyExchanges == nil { if c.KeyExchanges == nil {
c.KeyExchanges = defaultKexAlgos c.KeyExchanges = preferredKexAlgos
} }
var kexs []string var kexs []string
for _, k := range c.KeyExchanges { for _, k := range c.KeyExchanges {
if kexAlgoMap[k] != nil { if kexAlgoMap[k] != nil {
// Ignore the KEX if we have no kexAlgoMap definition. // Ignore the KEX if we have no kexAlgoMap definition.
kexs = append(kexs, k) kexs = append(kexs, k)
if k == KeyExchangeCurve25519 && !contains(c.KeyExchanges, keyExchangeCurve25519LibSSH) {
kexs = append(kexs, keyExchangeCurve25519LibSSH)
}
} }
} }
c.KeyExchanges = kexs c.KeyExchanges = kexs
if c.MACs == nil { if c.MACs == nil {
c.MACs = defaultMACs c.MACs = supportedMACs
} }
var macs []string var macs []string
for _, m := range c.MACs { for _, m := range c.MACs {

View File

@@ -74,13 +74,6 @@ type Conn interface {
// Disconnect // Disconnect
} }
// AlgorithmsConnMetadata is a ConnMetadata that can return the algorithms
// negotiated between client and server.
type AlgorithmsConnMetadata interface {
ConnMetadata
Algorithms() NegotiatedAlgorithms
}
// DiscardRequests consumes and rejects all requests from the // DiscardRequests consumes and rejects all requests from the
// passed-in channel. // passed-in channel.
func DiscardRequests(in <-chan *Request) { func DiscardRequests(in <-chan *Request) {
@@ -113,7 +106,6 @@ type sshConn struct {
sessionID []byte sessionID []byte
clientVersion []byte clientVersion []byte
serverVersion []byte serverVersion []byte
algorithms NegotiatedAlgorithms
} }
func dup(src []byte) []byte { func dup(src []byte) []byte {
@@ -149,7 +141,3 @@ func (c *sshConn) ClientVersion() []byte {
func (c *sshConn) ServerVersion() []byte { func (c *sshConn) ServerVersion() []byte {
return dup(c.serverVersion) return dup(c.serverVersion)
} }
func (c *sshConn) Algorithms() NegotiatedAlgorithms {
return c.algorithms
}

View File

@@ -38,7 +38,7 @@ type keyingTransport interface {
// prepareKeyChange sets up a key change. The key change for a // prepareKeyChange sets up a key change. The key change for a
// direction will be effected if a msgNewKeys message is sent // direction will be effected if a msgNewKeys message is sent
// or received. // or received.
prepareKeyChange(*NegotiatedAlgorithms, *kexResult) error prepareKeyChange(*algorithms, *kexResult) error
// setStrictMode sets the strict KEX mode, notably triggering // setStrictMode sets the strict KEX mode, notably triggering
// sequence number resets on sending or receiving msgNewKeys. // sequence number resets on sending or receiving msgNewKeys.
@@ -115,7 +115,7 @@ type handshakeTransport struct {
bannerCallback BannerCallback bannerCallback BannerCallback
// Algorithms agreed in the last key exchange. // Algorithms agreed in the last key exchange.
algorithms *NegotiatedAlgorithms algorithms *algorithms
// Counters exclusively owned by readLoop. // Counters exclusively owned by readLoop.
readPacketsLeft uint32 readPacketsLeft uint32
@@ -164,7 +164,7 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
if config.HostKeyAlgorithms != nil { if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else { } else {
t.hostKeyAlgorithms = defaultHostKeyAlgos t.hostKeyAlgorithms = supportedHostKeyAlgos
} }
go t.readLoop() go t.readLoop()
go t.kexLoop() go t.kexLoop()
@@ -184,10 +184,6 @@ func (t *handshakeTransport) getSessionID() []byte {
return t.sessionID return t.sessionID
} }
func (t *handshakeTransport) getAlgorithms() NegotiatedAlgorithms {
return *t.algorithms
}
// waitSession waits for the session to be established. This should be // waitSession waits for the session to be established. This should be
// the first thing to call after instantiating handshakeTransport. // the first thing to call after instantiating handshakeTransport.
func (t *handshakeTransport) waitSession() error { func (t *handshakeTransport) waitSession() error {
@@ -294,7 +290,7 @@ func (t *handshakeTransport) resetWriteThresholds() {
if t.config.RekeyThreshold > 0 { if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold) t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil { } else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.Write.rekeyBytes() t.writeBytesLeft = t.algorithms.w.rekeyBytes()
} else { } else {
t.writeBytesLeft = 1 << 30 t.writeBytesLeft = 1 << 30
} }
@@ -411,7 +407,7 @@ func (t *handshakeTransport) resetReadThresholds() {
if t.config.RekeyThreshold > 0 { if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold) t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil { } else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.Read.rekeyBytes() t.readBytesLeft = t.algorithms.r.rekeyBytes()
} else { } else {
t.readBytesLeft = 1 << 30 t.readBytesLeft = 1 << 30
} }
@@ -704,9 +700,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
} }
} }
kex, ok := kexAlgoMap[t.algorithms.KeyExchange] kex, ok := kexAlgoMap[t.algorithms.kex]
if !ok { if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.KeyExchange) return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
} }
var result *kexResult var result *kexResult
@@ -813,12 +809,12 @@ func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner {
} }
func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) {
hostKey := pickHostKey(t.hostKeys, t.algorithms.HostKey) hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey)
if hostKey == nil { if hostKey == nil {
return nil, errors.New("ssh: internal error: negotiated unsupported signature type") return nil, errors.New("ssh: internal error: negotiated unsupported signature type")
} }
r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.HostKey) r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey)
return r, err return r, err
} }
@@ -833,7 +829,7 @@ func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (
return nil, err return nil, err
} }
if err := verifyHostKeySignature(hostKey, t.algorithms.HostKey, result); err != nil { if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil {
return nil, err return nil, err
} }

View File

@@ -20,18 +20,21 @@ import (
) )
const ( const (
// This is the group called diffie-hellman-group1-sha1 in RFC 4253 and kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
// Oakley Group 2 in RFC 2409. kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
oakleyGroup2 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF" kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256"
// This is the group called diffie-hellman-group14-sha1 in RFC 4253 and kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512"
// Oakley Group 14 in RFC 3526. kexAlgoECDH256 = "ecdh-sha2-nistp256"
oakleyGroup14 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF" kexAlgoECDH384 = "ecdh-sha2-nistp384"
// This is the group called diffie-hellman-group15-sha512 in RFC 8268 and kexAlgoECDH521 = "ecdh-sha2-nistp521"
// Oakley Group 15 in RFC 3526. kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org"
oakleyGroup15 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF" kexAlgoCurve25519SHA256 = "curve25519-sha256"
// This is the group called diffie-hellman-group16-sha512 in RFC 8268 and
// Oakley Group 16 in RFC 3526. // For the following kex only the client half contains a production
oakleyGroup16 = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF" // ready implementation. The server half only consists of a minimal
// implementation to satisfy the automated tests.
kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1"
kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256"
) )
// kexResult captures the outcome of a key exchange. // kexResult captures the outcome of a key exchange.
@@ -399,46 +402,53 @@ func ecHash(curve elliptic.Curve) crypto.Hash {
var kexAlgoMap = map[string]kexAlgorithm{} var kexAlgoMap = map[string]kexAlgorithm{}
func init() { func init() {
p, _ := new(big.Int).SetString(oakleyGroup2, 16) // This is the group called diffie-hellman-group1-sha1 in
kexAlgoMap[InsecureKeyExchangeDH1SHA1] = &dhGroup{ // RFC 4253 and Oakley Group 2 in RFC 2409.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2), g: new(big.Int).SetInt64(2),
p: p, p: p,
pMinus1: new(big.Int).Sub(p, bigOne), pMinus1: new(big.Int).Sub(p, bigOne),
hashFunc: crypto.SHA1, hashFunc: crypto.SHA1,
} }
p, _ = new(big.Int).SetString(oakleyGroup14, 16) // This are the groups called diffie-hellman-group14-sha1 and
// diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268,
// and Oakley Group 14 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
group14 := &dhGroup{ group14 := &dhGroup{
g: new(big.Int).SetInt64(2), g: new(big.Int).SetInt64(2),
p: p, p: p,
pMinus1: new(big.Int).Sub(p, bigOne), pMinus1: new(big.Int).Sub(p, bigOne),
} }
kexAlgoMap[InsecureKeyExchangeDH14SHA1] = &dhGroup{ kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
g: group14.g, p: group14.p, pMinus1: group14.pMinus1, g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
hashFunc: crypto.SHA1, hashFunc: crypto.SHA1,
} }
kexAlgoMap[KeyExchangeDH14SHA256] = &dhGroup{ kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{
g: group14.g, p: group14.p, pMinus1: group14.pMinus1, g: group14.g, p: group14.p, pMinus1: group14.pMinus1,
hashFunc: crypto.SHA256, hashFunc: crypto.SHA256,
} }
p, _ = new(big.Int).SetString(oakleyGroup16, 16) // This is the group called diffie-hellman-group16-sha512 in RFC
// 8268 and Oakley Group 16 in RFC 3526.
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[KeyExchangeDH16SHA512] = &dhGroup{ kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{
g: new(big.Int).SetInt64(2), g: new(big.Int).SetInt64(2),
p: p, p: p,
pMinus1: new(big.Int).Sub(p, bigOne), pMinus1: new(big.Int).Sub(p, bigOne),
hashFunc: crypto.SHA512, hashFunc: crypto.SHA512,
} }
kexAlgoMap[KeyExchangeECDHP521] = &ecdh{elliptic.P521()} kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
kexAlgoMap[KeyExchangeECDHP384] = &ecdh{elliptic.P384()} kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
kexAlgoMap[KeyExchangeECDHP256] = &ecdh{elliptic.P256()} kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
kexAlgoMap[KeyExchangeCurve25519] = &curve25519sha256{} kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
kexAlgoMap[keyExchangeCurve25519LibSSH] = &curve25519sha256{} kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{}
kexAlgoMap[InsecureKeyExchangeDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1} kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1}
kexAlgoMap[KeyExchangeDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256} kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256}
} }
// curve25519sha256 implements the curve25519-sha256 (formerly known as // curve25519sha256 implements the curve25519-sha256 (formerly known as
@@ -591,9 +601,9 @@ const (
func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
// Send GexRequest // Send GexRequest
kexDHGexRequest := kexDHGexRequestMsg{ kexDHGexRequest := kexDHGexRequestMsg{
MinBits: dhGroupExchangeMinimumBits, MinBits: dhGroupExchangeMinimumBits,
PreferredBits: dhGroupExchangePreferredBits, PreferedBits: dhGroupExchangePreferredBits,
MaxBits: dhGroupExchangeMaximumBits, MaxBits: dhGroupExchangeMaximumBits,
} }
if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil { if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil {
return nil, err return nil, err
@@ -680,7 +690,9 @@ func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshak
} }
// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256. // Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256.
func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { //
// This is a minimal implementation to satisfy the automated tests.
func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) {
// Receive GexRequest // Receive GexRequest
packet, err := c.readPacket() packet, err := c.readPacket()
if err != nil { if err != nil {
@@ -690,32 +702,13 @@ func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshak
if err = Unmarshal(packet, &kexDHGexRequest); err != nil { if err = Unmarshal(packet, &kexDHGexRequest); err != nil {
return return
} }
// We check that the request received is valid and that the MaxBits
// requested are at least equal to our supported minimum. This is the same
// check done in OpenSSH:
// https://github.com/openssh/openssh-portable/blob/80a2f64b/kexgexs.c#L94
//
// Furthermore, we also check that the required MinBits are less than or
// equal to 4096 because we can use up to Oakley Group 16.
if kexDHGexRequest.MaxBits < kexDHGexRequest.MinBits || kexDHGexRequest.PreferredBits < kexDHGexRequest.MinBits ||
kexDHGexRequest.MaxBits < kexDHGexRequest.PreferredBits || kexDHGexRequest.MaxBits < dhGroupExchangeMinimumBits ||
kexDHGexRequest.MinBits > 4096 {
return nil, fmt.Errorf("ssh: DH GEX request out of range, min: %d, max: %d, preferred: %d", kexDHGexRequest.MinBits,
kexDHGexRequest.MaxBits, kexDHGexRequest.PreferredBits)
}
var p *big.Int
// We hardcode sending Oakley Group 14 (2048 bits), Oakley Group 15 (3072
// bits) or Oakley Group 16 (4096 bits), based on the requested max size.
if kexDHGexRequest.MaxBits < 3072 {
p, _ = new(big.Int).SetString(oakleyGroup14, 16)
} else if kexDHGexRequest.MaxBits < 4096 {
p, _ = new(big.Int).SetString(oakleyGroup15, 16)
} else {
p, _ = new(big.Int).SetString(oakleyGroup16, 16)
}
// Send GexGroup
// This is the group called diffie-hellman-group14-sha1 in RFC
// 4253 and Oakley Group 14 in RFC 3526.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
g := big.NewInt(2) g := big.NewInt(2)
msg := &kexDHGexGroupMsg{ msg := &kexDHGexGroupMsg{
P: p, P: p,
G: g, G: g,
@@ -753,9 +746,9 @@ func (gex *dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshak
h := gex.hashFunc.New() h := gex.hashFunc.New()
magics.write(h) magics.write(h)
writeString(h, hostKeyBytes) writeString(h, hostKeyBytes)
binary.Write(h, binary.BigEndian, kexDHGexRequest.MinBits) binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits))
binary.Write(h, binary.BigEndian, kexDHGexRequest.PreferredBits) binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits))
binary.Write(h, binary.BigEndian, kexDHGexRequest.MaxBits) binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits))
writeInt(h, p) writeInt(h, p)
writeInt(h, g) writeInt(h, g)
writeInt(h, kexDHGexInit.X) writeInt(h, kexDHGexInit.X)

View File

@@ -36,19 +36,14 @@ import (
// ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner // ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner
// arguments. // arguments.
const ( const (
KeyAlgoRSA = "ssh-rsa" KeyAlgoRSA = "ssh-rsa"
// Deprecated: DSA is only supported at insecure key sizes, and was removed KeyAlgoDSA = "ssh-dss"
// from major implementations. KeyAlgoECDSA256 = "ecdsa-sha2-nistp256"
KeyAlgoDSA = InsecureKeyAlgoDSA KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
// Deprecated: DSA is only supported at insecure key sizes, and was removed KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
// from major implementations. KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
InsecureKeyAlgoDSA = "ssh-dss" KeyAlgoED25519 = "ssh-ed25519"
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com"
KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com"
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
KeyAlgoED25519 = "ssh-ed25519"
KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com"
// KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not // KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not
// public key formats, so they can't appear as a PublicKey.Type. The // public key formats, so they can't appear as a PublicKey.Type. The
@@ -72,7 +67,7 @@ func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err err
switch algo { switch algo {
case KeyAlgoRSA: case KeyAlgoRSA:
return parseRSA(in) return parseRSA(in)
case InsecureKeyAlgoDSA: case KeyAlgoDSA:
return parseDSA(in) return parseDSA(in)
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
return parseECDSA(in) return parseECDSA(in)
@@ -82,7 +77,7 @@ func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err err
return parseED25519(in) return parseED25519(in)
case KeyAlgoSKED25519: case KeyAlgoSKED25519:
return parseSKEd25519(in) return parseSKEd25519(in)
case CertAlgoRSAv01, InsecureCertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01: case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01:
cert, err := parseCert(in, certKeyAlgoNames[algo]) cert, err := parseCert(in, certKeyAlgoNames[algo])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -47,22 +47,22 @@ func (t truncatingMAC) Size() int {
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
var macModes = map[string]*macMode{ var macModes = map[string]*macMode{
HMACSHA512ETM: {64, true, func(key []byte) hash.Hash { "hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key) return hmac.New(sha512.New, key)
}}, }},
HMACSHA256ETM: {32, true, func(key []byte) hash.Hash { "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key) return hmac.New(sha256.New, key)
}}, }},
HMACSHA512: {64, false, func(key []byte) hash.Hash { "hmac-sha2-512": {64, false, func(key []byte) hash.Hash {
return hmac.New(sha512.New, key) return hmac.New(sha512.New, key)
}}, }},
HMACSHA256: {32, false, func(key []byte) hash.Hash { "hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key) return hmac.New(sha256.New, key)
}}, }},
HMACSHA1: {20, false, func(key []byte) hash.Hash { "hmac-sha1": {20, false, func(key []byte) hash.Hash {
return hmac.New(sha1.New, key) return hmac.New(sha1.New, key)
}}, }},
InsecureHMACSHA196: {20, false, func(key []byte) hash.Hash { "hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
return truncatingMAC{12, hmac.New(sha1.New, key)} return truncatingMAC{12, hmac.New(sha1.New, key)}
}}, }},
} }

View File

@@ -122,9 +122,9 @@ type kexDHGexReplyMsg struct {
const msgKexDHGexRequest = 34 const msgKexDHGexRequest = 34
type kexDHGexRequestMsg struct { type kexDHGexRequestMsg struct {
MinBits uint32 `sshtype:"34"` MinBits uint32 `sshtype:"34"`
PreferredBits uint32 PreferedBits uint32
MaxBits uint32 MaxBits uint32
} }
// See RFC 4253, section 10. // See RFC 4253, section 10.

View File

@@ -19,15 +19,19 @@ import (
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
) )
const (
kexAlgoMLKEM768xCurve25519SHA256 = "mlkem768x25519-sha256"
)
func init() { func init() {
// After Go 1.24rc1 mlkem swapped the order of return values of Encapsulate. // After Go 1.24rc1 mlkem swapped the order of return values of Encapsulate.
// See #70950. // See #70950.
if runtime.Version() == "go1.24rc1" { if runtime.Version() == "go1.24rc1" {
return return
} }
supportedKexAlgos = slices.Insert(supportedKexAlgos, 0, KeyExchangeMLKEM768X25519) supportedKexAlgos = slices.Insert(supportedKexAlgos, 0, kexAlgoMLKEM768xCurve25519SHA256)
defaultKexAlgos = slices.Insert(defaultKexAlgos, 0, KeyExchangeMLKEM768X25519) preferredKexAlgos = slices.Insert(preferredKexAlgos, 0, kexAlgoMLKEM768xCurve25519SHA256)
kexAlgoMap[KeyExchangeMLKEM768X25519] = &mlkem768WithCurve25519sha256{} kexAlgoMap[kexAlgoMLKEM768xCurve25519SHA256] = &mlkem768WithCurve25519sha256{}
} }
// mlkem768WithCurve25519sha256 implements the hybrid ML-KEM768 with // mlkem768WithCurve25519sha256 implements the hybrid ML-KEM768 with

View File

@@ -243,15 +243,22 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
fullConf.MaxAuthTries = 6 fullConf.MaxAuthTries = 6
} }
if len(fullConf.PublicKeyAuthAlgorithms) == 0 { if len(fullConf.PublicKeyAuthAlgorithms) == 0 {
fullConf.PublicKeyAuthAlgorithms = defaultPubKeyAuthAlgos fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos
} else { } else {
for _, algo := range fullConf.PublicKeyAuthAlgorithms { for _, algo := range fullConf.PublicKeyAuthAlgorithms {
if !contains(SupportedAlgorithms().PublicKeyAuths, algo) && !contains(InsecureAlgorithms().PublicKeyAuths, algo) { if !contains(supportedPubKeyAuthAlgos, algo) {
c.Close() c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo) return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
} }
} }
} }
// Check if the config contains any unsupported key exchanges
for _, kex := range fullConf.KeyExchanges {
if _, ok := serverForbiddenKexAlgos[kex]; ok {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
}
}
s := &connection{ s := &connection{
sshConn: sshConn{conn: c}, sshConn: sshConn{conn: c},
@@ -308,7 +315,6 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
// We just did the key change, so the session ID is established. // We just did the key change, so the session ID is established.
s.sessionID = s.transport.getSessionID() s.sessionID = s.transport.getSessionID()
s.algorithms = s.transport.getAlgorithms()
var packet []byte var packet []byte
if packet, err = s.transport.readPacket(); err != nil { if packet, err = s.transport.readPacket(); err != nil {

View File

@@ -16,6 +16,13 @@ import (
// wire. No message decoding is done, to minimize the impact on timing. // wire. No message decoding is done, to minimize the impact on timing.
const debugTransport = false const debugTransport = false
const (
gcm128CipherID = "aes128-gcm@openssh.com"
gcm256CipherID = "aes256-gcm@openssh.com"
aes128cbcID = "aes128-cbc"
tripledescbcID = "3des-cbc"
)
// packetConn represents a transport that implements packet based // packetConn represents a transport that implements packet based
// operations. // operations.
type packetConn interface { type packetConn interface {
@@ -85,14 +92,14 @@ func (t *transport) setInitialKEXDone() {
// prepareKeyChange sets up key material for a keychange. The key changes in // prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet // both directions are triggered by reading and writing a msgNewKey packet
// respectively. // respectively.
func (t *transport) prepareKeyChange(algs *NegotiatedAlgorithms, kexResult *kexResult) error { func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
ciph, err := newPacketCipher(t.reader.dir, algs.Read, kexResult) ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil { if err != nil {
return err return err
} }
t.reader.pendingKeyChange <- ciph t.reader.pendingKeyChange <- ciph
ciph, err = newPacketCipher(t.writer.dir, algs.Write, kexResult) ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil { if err != nil {
return err return err
} }
@@ -252,7 +259,7 @@ var (
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys // described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys). // (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs DirectionAlgorithms, kex *kexResult) (packetCipher, error) { func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
cipherMode := cipherModes[algs.Cipher] cipherMode := cipherModes[algs.Cipher]
iv := make([]byte, cipherMode.ivSize) iv := make([]byte, cipherMode.ivSize)

View File

@@ -10,7 +10,6 @@
// builds a list of push/pop events and their node type. Subsequent // builds a list of push/pop events and their node type. Subsequent
// method calls that request a traversal scan this list, rather than walk // method calls that request a traversal scan this list, rather than walk
// the AST, and perform type filtering using efficient bit sets. // the AST, and perform type filtering using efficient bit sets.
// This representation is sometimes called a "balanced parenthesis tree."
// //
// Experiments suggest the inspector's traversals are about 2.5x faster // Experiments suggest the inspector's traversals are about 2.5x faster
// than ast.Inspect, but it may take around 5 traversals for this // than ast.Inspect, but it may take around 5 traversals for this
@@ -48,10 +47,9 @@ type Inspector struct {
events []event events []event
} }
//go:linkname events golang.org/x/tools/go/ast/inspector.events //go:linkname events
func events(in *Inspector) []event { return in.events } func events(in *Inspector) []event { return in.events }
//go:linkname packEdgeKindAndIndex golang.org/x/tools/go/ast/inspector.packEdgeKindAndIndex
func packEdgeKindAndIndex(ek edge.Kind, index int) int32 { func packEdgeKindAndIndex(ek edge.Kind, index int) int32 {
return int32(uint32(index+1)<<7 | uint32(ek)) return int32(uint32(index+1)<<7 | uint32(ek))
} }
@@ -59,7 +57,7 @@ func packEdgeKindAndIndex(ek edge.Kind, index int) int32 {
// unpackEdgeKindAndIndex unpacks the edge kind and edge index (within // unpackEdgeKindAndIndex unpacks the edge kind and edge index (within
// an []ast.Node slice) from the parent field of a pop event. // an []ast.Node slice) from the parent field of a pop event.
// //
//go:linkname unpackEdgeKindAndIndex golang.org/x/tools/go/ast/inspector.unpackEdgeKindAndIndex //go:linkname unpackEdgeKindAndIndex
func unpackEdgeKindAndIndex(x int32) (edge.Kind, int) { func unpackEdgeKindAndIndex(x int32) (edge.Kind, int) {
// The "parent" field of a pop node holds the // The "parent" field of a pop node holds the
// edge Kind in the lower 7 bits and the index+1 // edge Kind in the lower 7 bits and the index+1

Some files were not shown because too many files have changed in this diff Show More