mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2024-11-22 00:08:11 +00:00
Add transform to deduplicate entities in CDI spec
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
df618d3cba
commit
5ff206e1a9
@ -2,6 +2,7 @@
|
||||
|
||||
## v1.13.0-rc.3
|
||||
|
||||
* Add transformers to deduplicate and simplify CDI specifications.
|
||||
* Fix the generation of CDI specifications for management containers when the driver libraries are not in the LDCache.
|
||||
* Prefer /run over /var/run when locating nvidia-persistenced and nvidia-fabricmanager sockets.
|
||||
* Only initialize NVML for modes that require it when runing `nvidia-ctk cdi generate`
|
||||
|
151
pkg/nvcdi/transform/deduplicate.go
Normal file
151
pkg/nvcdi/transform/deduplicate.go
Normal file
@ -0,0 +1,151 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package transform
|
||||
|
||||
import (
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
)
|
||||
|
||||
type dedupe struct{}
|
||||
|
||||
var _ Transformer = (*dedupe)(nil)
|
||||
|
||||
// NewDedupe creates a transformer that deduplicates container edits.
|
||||
func NewDedupe() (Transformer, error) {
|
||||
return &dedupe{}, nil
|
||||
}
|
||||
|
||||
// Transform removes duplicate entris from devices and common container edits.
|
||||
func (d dedupe) Transform(spec *specs.Spec) error {
|
||||
if spec == nil {
|
||||
return nil
|
||||
}
|
||||
if err := d.transformEdits(&spec.ContainerEdits); err != nil {
|
||||
return err
|
||||
}
|
||||
var updatedDevices []specs.Device
|
||||
for _, device := range spec.Devices {
|
||||
if err := d.transformEdits(&device.ContainerEdits); err != nil {
|
||||
return err
|
||||
}
|
||||
updatedDevices = append(updatedDevices, device)
|
||||
}
|
||||
spec.Devices = updatedDevices
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dedupe) transformEdits(edits *specs.ContainerEdits) error {
|
||||
deviceNodes, err := d.deduplicateDeviceNodes(edits.DeviceNodes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
edits.DeviceNodes = deviceNodes
|
||||
|
||||
envs, err := d.deduplicateEnvs(edits.Env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
edits.Env = envs
|
||||
|
||||
hooks, err := d.deduplicateHooks(edits.Hooks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
edits.Hooks = hooks
|
||||
|
||||
mounts, err := d.deduplicateMounts(edits.Mounts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
edits.Mounts = mounts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d dedupe) deduplicateDeviceNodes(entities []*specs.DeviceNode) ([]*specs.DeviceNode, error) {
|
||||
seen := make(map[string]bool)
|
||||
var deviceNodes []*specs.DeviceNode
|
||||
for _, e := range entities {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
id, err := deviceNode(*e).id()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if seen[id] {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
deviceNodes = append(deviceNodes, e)
|
||||
}
|
||||
return deviceNodes, nil
|
||||
}
|
||||
|
||||
func (d dedupe) deduplicateEnvs(entities []string) ([]string, error) {
|
||||
seen := make(map[string]bool)
|
||||
var envs []string
|
||||
for _, e := range entities {
|
||||
id := e
|
||||
if seen[id] {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
envs = append(envs, e)
|
||||
}
|
||||
return envs, nil
|
||||
}
|
||||
|
||||
func (d dedupe) deduplicateHooks(entities []*specs.Hook) ([]*specs.Hook, error) {
|
||||
seen := make(map[string]bool)
|
||||
var hooks []*specs.Hook
|
||||
for _, e := range entities {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
id, err := hook(*e).id()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if seen[id] {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
hooks = append(hooks, e)
|
||||
}
|
||||
return hooks, nil
|
||||
}
|
||||
|
||||
func (d dedupe) deduplicateMounts(entities []*specs.Mount) ([]*specs.Mount, error) {
|
||||
seen := make(map[string]bool)
|
||||
var mounts []*specs.Mount
|
||||
for _, e := range entities {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
id, err := mount(*e).id()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if seen[id] {
|
||||
continue
|
||||
}
|
||||
seen[id] = true
|
||||
mounts = append(mounts, e)
|
||||
}
|
||||
return mounts, nil
|
||||
}
|
250
pkg/nvcdi/transform/deduplicate_test.go
Normal file
250
pkg/nvcdi/transform/deduplicate_test.go
Normal file
@ -0,0 +1,250 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package transform
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeduplicate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
description string
|
||||
spec *specs.Spec
|
||||
expectedError error
|
||||
expectedSpec *specs.Spec
|
||||
}{
|
||||
{
|
||||
description: "nil spec",
|
||||
},
|
||||
{
|
||||
description: "duplicate deviceNode is removed",
|
||||
spec: &specs.Spec{
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
DeviceNodes: []*specs.DeviceNode{
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedSpec: &specs.Spec{
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
DeviceNodes: []*specs.DeviceNode{
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "duplicate deviceNode is remved from device edits",
|
||||
spec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
DeviceNodes: []*specs.DeviceNode{
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedSpec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
DeviceNodes: []*specs.DeviceNode{
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "duplicate hook is removed",
|
||||
spec: &specs.Spec{
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Hooks: []*specs.Hook{
|
||||
{
|
||||
HookName: "createContainer",
|
||||
Path: "/usr/bin/nvidia-ctk",
|
||||
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
|
||||
},
|
||||
{
|
||||
HookName: "createContainer",
|
||||
Path: "/usr/bin/nvidia-ctk",
|
||||
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedSpec: &specs.Spec{
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Hooks: []*specs.Hook{
|
||||
{
|
||||
HookName: "createContainer",
|
||||
Path: "/usr/bin/nvidia-ctk",
|
||||
Args: []string{"nvidia-ctk", "hook", "chmod", "--mode", "755", "--path", "/dev/dri"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "duplicate mount is removed",
|
||||
spec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Mounts: []*specs.Mount{
|
||||
{
|
||||
HostPath: "/host/mount2",
|
||||
ContainerPath: "/mount2",
|
||||
},
|
||||
{
|
||||
HostPath: "/host/mount2",
|
||||
ContainerPath: "/mount2",
|
||||
},
|
||||
{
|
||||
HostPath: "/host/mount1",
|
||||
ContainerPath: "/mount1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Mounts: []*specs.Mount{
|
||||
{
|
||||
HostPath: "/host/mount1",
|
||||
ContainerPath: "/mount1",
|
||||
Options: []string{"bind", "ro"},
|
||||
Type: "tmpfs",
|
||||
},
|
||||
{
|
||||
HostPath: "/host/mount1",
|
||||
ContainerPath: "/mount1",
|
||||
Options: []string{"bind", "ro"},
|
||||
Type: "tmpfs",
|
||||
},
|
||||
{
|
||||
HostPath: "/host/mount1",
|
||||
ContainerPath: "/mount1",
|
||||
Options: []string{"bind", "ro"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedSpec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Mounts: []*specs.Mount{
|
||||
{
|
||||
HostPath: "/host/mount2",
|
||||
ContainerPath: "/mount2",
|
||||
},
|
||||
{
|
||||
HostPath: "/host/mount1",
|
||||
ContainerPath: "/mount1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Mounts: []*specs.Mount{
|
||||
{
|
||||
HostPath: "/host/mount1",
|
||||
ContainerPath: "/mount1",
|
||||
Options: []string{"bind", "ro"},
|
||||
Type: "tmpfs",
|
||||
},
|
||||
{
|
||||
HostPath: "/host/mount1",
|
||||
ContainerPath: "/mount1",
|
||||
Options: []string{"bind", "ro"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "duplicate env is removed",
|
||||
spec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"ENV1=VAL1", "ENV1=VAL1", "ENV2=ONE_VALUE", "ENV2=ANOTHER_VALUE"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"ENV1=VAL1", "ENV1=VAL1", "ENV2=ONE_VALUE", "ENV2=ANOTHER_VALUE"},
|
||||
},
|
||||
},
|
||||
expectedSpec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"ENV1=VAL1", "ENV2=ONE_VALUE", "ENV2=ANOTHER_VALUE"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"ENV1=VAL1", "ENV2=ONE_VALUE", "ENV2=ANOTHER_VALUE"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
d := dedupe{}
|
||||
|
||||
err := d.Transform(tc.spec)
|
||||
if tc.expectedError != nil {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
require.EqualValues(t, tc.expectedSpec, tc.spec)
|
||||
})
|
||||
}
|
||||
}
|
166
pkg/nvcdi/transform/edits.go
Normal file
166
pkg/nvcdi/transform/edits.go
Normal file
@ -0,0 +1,166 @@
|
||||
/*
|
||||
*
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package transform
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
)
|
||||
|
||||
type containerEdits specs.ContainerEdits
|
||||
|
||||
// IsEmpty returns true if the edits are empty.
|
||||
func (e containerEdits) IsEmpty() bool {
|
||||
// Devices with empty edits are invalid
|
||||
if len(e.DeviceNodes) > 0 {
|
||||
return false
|
||||
}
|
||||
if len(e.Env) > 0 {
|
||||
return false
|
||||
}
|
||||
if len(e.Hooks) > 0 {
|
||||
return false
|
||||
}
|
||||
if len(e.Mounts) > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *containerEdits) getEntityIds() ([]string, error) {
|
||||
if e == nil {
|
||||
return nil, nil
|
||||
}
|
||||
uniqueIDs := make(map[string]bool)
|
||||
|
||||
deviceNodes, err := e.getDeviceNodeIDs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k := range deviceNodes {
|
||||
uniqueIDs[k] = true
|
||||
}
|
||||
|
||||
envs, err := e.getEnvIDs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k := range envs {
|
||||
uniqueIDs[k] = true
|
||||
}
|
||||
|
||||
hooks, err := e.getHookIDs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k := range hooks {
|
||||
uniqueIDs[k] = true
|
||||
}
|
||||
|
||||
mounts, err := e.getMountIDs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for k := range mounts {
|
||||
uniqueIDs[k] = true
|
||||
}
|
||||
|
||||
var ids []string
|
||||
for k := range uniqueIDs {
|
||||
ids = append(ids, k)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (e *containerEdits) getDeviceNodeIDs() (map[string]bool, error) {
|
||||
deviceIDs := make(map[string]bool)
|
||||
for _, entity := range e.DeviceNodes {
|
||||
id, err := deviceNode(*entity).id()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deviceIDs[id] = true
|
||||
}
|
||||
return deviceIDs, nil
|
||||
}
|
||||
|
||||
func (e *containerEdits) getEnvIDs() (map[string]bool, error) {
|
||||
envIDs := make(map[string]bool)
|
||||
for _, entity := range e.Env {
|
||||
id, err := env(entity).id()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
envIDs[id] = true
|
||||
}
|
||||
return envIDs, nil
|
||||
}
|
||||
|
||||
func (e *containerEdits) getHookIDs() (map[string]bool, error) {
|
||||
hookIDs := make(map[string]bool)
|
||||
for _, entity := range e.Hooks {
|
||||
id, err := hook(*entity).id()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
hookIDs[id] = true
|
||||
}
|
||||
return hookIDs, nil
|
||||
}
|
||||
|
||||
func (e *containerEdits) getMountIDs() (map[string]bool, error) {
|
||||
mountIDs := make(map[string]bool)
|
||||
for _, entity := range e.Mounts {
|
||||
id, err := mount(*entity).id()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mountIDs[id] = true
|
||||
}
|
||||
return mountIDs, nil
|
||||
}
|
||||
|
||||
type deviceNode specs.DeviceNode
|
||||
|
||||
func (dn deviceNode) id() (string, error) {
|
||||
b, err := json.Marshal(dn)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
type env string
|
||||
|
||||
func (e env) id() (string, error) {
|
||||
return string(e), nil
|
||||
}
|
||||
|
||||
type mount specs.Mount
|
||||
|
||||
func (m mount) id() (string, error) {
|
||||
b, err := json.Marshal(m)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
type hook specs.Hook
|
||||
|
||||
func (m hook) id() (string, error) {
|
||||
b, err := json.Marshal(m)
|
||||
return string(b), err
|
||||
}
|
105
pkg/nvcdi/transform/remove.go
Normal file
105
pkg/nvcdi/transform/remove.go
Normal file
@ -0,0 +1,105 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package transform
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
)
|
||||
|
||||
type remove map[string]bool
|
||||
|
||||
func newRemover(ids ...string) Transformer {
|
||||
r := make(remove)
|
||||
for _, id := range ids {
|
||||
r[id] = true
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Transform remove the specified entities from the spec.
|
||||
func (r remove) Transform(spec *specs.Spec) error {
|
||||
if spec == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, device := range spec.Devices {
|
||||
if err := r.transformEdits(&device.ContainerEdits); err != nil {
|
||||
return fmt.Errorf("failed to remove edits from device %q: %w", device.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.transformEdits(&spec.ContainerEdits)
|
||||
}
|
||||
|
||||
func (r remove) transformEdits(edits *specs.ContainerEdits) error {
|
||||
if edits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var deviceNodes []*specs.DeviceNode
|
||||
for _, entity := range edits.DeviceNodes {
|
||||
id, err := deviceNode(*entity).id()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r[id] {
|
||||
continue
|
||||
}
|
||||
deviceNodes = append(deviceNodes, entity)
|
||||
}
|
||||
edits.DeviceNodes = deviceNodes
|
||||
|
||||
var envs []string
|
||||
for _, entity := range edits.Env {
|
||||
id := entity
|
||||
if r[id] {
|
||||
continue
|
||||
}
|
||||
envs = append(envs, entity)
|
||||
}
|
||||
edits.Env = envs
|
||||
|
||||
var hooks []*specs.Hook
|
||||
for _, entity := range edits.Hooks {
|
||||
id, err := hook(*entity).id()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r[id] {
|
||||
continue
|
||||
}
|
||||
hooks = append(hooks, entity)
|
||||
}
|
||||
edits.Hooks = hooks
|
||||
|
||||
var mounts []*specs.Mount
|
||||
for _, entity := range edits.Mounts {
|
||||
id, err := mount(*entity).id()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r[id] {
|
||||
continue
|
||||
}
|
||||
mounts = append(mounts, entity)
|
||||
}
|
||||
edits.Mounts = mounts
|
||||
|
||||
return nil
|
||||
}
|
74
pkg/nvcdi/transform/simplify.go
Normal file
74
pkg/nvcdi/transform/simplify.go
Normal file
@ -0,0 +1,74 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package transform
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
)
|
||||
|
||||
type simplify struct{}
|
||||
|
||||
var _ Transformer = (*simplify)(nil)
|
||||
|
||||
// NewSimplifier creates a simplifier transformer.
|
||||
// This transoformer ensures that entities in the spec are deduplicated and that common edits are removed from device-specific edits.
|
||||
func NewSimplifier() (Transformer, error) {
|
||||
return &simplify{}, nil
|
||||
}
|
||||
|
||||
// Transform simplifies the supplied spec.
|
||||
// Edits that are present in the common edits are removed from device-specific edits.
|
||||
func (s simplify) Transform(spec *specs.Spec) error {
|
||||
if spec == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
dedupe := dedupe{}
|
||||
if err := dedupe.Transform(spec); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
commonEntityIDs, err := (*containerEdits)(&spec.ContainerEdits).getEntityIds()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
toRemove := newRemover(commonEntityIDs...)
|
||||
var updatedDevices []specs.Device
|
||||
for _, device := range spec.Devices {
|
||||
deviceAsSpec := specs.Spec{
|
||||
ContainerEdits: device.ContainerEdits,
|
||||
}
|
||||
err := toRemove.Transform(&deviceAsSpec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to transform device edits: %w", err)
|
||||
}
|
||||
|
||||
if !(containerEdits)(deviceAsSpec.ContainerEdits).IsEmpty() {
|
||||
// Devices with empty edits are invalid.
|
||||
// We only update the container edits for the device if this would
|
||||
// result in a valid device.
|
||||
device.ContainerEdits = deviceAsSpec.ContainerEdits
|
||||
}
|
||||
updatedDevices = append(updatedDevices, device)
|
||||
}
|
||||
spec.Devices = updatedDevices
|
||||
|
||||
return nil
|
||||
}
|
125
pkg/nvcdi/transform/simplify_test.go
Normal file
125
pkg/nvcdi/transform/simplify_test.go
Normal file
@ -0,0 +1,125 @@
|
||||
/**
|
||||
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
**/
|
||||
|
||||
package transform
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/container-orchestrated-devices/container-device-interface/specs-go"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSimplify(t *testing.T) {
|
||||
testCases := []struct {
|
||||
description string
|
||||
spec *specs.Spec
|
||||
expectedError error
|
||||
expectedSpec *specs.Spec
|
||||
}{
|
||||
{
|
||||
description: "nil spec is a no-op",
|
||||
},
|
||||
{
|
||||
description: "empty spec is simplified",
|
||||
spec: &specs.Spec{},
|
||||
expectedSpec: &specs.Spec{},
|
||||
},
|
||||
{
|
||||
description: "simplify does not allow empty device",
|
||||
spec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"FOO=BAR"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"FOO=BAR"},
|
||||
},
|
||||
},
|
||||
expectedSpec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"FOO=BAR"},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"FOO=BAR"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
description: "simplify removes common entities",
|
||||
spec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"FOO=BAR"},
|
||||
DeviceNodes: []*specs.DeviceNode{
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"FOO=BAR"},
|
||||
},
|
||||
},
|
||||
expectedSpec: &specs.Spec{
|
||||
Devices: []specs.Device{
|
||||
{
|
||||
Name: "device0",
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
DeviceNodes: []*specs.DeviceNode{
|
||||
{
|
||||
Path: "/dev/gpu0",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ContainerEdits: specs.ContainerEdits{
|
||||
Env: []string{"FOO=BAR"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.description, func(t *testing.T) {
|
||||
s := simplify{}
|
||||
|
||||
err := s.Transform(tc.spec)
|
||||
if tc.expectedError != nil {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
require.EqualValues(t, tc.expectedSpec, tc.spec)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user