diff --git a/pkg/nvcdi/transform/root/driver-root.go b/pkg/nvcdi/transform/root/driver-root.go new file mode 100644 index 00000000..931e73ef --- /dev/null +++ b/pkg/nvcdi/transform/root/driver-root.go @@ -0,0 +1,99 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package root + +import ( + "path/filepath" + "strings" + + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" +) + +type DriverOption func(*driverOptions) + +func WithDriverRoot(root string) DriverOption { + return func(do *driverOptions) { + do.driverRoot = root + } +} + +func WithTargetDriverRoot(root string) DriverOption { + return func(do *driverOptions) { + do.targetDriverRoot = root + } +} + +func WithDevRoot(root string) DriverOption { + return func(do *driverOptions) { + do.devRoot = root + } +} + +func WithTargetDevRoot(root string) DriverOption { + return func(do *driverOptions) { + do.targetDevRoot = root + } +} + +type driverOptions struct { + driverRoot string + targetDriverRoot string + devRoot string + targetDevRoot string +} + +// NewDriverTransformer creates a transformer for transforming driver specifications. +func NewDriverTransformer(opts ...DriverOption) transform.Transformer { + d := &driverOptions{} + for _, opt := range opts { + opt(d) + } + if d.driverRoot == "" { + d.driverRoot = "/" + } + if d.targetDriverRoot == "" { + d.targetDriverRoot = "/" + } + if d.devRoot == "" { + d.devRoot = d.driverRoot + } + if d.targetDevRoot == "" { + d.targetDevRoot = d.targetDriverRoot + } + + var transformers []transform.Transformer + + if d.targetDevRoot != d.targetDriverRoot { + devRootTransformer := New( + WithRoot(ensureDev(d.devRoot)), + WithTargetRoot(ensureDev(d.targetDevRoot)), + ) + transformers = append(transformers, devRootTransformer) + } + + driverRootTransformer := New( + WithRoot(d.driverRoot), + WithTargetRoot(d.targetDriverRoot), + ) + transformers = append(transformers, driverRootTransformer) + + return transform.Merge(transformers...) +} + +func ensureDev(p string) string { + return filepath.Join(strings.TrimSuffix(filepath.Clean(p), "/dev"), "/dev") +} diff --git a/pkg/nvcdi/transform/root/driver-root_test.go b/pkg/nvcdi/transform/root/driver-root_test.go new file mode 100644 index 00000000..36628053 --- /dev/null +++ b/pkg/nvcdi/transform/root/driver-root_test.go @@ -0,0 +1,208 @@ +/** +# Copyright 2024 NVIDIA CORPORATION +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package root + +import ( + "testing" + + "github.com/stretchr/testify/require" + "tags.cncf.io/container-device-interface/specs-go" +) + +func TestDriverTransformer(t *testing.T) { + testCases := []struct { + description string + driverRoot string + targetDriverRoot string + devRoot string + targetDevRoot string + spec *specs.Spec + expectedError error + expectedSpec *specs.Spec + }{ + { + description: "dev root not specified", + driverRoot: "/driver-root", + targetDriverRoot: "/host/driver/root/", + spec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/driver-root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/driver-root/dev/host/path", + Path: "/driver-root/dev/container/path", + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/host/driver/root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/host/driver/root/dev/host/path", + Path: "/driver-root/dev/container/path", + }, + }, + }, + }, + }, + { + description: "dev driver root matches", + driverRoot: "/driver-root", + targetDriverRoot: "/host/driver/root/", + devRoot: "/driver-root", + targetDevRoot: "/host/driver/root/", + spec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/driver-root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/driver-root/dev/host/path", + Path: "/driver-root/dev/container/path", + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/host/driver/root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/host/driver/root/dev/host/path", + Path: "/driver-root/dev/container/path", + }, + }, + }, + }, + }, + { + description: "dev driver root matches separate target dev root", + driverRoot: "/driver-root", + targetDriverRoot: "/host/driver/root/", + devRoot: "/driver-root", + targetDevRoot: "/host/dev/root/", + spec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/driver-root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/driver-root/dev/host/path", + Path: "/driver-root/dev/container/path", + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/host/driver/root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/host/dev/root/dev/host/path", + Path: "/driver-root/dev/container/path", + }, + }, + }, + }, + }, + { + description: "dev root specified with explicit target", + driverRoot: "/driver-root", + targetDriverRoot: "/host/driver/root/", + devRoot: "/", + targetDevRoot: "/dev/root/", + spec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/driver-root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/dev/host/path", + Path: "/dev/container/path", + }, + }, + }, + }, + expectedSpec: &specs.Spec{ + ContainerEdits: specs.ContainerEdits{ + Mounts: []*specs.Mount{ + { + HostPath: "/host/driver/root/host/path", + ContainerPath: "/driver-root/container/path", + }, + }, + DeviceNodes: []*specs.DeviceNode{ + { + HostPath: "/dev/root/dev/host/path", + Path: "/dev/container/path", + }, + }, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + transformer := NewDriverTransformer( + WithDriverRoot(tc.driverRoot), + WithTargetDriverRoot(tc.targetDriverRoot), + WithDevRoot(tc.devRoot), + WithTargetDevRoot(tc.targetDevRoot), + ) + + err := transformer.Transform(tc.spec) + + require.ErrorIs(t, err, tc.expectedError) + require.EqualValues(t, tc.expectedSpec, tc.spec) + }) + } +} diff --git a/tools/container/toolkit/toolkit.go b/tools/container/toolkit/toolkit.go index 25ee9497..9e161d63 100644 --- a/tools/container/toolkit/toolkit.go +++ b/tools/container/toolkit/toolkit.go @@ -48,7 +48,9 @@ const ( type options struct { DriverRoot string + DevRoot string DriverRootCtrPath string + DevRootCtrPath string ContainerRuntimeMode string ContainerRuntimeDebug string @@ -132,6 +134,18 @@ func main() { Destination: &opts.DriverRootCtrPath, EnvVars: []string{"DRIVER_ROOT_CTR_PATH"}, }, + &cli.StringFlag{ + Name: "dev-root", + Usage: "Specify the root where `/dev` is located. If this is not specified, the driver-root is assumed.", + Destination: &opts.DevRoot, + EnvVars: []string{"NVIDIA_DEV_ROOT", "DEV_ROOT"}, + }, + &cli.StringFlag{ + Name: "dev-root-ctr-path", + Usage: "Specify the root where `/dev` is located in the container. If this is not specified, the driver-root-ctr-path is assumed.", + Destination: &opts.DevRootCtrPath, + EnvVars: []string{"DEV_ROOT_CTR_PATH"}, + }, &cli.StringFlag{ Name: "nvidia-container-runtime.debug", Aliases: []string{"nvidia-container-runtime-debug"}, @@ -750,14 +764,14 @@ func createDeviceNodes(opts *options) error { } devices, err := nvdevices.New( - nvdevices.WithDevRoot(opts.DriverRootCtrPath), + nvdevices.WithDevRoot(opts.DevRootCtrPath), ) if err != nil { return fmt.Errorf("failed to create library: %v", err) } for _, mode := range modes { - log.Infof("Creating %v device nodes at %v", mode, opts.DriverRootCtrPath) + log.Infof("Creating %v device nodes at %v", mode, opts.DevRootCtrPath) if mode != "control" { log.Warningf("Unrecognised device mode: %v", mode) continue @@ -778,6 +792,7 @@ func generateCDISpec(opts *options, nvidiaCDIHookPath string) error { cdilib, err := nvcdi.New( nvcdi.WithMode(nvcdi.ModeManagement), nvcdi.WithDriverRoot(opts.DriverRootCtrPath), + nvcdi.WithDevRoot(opts.DevRootCtrPath), nvcdi.WithNVIDIACDIHookPath(nvidiaCDIHookPath), nvcdi.WithVendor(opts.cdiVendor), nvcdi.WithClass(opts.cdiClass), @@ -790,11 +805,14 @@ func generateCDISpec(opts *options, nvidiaCDIHookPath string) error { if err != nil { return fmt.Errorf("failed to genereate CDI spec for management containers: %v", err) } - err = transformroot.New( - transformroot.WithRoot(opts.DriverRootCtrPath), - transformroot.WithTargetRoot(opts.DriverRoot), - ).Transform(spec.Raw()) - if err != nil { + + transformer := transformroot.NewDriverTransformer( + transformroot.WithDriverRoot(opts.DriverRootCtrPath), + transformroot.WithTargetDriverRoot(opts.DriverRoot), + transformroot.WithDevRoot(opts.DevRootCtrPath), + transformroot.WithTargetDevRoot(opts.DevRoot), + ) + if err := transformer.Transform(spec.Raw()); err != nil { return fmt.Errorf("failed to transform driver root in CDI spec: %v", err) }