Fix race condition in mounts cache

This change switches to using the WithCache decorator for
mounts instead of keeping track of a cache locally.

This addresses a race condition when using the mounts structure.

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2024-10-11 15:56:35 +02:00
parent 7fb31bd1dc
commit 7352a90b95

View File

@ -20,7 +20,6 @@ import (
"fmt" "fmt"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger" "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup" "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup"
@ -35,15 +34,13 @@ type mounts struct {
lookup lookup.Locator lookup lookup.Locator
root string root string
required []string required []string
sync.Mutex
cache []Mount
} }
var _ Discover = (*mounts)(nil) var _ Discover = (*mounts)(nil)
// NewMounts creates a discoverer for the required mounts using the specified locator. // NewMounts creates a discoverer for the required mounts using the specified locator.
func NewMounts(logger logger.Interface, lookup lookup.Locator, root string, required []string) Discover { func NewMounts(logger logger.Interface, lookup lookup.Locator, root string, required []string) Discover {
return newMounts(logger, lookup, root, required) return WithCache(newMounts(logger, lookup, root, required))
} }
// newMounts creates a discoverer for the required mounts using the specified locator. // newMounts creates a discoverer for the required mounts using the specified locator.
@ -60,15 +57,6 @@ func (d *mounts) Mounts() ([]Mount, error) {
if d.lookup == nil { if d.lookup == nil {
return nil, fmt.Errorf("no lookup defined") return nil, fmt.Errorf("no lookup defined")
} }
if d.cache != nil {
d.logger.Debugf("returning cached mounts")
return d.cache, nil
}
d.Lock()
defer d.Unlock()
uniqueMounts := make(map[string]Mount) uniqueMounts := make(map[string]Mount)
for _, candidate := range d.required { for _, candidate := range d.required {
@ -112,10 +100,7 @@ func (d *mounts) Mounts() ([]Mount, error) {
for _, m := range uniqueMounts { for _, m := range uniqueMounts {
mounts = append(mounts, m) mounts = append(mounts, m)
} }
return mounts, nil
d.cache = mounts
return d.cache, nil
} }
// relativeTo returns the path relative to the root for the file locator // relativeTo returns the path relative to the root for the file locator