diff --git a/cmd/nvidia-ctk/hook/hook.go b/cmd/nvidia-ctk/hook/hook.go new file mode 100644 index 00000000..4d85dcd5 --- /dev/null +++ b/cmd/nvidia-ctk/hook/hook.go @@ -0,0 +1,50 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package hook + +import ( + ldcache "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/hook/update-ldcache" + "github.com/sirupsen/logrus" + "github.com/urfave/cli/v2" +) + +type hookCommand struct { + logger *logrus.Logger +} + +// NewCommand constructs a hook command with the specified logger +func NewCommand(logger *logrus.Logger) *cli.Command { + c := hookCommand{ + logger: logger, + } + return c.build() +} + +// build +func (m hookCommand) build() *cli.Command { + // Create the 'hook' command + hook := cli.Command{ + Name: "hook", + Usage: "A collection of hooks that may be injected into an OCI spec", + } + + hook.Subcommands = []*cli.Command{ + ldcache.NewCommand(m.logger), + } + + return &hook +} diff --git a/cmd/nvidia-ctk/hook/update-ldcache/update-ldcache.go b/cmd/nvidia-ctk/hook/update-ldcache/update-ldcache.go new file mode 100644 index 00000000..c0132b33 --- /dev/null +++ b/cmd/nvidia-ctk/hook/update-ldcache/update-ldcache.go @@ -0,0 +1,155 @@ +/** +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package ldcache + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "syscall" + + "github.com/NVIDIA/nvidia-container-toolkit/internal/oci" + "github.com/opencontainers/runtime-spec/specs-go" + "github.com/sirupsen/logrus" + "github.com/urfave/cli/v2" +) + +type command struct { + logger *logrus.Logger +} + +type config struct { + folders cli.StringSlice + containerSpec string +} + +// NewCommand constructs an update-ldcache command with the specified logger +func NewCommand(logger *logrus.Logger) *cli.Command { + c := command{ + logger: logger, + } + return c.build() +} + +// build the update-ldcache command +func (m command) build() *cli.Command { + cfg := config{} + + // Create the 'update-ldcache' command + c := cli.Command{ + Name: "update-ldcache", + Usage: "Update ldcache in a container by running ldconfig", + Action: func(c *cli.Context) error { + return m.run(c, &cfg) + }, + } + + c.Flags = []cli.Flag{ + &cli.StringSliceFlag{ + Name: "folders", + Usage: "Specifiy the additional folders to add to /etc/ld.so.conf before updating the ld cache", + Destination: &cfg.folders, + }, + &cli.StringFlag{ + Name: "containerSpec", + Usage: "Specify the path to the OCI container spec. If empty or '-' the spec will be read from STDIN", + Destination: &cfg.containerSpec, + }, + } + + return &c +} + +func (m command) run(c *cli.Context, cfg *config) error { + var s specs.State + + inputReader := os.Stdin + if cfg.containerSpec != "" && cfg.containerSpec != "-" { + inputFile, err := os.Open(cfg.containerSpec) + if err != nil { + return fmt.Errorf("failed to open intput: %v", err) + } + defer inputFile.Close() + inputReader = inputFile + } + + d := json.NewDecoder(inputReader) + if err := d.Decode(&s); err != nil { + return fmt.Errorf("failed to decode container state: %v", err) + } + + specFilePath := oci.GetSpecFilePath(s.Bundle) + specFile, err := os.Open(specFilePath) + if err != nil { + return fmt.Errorf("failed to open OCI spec file: %v", err) + } + defer specFile.Close() + + spec, err := oci.LoadFrom(specFile) + if err != nil { + return fmt.Errorf("failed to load OCI spec: %v", err) + } + + var containerRoot string + if spec.Root != nil { + containerRoot = spec.Root.Path + } + + err = m.createConfig(containerRoot, cfg.folders.Value()) + if err != nil { + return fmt.Errorf("failed to update ld.so.conf: %v", err) + } + + args := []string{"/sbin/ldconfig"} + if containerRoot != "" { + args = append(args, "-r", containerRoot) + } + + return syscall.Exec(args[0], args, nil) +} + +// createConfig creates (or updates) /etc/ld.so.conf.d/nvcr-.conf in the container +// to include the required paths. +func (m command) createConfig(root string, folders []string) error { + if len(folders) == 0 { + m.logger.Debugf("No folders to add to /etc/ld.so.conf") + return nil + } + + configFile, err := os.CreateTemp(filepath.Join(root, "/etc/ld.so.conf.d"), "nvcr-*.conf") + if err != nil { + return fmt.Errorf("failed to create config file: %v", err) + } + defer configFile.Close() + + m.logger.Debugf("Adding folders %v to %v", folders, configFile.Name()) + + configured := make(map[string]bool) + for _, folder := range folders { + if configured[folder] { + continue + } + _, err = configFile.WriteString(fmt.Sprintf("%s\n", folder)) + if err != nil { + return fmt.Errorf("failed to update ld.so.conf.d: %v", err) + } + configured[folder] = true + } + + return nil +} diff --git a/cmd/nvidia-ctk/main.go b/cmd/nvidia-ctk/main.go index a2b77bb7..08374ed6 100644 --- a/cmd/nvidia-ctk/main.go +++ b/cmd/nvidia-ctk/main.go @@ -19,6 +19,7 @@ package main import ( "os" + "github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk/hook" log "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" ) @@ -68,7 +69,9 @@ func main() { } // Define the subcommands - c.Commands = []*cli.Command{} + c.Commands = []*cli.Command{ + hook.NewCommand(logger), + } // Run the CLI err := c.Run(os.Args) diff --git a/internal/oci/spec_file.go b/internal/oci/spec_file.go index ff0cbb46..3465652d 100644 --- a/internal/oci/spec_file.go +++ b/internal/oci/spec_file.go @@ -52,7 +52,7 @@ func (s *fileSpec) Load() error { } defer specFile.Close() - spec, err := loadFrom(specFile) + spec, err := LoadFrom(specFile) if err != nil { return fmt.Errorf("error loading OCI specification from file: %v", err) } @@ -60,8 +60,8 @@ func (s *fileSpec) Load() error { return nil } -// loadFrom reads the contents of the OCI spec from the specified io.Reader. -func loadFrom(reader io.Reader) (*specs.Spec, error) { +// LoadFrom reads the contents of the OCI spec from the specified io.Reader. +func LoadFrom(reader io.Reader) (*specs.Spec, error) { decoder := json.NewDecoder(reader) var spec specs.Spec diff --git a/internal/oci/spec_file_test.go b/internal/oci/spec_file_test.go index 94dfb3b3..e1c1fe0f 100644 --- a/internal/oci/spec_file_test.go +++ b/internal/oci/spec_file_test.go @@ -44,7 +44,7 @@ func TestLoadFrom(t *testing.T) { for i, tc := range testCases { var spec *specs.Spec - spec, err := loadFrom(bytes.NewReader(tc.contents)) + spec, err := LoadFrom(bytes.NewReader(tc.contents)) if tc.isError { require.Error(t, err, "%d: %v", i, tc)