diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index c5986f46..aeee4967 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -44,10 +44,11 @@ type command struct { } type config struct { - output string - format string - root string - nvidiaCTKPath string + output string + format string + deviceNameStrategy string + root string + nvidiaCTKPath string } // NewCommand constructs a generate-cdi command with the specified logger @@ -86,6 +87,12 @@ func (m command) build() *cli.Command { Value: formatYAML, Destination: &cfg.format, }, + &cli.StringFlag{ + Name: "device-name-strategy", + Usage: "Specify the strategy for generating device names. One of [type-index | index | uuid]", + Value: deviceNameStrategyTypeIndex, + Destination: &cfg.deviceNameStrategy, + }, &cli.StringFlag{ Name: "root", Usage: "Specify the root to use when discovering the entities that should be included in the CDI specification.", @@ -110,13 +117,24 @@ func (m command) validateFlags(r *cli.Context, cfg *config) error { return fmt.Errorf("invalid output format: %v", cfg.format) } + _, err := NewDeviceNamer(cfg.deviceNameStrategy) + if err != nil { + return err + } + return nil } func (m command) run(c *cli.Context, cfg *config) error { + deviceNamer, err := NewDeviceNamer(cfg.deviceNameStrategy) + if err != nil { + return fmt.Errorf("failed to create device namer: %v", err) + } + spec, err := m.generateSpec( cfg.root, discover.FindNvidiaCTK(m.logger, cfg.nvidiaCTKPath), + deviceNamer, ) if err != nil { return fmt.Errorf("failed to generate CDI spec: %v", err) @@ -196,7 +214,7 @@ func writeToOutput(format string, data []byte, output io.Writer) error { return nil } -func (m command) generateSpec(root string, nvidiaCTKPath string) (*specs.Spec, error) { +func (m command) generateSpec(root string, nvidiaCTKPath string, namer deviceNamer) (*specs.Spec, error) { nvmllib := nvml.New() if r := nvmllib.Init(); r != nvml.SUCCESS { return nil, r @@ -205,7 +223,7 @@ func (m command) generateSpec(root string, nvidiaCTKPath string) (*specs.Spec, e devicelib := device.New(device.WithNvml(nvmllib)) - deviceSpecs, err := m.generateDeviceSpecs(devicelib, root, nvidiaCTKPath) + deviceSpecs, err := m.generateDeviceSpecs(devicelib, root, nvidiaCTKPath, namer) if err != nil { return nil, fmt.Errorf("failed to create device CDI specs: %v", err) } @@ -266,7 +284,7 @@ func (m command) generateSpec(root string, nvidiaCTKPath string) (*specs.Spec, e return &spec, nil } -func (m command) generateDeviceSpecs(devicelib device.Interface, root string, nvidiaCTKPath string) ([]specs.Device, error) { +func (m command) generateDeviceSpecs(devicelib device.Interface, root string, nvidiaCTKPath string, namer deviceNamer) ([]specs.Device, error) { var deviceSpecs []specs.Device err := devicelib.VisitDevices(func(i int, d device.Device) error { @@ -287,8 +305,12 @@ func (m command) generateDeviceSpecs(devicelib device.Interface, root string, nv return fmt.Errorf("failed to create container edits for device: %v", err) } + deviceName, err := namer.GetDeviceName(i, d) + if err != nil { + return fmt.Errorf("failed to get device name: %v", err) + } deviceSpec := specs.Device{ - Name: fmt.Sprintf("gpu%d", i), + Name: deviceName, ContainerEdits: *deviceEdits.ContainerEdits, } @@ -310,8 +332,12 @@ func (m command) generateDeviceSpecs(devicelib device.Interface, root string, nv return fmt.Errorf("failed to create container edits for MIG device: %v", err) } + deviceName, err := namer.GetMigDeviceName(i, j, mig) + if err != nil { + return fmt.Errorf("failed to get device name: %v", err) + } deviceSpec := specs.Device{ - Name: fmt.Sprintf("mig%v:%v", i, j), + Name: deviceName, ContainerEdits: *deviceEdits.ContainerEdits, } diff --git a/cmd/nvidia-ctk/cdi/generate/namer.go b/cmd/nvidia-ctk/cdi/generate/namer.go new file mode 100644 index 00000000..63131ca9 --- /dev/null +++ b/cmd/nvidia-ctk/cdi/generate/namer.go @@ -0,0 +1,84 @@ +/** +# Copyright (c) NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package generate + +import ( + "fmt" + + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device" + "gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml" +) + +type deviceNamer interface { + GetDeviceName(int, device.Device) (string, error) + GetMigDeviceName(int, int, device.MigDevice) (string, error) +} + +const ( + deviceNameStrategyIndex = "index" + deviceNameStrategyTypeIndex = "type-index" + deviceNameStrategyUUID = "uuid" +) + +type deviceNameIndex struct { + gpuPrefix string + migPrefix string +} +type deviceNameUUID struct{} + +// NewDeviceNamer creates a Device Namer based on the supplied strategy. +// This namer can be used to construct the names for MIG and GPU devices when generating the CDI spec. +func NewDeviceNamer(strategy string) (deviceNamer, error) { + switch strategy { + case deviceNameStrategyIndex: + return deviceNameIndex{}, nil + case deviceNameStrategyTypeIndex: + return deviceNameIndex{gpuPrefix: "gpu", migPrefix: "mig"}, nil + case deviceNameStrategyUUID: + return deviceNameUUID{}, nil + } + + return nil, fmt.Errorf("invalid device name strategy: %v", strategy) +} + +// GetDeviceName returns the name for the specified device based on the naming strategy +func (s deviceNameIndex) GetDeviceName(i int, d device.Device) (string, error) { + return fmt.Sprintf("%s%d", s.gpuPrefix, i), nil +} + +// GetMigDeviceName returns the name for the specified device based on the naming strategy +func (s deviceNameIndex) GetMigDeviceName(i int, j int, d device.MigDevice) (string, error) { + return fmt.Sprintf("%s%d:%d", s.migPrefix, i, j), nil +} + +// GetDeviceName returns the name for the specified device based on the naming strategy +func (s deviceNameUUID) GetDeviceName(i int, d device.Device) (string, error) { + uuid, ret := d.GetUUID() + if ret != nvml.SUCCESS { + return "", fmt.Errorf("failed to get device UUID: %v", ret) + } + return uuid, nil +} + +// GetMigDeviceName returns the name for the specified device based on the naming strategy +func (s deviceNameUUID) GetMigDeviceName(i int, j int, d device.MigDevice) (string, error) { + uuid, ret := d.GetUUID() + if ret != nvml.SUCCESS { + return "", fmt.Errorf("failed to get device UUID: %v", ret) + } + return uuid, nil +}