diff --git a/cmd/nvidia-ctk/cdi/cdi.go b/cmd/nvidia-ctk/cdi/cdi.go index 84b42717..e6c75a2d 100644 --- a/cmd/nvidia-ctk/cdi/cdi.go +++ b/cmd/nvidia-ctk/cdi/cdi.go @@ -18,6 +18,7 @@ package cdi import ( "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/cdi/generate" + "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/cdi/transform" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" ) @@ -44,6 +45,7 @@ func (m command) build() *cli.Command { hook.Subcommands = []*cli.Command{ generate.NewCommand(m.logger), + transform.NewCommand(m.logger), } return &hook diff --git a/cmd/nvidia-ctk/cdi/transform/root/root.go b/cmd/nvidia-ctk/cdi/transform/root/root.go new file mode 100644 index 00000000..6014cea8 --- /dev/null +++ b/cmd/nvidia-ctk/cdi/transform/root/root.go @@ -0,0 +1,104 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package root + +import ( + "fmt" + + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/transform" + "github.com/sirupsen/logrus" + "github.com/urfave/cli/v2" +) + +type loadSaver interface { + Load() (spec.Interface, error) + Save(spec.Interface) error +} + +type command struct { + logger *logrus.Logger + + handler loadSaver +} + +type config struct { + from string + to string +} + +// NewCommand constructs a generate-cdi command with the specified logger +func NewCommand(logger *logrus.Logger, specHandler loadSaver) *cli.Command { + c := command{ + logger: logger, + handler: specHandler, + } + return c.build() +} + +// build creates the CLI command +func (m command) build() *cli.Command { + cfg := config{} + + c := cli.Command{ + Name: "root", + Usage: "Apply a root transform to a CDI specification", + Before: func(c *cli.Context) error { + return m.validateFlags(c, &cfg) + }, + Action: func(c *cli.Context) error { + return m.run(c, &cfg) + }, + } + + c.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "from", + Usage: "specify the root to be transformed", + Destination: &cfg.from, + }, + &cli.StringFlag{ + Name: "to", + Usage: "specify the replacement root. If this is the same as the from root, the transform is a no-op.", + Value: "", + Destination: &cfg.to, + }, + } + + return &c +} + +func (m command) validateFlags(c *cli.Context, cfg *config) error { + return nil +} + +func (m command) run(c *cli.Context, cfg *config) error { + spec, err := m.handler.Load() + if err != nil { + return fmt.Errorf("failed to load CDI specification: %w", err) + } + + err = transform.NewRootTransformer( + cfg.from, + cfg.to, + ).Transform(spec.Raw()) + if err != nil { + return fmt.Errorf("failed to transform CDI specification: %w", err) + } + + return m.handler.Save(spec) +} diff --git a/cmd/nvidia-ctk/cdi/transform/transform.go b/cmd/nvidia-ctk/cdi/transform/transform.go new file mode 100644 index 00000000..22ad2e13 --- /dev/null +++ b/cmd/nvidia-ctk/cdi/transform/transform.go @@ -0,0 +1,128 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package transform + +import ( + "fmt" + "io" + "os" + + "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/cdi/transform/root" + "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec" + "github.com/container-orchestrated-devices/container-device-interface/pkg/cdi" + "github.com/sirupsen/logrus" + "github.com/urfave/cli/v2" +) + +type command struct { + logger *logrus.Logger +} + +type options struct { + input string + output string +} + +// NewCommand constructs a command with the specified logger +func NewCommand(logger *logrus.Logger) *cli.Command { + c := command{ + logger: logger, + } + return c.build() +} + +// build creates the CLI command +func (m command) build() *cli.Command { + opts := options{} + + c := cli.Command{ + Name: "transform", + Usage: "Apply a transform to a CDI specification", + Before: func(c *cli.Context) error { + return m.validateFlags(c, &opts) + }, + Action: func(c *cli.Context) error { + return m.run(c, &opts) + }, + } + + c.Flags = []cli.Flag{ + &cli.StringFlag{ + Name: "input", + Usage: "Specify the file to read the CDI specification from. If this is '-' the specification is read from STDIN", + Value: "-", + Destination: &opts.input, + }, + &cli.StringFlag{ + Name: "output", + Usage: "Specify the file to output the generated CDI specification to. If this is '' the specification is output to STDOUT", + Destination: &opts.output, + }, + } + + c.Subcommands = []*cli.Command{ + root.NewCommand(m.logger, &opts), + } + + return &c +} + +func (m command) validateFlags(c *cli.Context, opts *options) error { + return nil +} + +func (m command) run(c *cli.Context, cfg *options) error { + return nil +} + +// Load lodas the input CDI specification +func (o options) Load() (spec.Interface, error) { + contents, err := o.getContents() + if err != nil { + return nil, fmt.Errorf("failed to read spec contents: %v", err) + } + + raw, err := cdi.ParseSpec(contents) + if err != nil { + return nil, fmt.Errorf("failed to parse CDI spec: %v", err) + } + + return spec.New( + spec.WithRawSpec(raw), + ) +} + +func (o options) getContents() ([]byte, error) { + if o.input == "-" { + return io.ReadAll(os.Stdin) + } + + return os.ReadFile(o.input) +} + +// Save saves the CDI specification to the output file +func (o options) Save(s spec.Interface) error { + if o.output == "" { + _, err := s.WriteTo(os.Stdout) + if err != nil { + return fmt.Errorf("failed to write CDI spec to STDOUT: %v", err) + } + return nil + } + + return s.Save(o.output) +} diff --git a/pkg/nvcdi/spec/builder.go b/pkg/nvcdi/spec/builder.go index 78df7c17..6379ad0f 100644 --- a/pkg/nvcdi/spec/builder.go +++ b/pkg/nvcdi/spec/builder.go @@ -41,6 +41,13 @@ func newBuilder(opts ...Option) *builder { for _, opt := range opts { opt(s) } + if s.raw != nil { + s.noSimplify = true + vendor, class := cdi.ParseQualifier(s.raw.Kind) + s.vendor = vendor + s.class = class + } + if s.version == "" { s.version = DetectMinimumVersion } @@ -60,7 +67,6 @@ func newBuilder(opts ...Option) *builder { // Build builds a CDI spec form the spec builder. func (o *builder) Build() (*spec, error) { raw := o.raw - if raw == nil { raw = &specs.Spec{ Version: o.version, @@ -144,3 +150,10 @@ func WithNoSimplify(noSimplify bool) Option { o.noSimplify = noSimplify } } + +// WithRawSpec sets the raw spec for the spec builder +func WithRawSpec(raw *specs.Spec) Option { + return func(o *builder) { + o.raw = raw + } +}