diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 008145ad..c235c90f 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -167,15 +167,7 @@ func (m command) run(c *cli.Context, cfg *config) error { return nil } - path := cfg.output - if filepath.Clean(filepath.Dir(path)) == "." { - pwd, err := os.Getwd() - if err != nil { - return fmt.Errorf("failed to get current working directory: %v", err) - } - path = filepath.Join(pwd, path) - } - return spec.Save(path) + return spec.Save(cfg.output) } func formatFromFilename(filename string) string { diff --git a/pkg/nvcdi/spec/spec.go b/pkg/nvcdi/spec/spec.go index c0e96946..2bb26a71 100644 --- a/pkg/nvcdi/spec/spec.go +++ b/pkg/nvcdi/spec/spec.go @@ -17,6 +17,7 @@ package spec import ( + "fmt" "io" "os" "path/filepath" @@ -39,7 +40,10 @@ func New(opts ...Option) (Interface, error) { // Save writes the spec to the specified path and overwrites the file if it exists. func (s *spec) Save(path string) error { - path = s.normalizePath(path) + path, err := s.normalizePath(path) + if err != nil { + return fmt.Errorf("failed to normalize path: %w", err) + } specDir := filepath.Dir(path) registry := cdi.GetRegistry( @@ -57,7 +61,7 @@ func (s *spec) WriteTo(w io.Writer) (int64, error) { return 0, err } - path := s.normalizePath(name) + path, _ := s.normalizePath(name) tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path)) if err != nil { return 0, err @@ -88,12 +92,20 @@ func (s *spec) Raw() *specs.Spec { } // normalizePath ensures that the specified path has a supported extension -func (s *spec) normalizePath(path string) string { +func (s *spec) normalizePath(path string) (string, error) { if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" { path += s.extension() } - return path + if filepath.Clean(filepath.Dir(path)) == "." { + pwd, err := os.Getwd() + if err != nil { + return path, fmt.Errorf("failed to get current working directory: %v", err) + } + path = filepath.Join(pwd, path) + } + + return path, nil } func (s *spec) extension() string {