From a10379555325863cc83730990173b18cd3bffcf8 Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 9 Apr 2025 13:47:28 +0200 Subject: [PATCH] Add EnvVar to Discover interface This change adds environment variables to the Discover interface. Signed-off-by: Evan Lezar --- internal/discover/cache.go | 15 +++++++ internal/discover/discover.go | 7 ++++ internal/discover/discover_mock.go | 41 +++++++++++++++++++ internal/discover/envvar.go | 41 +++++++++++++++++++ internal/discover/first-valid.go | 13 ++++++ internal/discover/hooks.go | 5 +++ internal/discover/list.go | 15 +++++++ internal/discover/none.go | 5 +++ internal/edits/edits.go | 9 ++++ internal/edits/envvar.go | 39 ++++++++++++++++++ .../platform-support/dgpu/by-path-hooks.go | 5 +++ .../platform-support/dgpu/nvsandboxutils.go | 4 ++ .../workarounds-device-folder-permissions.go | 5 +++ 13 files changed, 204 insertions(+) create mode 100644 internal/discover/envvar.go create mode 100644 internal/edits/envvar.go diff --git a/internal/discover/cache.go b/internal/discover/cache.go index e9e4b615..56523386 100644 --- a/internal/discover/cache.go +++ b/internal/discover/cache.go @@ -23,6 +23,7 @@ type cache struct { sync.Mutex devices []Device + envVars []EnvVar hooks []Hook mounts []Mount } @@ -51,6 +52,20 @@ func (c *cache) Devices() ([]Device, error) { return c.devices, nil } +func (c *cache) EnvVars() ([]EnvVar, error) { + c.Lock() + defer c.Unlock() + + if c.envVars == nil { + envVars, err := c.d.EnvVars() + if err != nil { + return nil, err + } + c.envVars = envVars + } + return c.envVars, nil +} + func (c *cache) Hooks() ([]Hook, error) { c.Lock() defer c.Unlock() diff --git a/internal/discover/discover.go b/internal/discover/discover.go index fc639296..07aeb937 100644 --- a/internal/discover/discover.go +++ b/internal/discover/discover.go @@ -22,6 +22,12 @@ type Device struct { Path string } +// EnvVar represents a discovered environment variable. +type EnvVar struct { + Name string + Value string +} + // Mount represents a discovered mount. type Mount struct { HostPath string @@ -41,6 +47,7 @@ type Hook struct { //go:generate moq -rm -fmt=goimports -stub -out discover_mock.go . Discover type Discover interface { Devices() ([]Device, error) + EnvVars() ([]EnvVar, error) Mounts() ([]Mount, error) Hooks() ([]Hook, error) } diff --git a/internal/discover/discover_mock.go b/internal/discover/discover_mock.go index d9be6ffb..2a66cf0e 100644 --- a/internal/discover/discover_mock.go +++ b/internal/discover/discover_mock.go @@ -20,6 +20,9 @@ var _ Discover = &DiscoverMock{} // DevicesFunc: func() ([]Device, error) { // panic("mock out the Devices method") // }, +// EnvVarsFunc: func() ([]EnvVar, error) { +// panic("mock out the EnvVars method") +// }, // HooksFunc: func() ([]Hook, error) { // panic("mock out the Hooks method") // }, @@ -36,6 +39,9 @@ type DiscoverMock struct { // DevicesFunc mocks the Devices method. DevicesFunc func() ([]Device, error) + // EnvVarsFunc mocks the EnvVars method. + EnvVarsFunc func() ([]EnvVar, error) + // HooksFunc mocks the Hooks method. HooksFunc func() ([]Hook, error) @@ -47,6 +53,9 @@ type DiscoverMock struct { // Devices holds details about calls to the Devices method. Devices []struct { } + // EnvVars holds details about calls to the EnvVars method. + EnvVars []struct { + } // Hooks holds details about calls to the Hooks method. Hooks []struct { } @@ -55,6 +64,7 @@ type DiscoverMock struct { } } lockDevices sync.RWMutex + lockEnvVars sync.RWMutex lockHooks sync.RWMutex lockMounts sync.RWMutex } @@ -90,6 +100,37 @@ func (mock *DiscoverMock) DevicesCalls() []struct { return calls } +// EnvVars calls EnvVarsFunc. +func (mock *DiscoverMock) EnvVars() ([]EnvVar, error) { + callInfo := struct { + }{} + mock.lockEnvVars.Lock() + mock.calls.EnvVars = append(mock.calls.EnvVars, callInfo) + mock.lockEnvVars.Unlock() + if mock.EnvVarsFunc == nil { + var ( + envVarsOut []EnvVar + errOut error + ) + return envVarsOut, errOut + } + return mock.EnvVarsFunc() +} + +// EnvVarsCalls gets all the calls that were made to EnvVars. +// Check the length with: +// +// len(mockedDiscover.EnvVarsCalls()) +func (mock *DiscoverMock) EnvVarsCalls() []struct { +} { + var calls []struct { + } + mock.lockEnvVars.RLock() + calls = mock.calls.EnvVars + mock.lockEnvVars.RUnlock() + return calls +} + // Hooks calls HooksFunc. func (mock *DiscoverMock) Hooks() ([]Hook, error) { callInfo := struct { diff --git a/internal/discover/envvar.go b/internal/discover/envvar.go new file mode 100644 index 00000000..3c84db0c --- /dev/null +++ b/internal/discover/envvar.go @@ -0,0 +1,41 @@ +/** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package discover + +var _ Discover = (*EnvVar)(nil) + +// Devices returns an empty list of devices for a EnvVar discoverer. +func (e EnvVar) Devices() ([]Device, error) { + return nil, nil +} + +// EnvVars returns an empty list of envs for a EnvVar discoverer. +func (e EnvVar) EnvVars() ([]EnvVar, error) { + return []EnvVar{e}, nil +} + +// Mounts returns an empty list of mounts for a EnvVar discoverer. +func (e EnvVar) Mounts() ([]Mount, error) { + return nil, nil +} + +// Hooks allows the Hook type to also implement the Discoverer interface. +// It returns a single hook +func (e EnvVar) Hooks() ([]Hook, error) { + return nil, nil +} diff --git a/internal/discover/first-valid.go b/internal/discover/first-valid.go index 36de9204..81c93d3e 100644 --- a/internal/discover/first-valid.go +++ b/internal/discover/first-valid.go @@ -45,6 +45,19 @@ func (f firstOf) Devices() ([]Device, error) { return nil, errs } +func (f firstOf) EnvVars() ([]EnvVar, error) { + var errs error + for _, d := range f { + envs, err := d.EnvVars() + if err != nil { + errs = errors.Join(errs, err) + continue + } + return envs, nil + } + return nil, errs +} + func (f firstOf) Hooks() ([]Hook, error) { var errs error for _, d := range f { diff --git a/internal/discover/hooks.go b/internal/discover/hooks.go index 4259ccf8..fb3ebf6a 100644 --- a/internal/discover/hooks.go +++ b/internal/discover/hooks.go @@ -29,6 +29,11 @@ func (h Hook) Devices() ([]Device, error) { return nil, nil } +// EnvVars returns an empty list of envs for a Hook discoverer. +func (h Hook) EnvVars() ([]EnvVar, error) { + return nil, nil +} + // Mounts returns an empty list of mounts for a Hook discoverer. func (h Hook) Mounts() ([]Mount, error) { return nil, nil diff --git a/internal/discover/list.go b/internal/discover/list.go index 8fbf6e75..a01180cc 100644 --- a/internal/discover/list.go +++ b/internal/discover/list.go @@ -51,6 +51,21 @@ func (d list) Devices() ([]Device, error) { return allDevices, nil } +// EnvVars returns all environment variables from the included discoverers. +func (d list) EnvVars() ([]EnvVar, error) { + var allEnvs []EnvVar + + for i, di := range d.discoverers { + envs, err := di.EnvVars() + if err != nil { + return nil, fmt.Errorf("error discovering envs for discoverer %v: %w", i, err) + } + allEnvs = append(allEnvs, envs...) + } + + return allEnvs, nil +} + // Mounts returns all mounts from the included discoverers func (d list) Mounts() ([]Mount, error) { var allMounts []Mount diff --git a/internal/discover/none.go b/internal/discover/none.go index 554e7eab..3ce62472 100644 --- a/internal/discover/none.go +++ b/internal/discover/none.go @@ -27,6 +27,11 @@ func (e None) Devices() ([]Device, error) { return nil, nil } +// EnvVars returns an empty list of devices +func (e None) EnvVars() ([]EnvVar, error) { + return nil, nil +} + // Mounts returns an empty list of mounts func (e None) Mounts() ([]Mount, error) { return nil, nil diff --git a/internal/edits/edits.go b/internal/edits/edits.go index 029e7885..4538ac31 100644 --- a/internal/edits/edits.go +++ b/internal/edits/edits.go @@ -55,6 +55,11 @@ func FromDiscoverer(d discover.Discover) (*cdi.ContainerEdits, error) { return nil, fmt.Errorf("failed to discover devices: %v", err) } + envs, err := d.EnvVars() + if err != nil { + return nil, fmt.Errorf("failed to discover environment variables: %w", err) + } + mounts, err := d.Mounts() if err != nil { return nil, fmt.Errorf("failed to discover mounts: %v", err) @@ -74,6 +79,10 @@ func FromDiscoverer(d discover.Discover) (*cdi.ContainerEdits, error) { c.Append(edits) } + for _, e := range envs { + c.Append(envvar(e).toEdits()) + } + for _, m := range mounts { c.Append(mount(m).toEdits()) } diff --git a/internal/edits/envvar.go b/internal/edits/envvar.go new file mode 100644 index 00000000..359ce6b9 --- /dev/null +++ b/internal/edits/envvar.go @@ -0,0 +1,39 @@ +/** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package edits + +import ( + "fmt" + + "tags.cncf.io/container-device-interface/pkg/cdi" + "tags.cncf.io/container-device-interface/specs-go" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/discover" +) + +type envvar discover.EnvVar + +// toEdits converts a discovered envvar to CDI Container Edits. +func (d envvar) toEdits() *cdi.ContainerEdits { + e := cdi.ContainerEdits{ + ContainerEdits: &specs.ContainerEdits{ + Env: []string{fmt.Sprintf("%s=%s", d.Name, d.Value)}, + }, + } + return &e +} diff --git a/internal/platform-support/dgpu/by-path-hooks.go b/internal/platform-support/dgpu/by-path-hooks.go index cd38f5e7..ff9c9917 100644 --- a/internal/platform-support/dgpu/by-path-hooks.go +++ b/internal/platform-support/dgpu/by-path-hooks.go @@ -41,6 +41,11 @@ func (d *byPathHookDiscoverer) Devices() ([]discover.Device, error) { return nil, nil } +// EnvVars returns the empty list for the by-path hook discoverer +func (d *byPathHookDiscoverer) EnvVars() ([]discover.EnvVar, error) { + return nil, nil +} + // Hooks returns the hooks for the GPU device. // The following hooks are detected: // 1. A hook to create /dev/dri/by-path symlinks diff --git a/internal/platform-support/dgpu/nvsandboxutils.go b/internal/platform-support/dgpu/nvsandboxutils.go index ebeea7c8..9fd7bb4e 100644 --- a/internal/platform-support/dgpu/nvsandboxutils.go +++ b/internal/platform-support/dgpu/nvsandboxutils.go @@ -106,6 +106,10 @@ func (d *nvsandboxutilsDGPU) Devices() ([]discover.Device, error) { return devices, nil } +func (d *nvsandboxutilsDGPU) EnvVars() ([]discover.EnvVar, error) { + return nil, nil +} + // Hooks returns a hook to create the by-path symlinks for the discovered devices. func (d *nvsandboxutilsDGPU) Hooks() ([]discover.Hook, error) { if len(d.deviceLinks) == 0 { diff --git a/pkg/nvcdi/workarounds-device-folder-permissions.go b/pkg/nvcdi/workarounds-device-folder-permissions.go index 511eb1fc..59f623fc 100644 --- a/pkg/nvcdi/workarounds-device-folder-permissions.go +++ b/pkg/nvcdi/workarounds-device-folder-permissions.go @@ -55,6 +55,11 @@ func (d *deviceFolderPermissions) Devices() ([]discover.Device, error) { return nil, nil } +// EnvVars are empty for this discoverer +func (d *deviceFolderPermissions) EnvVars() ([]discover.EnvVar, error) { + return nil, nil +} + // Hooks returns a set of hooks that sets the file mode to 755 of parent folders for nested device nodes. func (d *deviceFolderPermissions) Hooks() ([]discover.Hook, error) { folders, err := d.getDeviceSubfolders()