From 314059fcf033621f9d4a24bd5748c2f4f9143b4d Mon Sep 17 00:00:00 2001 From: Evan Lezar Date: Wed, 1 Mar 2023 07:50:46 +0200 Subject: [PATCH] Move path manipulation to spec.Save Signed-off-by: Evan Lezar --- cmd/nvidia-ctk/cdi/generate/generate.go | 10 +--------- pkg/nvcdi/spec/spec.go | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/cmd/nvidia-ctk/cdi/generate/generate.go b/cmd/nvidia-ctk/cdi/generate/generate.go index 008145ad..c235c90f 100644 --- a/cmd/nvidia-ctk/cdi/generate/generate.go +++ b/cmd/nvidia-ctk/cdi/generate/generate.go @@ -167,15 +167,7 @@ func (m command) run(c *cli.Context, cfg *config) error { return nil } - path := cfg.output - if filepath.Clean(filepath.Dir(path)) == "." { - pwd, err := os.Getwd() - if err != nil { - return fmt.Errorf("failed to get current working directory: %v", err) - } - path = filepath.Join(pwd, path) - } - return spec.Save(path) + return spec.Save(cfg.output) } func formatFromFilename(filename string) string { diff --git a/pkg/nvcdi/spec/spec.go b/pkg/nvcdi/spec/spec.go index c0e96946..2bb26a71 100644 --- a/pkg/nvcdi/spec/spec.go +++ b/pkg/nvcdi/spec/spec.go @@ -17,6 +17,7 @@ package spec import ( + "fmt" "io" "os" "path/filepath" @@ -39,7 +40,10 @@ func New(opts ...Option) (Interface, error) { // Save writes the spec to the specified path and overwrites the file if it exists. func (s *spec) Save(path string) error { - path = s.normalizePath(path) + path, err := s.normalizePath(path) + if err != nil { + return fmt.Errorf("failed to normalize path: %w", err) + } specDir := filepath.Dir(path) registry := cdi.GetRegistry( @@ -57,7 +61,7 @@ func (s *spec) WriteTo(w io.Writer) (int64, error) { return 0, err } - path := s.normalizePath(name) + path, _ := s.normalizePath(name) tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path)) if err != nil { return 0, err @@ -88,12 +92,20 @@ func (s *spec) Raw() *specs.Spec { } // normalizePath ensures that the specified path has a supported extension -func (s *spec) normalizePath(path string) string { +func (s *spec) normalizePath(path string) (string, error) { if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" { path += s.extension() } - return path + if filepath.Clean(filepath.Dir(path)) == "." { + pwd, err := os.Getwd() + if err != nil { + return path, fmt.Errorf("failed to get current working directory: %v", err) + } + path = filepath.Join(pwd, path) + } + + return path, nil } func (s *spec) extension() string {