mirror of
https://github.com/NVIDIA/nvidia-container-toolkit
synced 2025-04-05 04:59:18 +00:00
Use functions in oci package to parse arguments
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
c05f9dffc5
commit
1a755d471a
@ -66,22 +66,9 @@ func (r nvidiaContainerRuntime) Exec(args []string) error {
|
||||
// modificationRequired checks the intput arguments to determine whether a modification
|
||||
// to the OCI spec is required.
|
||||
func (r nvidiaContainerRuntime) modificationRequired(args []string) bool {
|
||||
var previousWasBundle bool
|
||||
for _, a := range args {
|
||||
// We check for '--bundle create' explicitly to ensure that we
|
||||
// don't inadvertently trigger a modification if the bundle directory
|
||||
// is specified as `create`
|
||||
if !previousWasBundle && isBundleFlag(a) {
|
||||
previousWasBundle = true
|
||||
continue
|
||||
}
|
||||
|
||||
if !previousWasBundle && a == "create" {
|
||||
r.logger.Infof("'create' command detected; modification required")
|
||||
return true
|
||||
}
|
||||
|
||||
previousWasBundle = false
|
||||
if oci.HasCreateSubcommand(args) {
|
||||
r.logger.Infof("'create' command detected; modification required")
|
||||
return true
|
||||
}
|
||||
|
||||
r.logger.Infof("No modification required")
|
||||
|
@ -18,7 +18,6 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@ -28,35 +27,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestArgsGetConfigFilePath(t *testing.T) {
|
||||
wd, err := os.Getwd()
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
bundleDir string
|
||||
ociSpecPath string
|
||||
}{
|
||||
{
|
||||
ociSpecPath: fmt.Sprintf("%v/config.json", wd),
|
||||
},
|
||||
{
|
||||
bundleDir: "/foo/bar",
|
||||
ociSpecPath: "/foo/bar/config.json",
|
||||
},
|
||||
{
|
||||
bundleDir: "/foo/bar/",
|
||||
ociSpecPath: "/foo/bar/config.json",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
cp, err := getOCISpecFilePath(tc.bundleDir)
|
||||
|
||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||
require.Equalf(t, tc.ociSpecPath, cp, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddNvidiaHook(t *testing.T) {
|
||||
logger, logHook := testlog.NewNullLogger()
|
||||
shim := nvidiaContainerRuntime{
|
||||
|
@ -18,9 +18,6 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
|
||||
)
|
||||
@ -53,15 +50,15 @@ func newRuntime(argv []string) (oci.Runtime, error) {
|
||||
|
||||
// newOCISpec constructs an OCI spec for the provided arguments
|
||||
func newOCISpec(argv []string) (oci.Spec, error) {
|
||||
bundlePath, err := getBundlePath(argv)
|
||||
bundleDir, err := oci.GetBundleDir(argv)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing command line arguments: %v", err)
|
||||
}
|
||||
logger.Infof("Using bundle directory: %v", bundleDir)
|
||||
|
||||
ociSpecPath := oci.GetSpecFilePath(bundleDir)
|
||||
logger.Infof("Using OCI specification file path: %v", ociSpecPath)
|
||||
|
||||
ociSpecPath, err := getOCISpecFilePath(bundlePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting OCI specification file path: %v", err)
|
||||
}
|
||||
ociSpec := oci.NewSpecFromFile(ociSpecPath)
|
||||
|
||||
return ociSpec, nil
|
||||
@ -75,70 +72,3 @@ func newRuncRuntime() (oci.Runtime, error) {
|
||||
runcExecutableName,
|
||||
)
|
||||
}
|
||||
|
||||
// getBundlePath checks the specified slice of strings (argv) for a 'bundle' flag as allowed by runc.
|
||||
// The following are supported:
|
||||
// --bundle{{SEP}}BUNDLE_PATH
|
||||
// -bundle{{SEP}}BUNDLE_PATH
|
||||
// -b{{SEP}}BUNDLE_PATH
|
||||
// where {{SEP}} is either ' ' or '='
|
||||
func getBundlePath(argv []string) (string, error) {
|
||||
var bundlePath string
|
||||
|
||||
for i := 0; i < len(argv); i++ {
|
||||
param := argv[i]
|
||||
|
||||
parts := strings.SplitN(param, "=", 2)
|
||||
if !isBundleFlag(parts[0]) {
|
||||
continue
|
||||
}
|
||||
|
||||
// The flag has the format --bundle=/path
|
||||
if len(parts) == 2 {
|
||||
bundlePath = parts[1]
|
||||
continue
|
||||
}
|
||||
|
||||
// The flag has the format --bundle /path
|
||||
if i+1 < len(argv) {
|
||||
bundlePath = argv[i+1]
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// --bundle / -b was the last element of argv
|
||||
return "", fmt.Errorf("bundle option requires an argument")
|
||||
}
|
||||
|
||||
return bundlePath, nil
|
||||
}
|
||||
|
||||
func isBundleFlag(arg string) bool {
|
||||
if !strings.HasPrefix(arg, "-") {
|
||||
return false
|
||||
}
|
||||
|
||||
trimmed := strings.TrimLeft(arg, "-")
|
||||
return trimmed == "b" || trimmed == "bundle"
|
||||
}
|
||||
|
||||
// getOCISpecFilePath returns the expected path to the OCI specification file for the given
|
||||
// bundle directory or the current working directory if not specified.
|
||||
func getOCISpecFilePath(bundleDir string) (string, error) {
|
||||
if bundleDir == "" {
|
||||
logger.Infof("Bundle directory path is empty, using working directory.")
|
||||
workingDirectory, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting working directory: %v", err)
|
||||
}
|
||||
bundleDir = workingDirectory
|
||||
}
|
||||
|
||||
logger.Infof("Using bundle directory: %v", bundleDir)
|
||||
|
||||
OCISpecFilePath := filepath.Join(bundleDir, ociSpecFileName)
|
||||
|
||||
logger.Infof("Using OCI specification file path: %v", OCISpecFilePath)
|
||||
|
||||
return OCISpecFilePath, nil
|
||||
}
|
||||
|
@ -28,111 +28,3 @@ func TestConstructor(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, shim)
|
||||
}
|
||||
|
||||
func TestGetBundlePath(t *testing.T) {
|
||||
type expected struct {
|
||||
bundle string
|
||||
isError bool
|
||||
}
|
||||
testCases := []struct {
|
||||
argv []string
|
||||
expected expected
|
||||
}{
|
||||
{
|
||||
argv: []string{},
|
||||
},
|
||||
{
|
||||
argv: []string{"create"},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle"},
|
||||
expected: expected{
|
||||
isError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b"},
|
||||
expected: expected{
|
||||
isError: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--not-bundle", "/foo/bar"},
|
||||
},
|
||||
{
|
||||
argv: []string{"--"},
|
||||
},
|
||||
{
|
||||
argv: []string{"-bundle", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"--bundle=/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=/foo/=bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/=bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"create", "-b", "/foo/bar"},
|
||||
expected: expected{
|
||||
bundle: "/foo/bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "create", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b=create", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
{
|
||||
argv: []string{"-b", "create"},
|
||||
expected: expected{
|
||||
bundle: "create",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
bundle, err := getBundlePath(tc.argv)
|
||||
|
||||
if tc.expected.isError {
|
||||
require.Errorf(t, err, "%d: %v", i, tc)
|
||||
} else {
|
||||
require.NoErrorf(t, err, "%d: %v", i, tc)
|
||||
}
|
||||
|
||||
require.Equalf(t, tc.expected.bundle, bundle, "%d: %v", i, tc)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user