Move path manipulation to spec.Save

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-03-01 07:50:46 +02:00
parent 221781bd0b
commit 314059fcf0
2 changed files with 17 additions and 13 deletions

View File

@ -167,15 +167,7 @@ func (m command) run(c *cli.Context, cfg *config) error {
return nil return nil
} }
path := cfg.output return spec.Save(cfg.output)
if filepath.Clean(filepath.Dir(path)) == "." {
pwd, err := os.Getwd()
if err != nil {
return fmt.Errorf("failed to get current working directory: %v", err)
}
path = filepath.Join(pwd, path)
}
return spec.Save(path)
} }
func formatFromFilename(filename string) string { func formatFromFilename(filename string) string {

View File

@ -17,6 +17,7 @@
package spec package spec
import ( import (
"fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
@ -39,7 +40,10 @@ func New(opts ...Option) (Interface, error) {
// Save writes the spec to the specified path and overwrites the file if it exists. // Save writes the spec to the specified path and overwrites the file if it exists.
func (s *spec) Save(path string) error { func (s *spec) Save(path string) error {
path = s.normalizePath(path) path, err := s.normalizePath(path)
if err != nil {
return fmt.Errorf("failed to normalize path: %w", err)
}
specDir := filepath.Dir(path) specDir := filepath.Dir(path)
registry := cdi.GetRegistry( registry := cdi.GetRegistry(
@ -57,7 +61,7 @@ func (s *spec) WriteTo(w io.Writer) (int64, error) {
return 0, err return 0, err
} }
path := s.normalizePath(name) path, _ := s.normalizePath(name)
tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path)) tmpFile, err := os.CreateTemp("", "*"+filepath.Base(path))
if err != nil { if err != nil {
return 0, err return 0, err
@ -88,12 +92,20 @@ func (s *spec) Raw() *specs.Spec {
} }
// normalizePath ensures that the specified path has a supported extension // normalizePath ensures that the specified path has a supported extension
func (s *spec) normalizePath(path string) string { func (s *spec) normalizePath(path string) (string, error) {
if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" { if ext := filepath.Ext(path); ext != ".yaml" && ext != ".json" {
path += s.extension() path += s.extension()
} }
return path if filepath.Clean(filepath.Dir(path)) == "." {
pwd, err := os.Getwd()
if err != nil {
return path, fmt.Errorf("failed to get current working directory: %v", err)
}
path = filepath.Join(pwd, path)
}
return path, nil
} }
func (s *spec) extension() string { func (s *spec) extension() string {