/**
# Copyright 2024 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package discover

import (
	"testing"

	"github.com/stretchr/testify/require"
)

func TestWithWithDriverDotSoSymlinks(t *testing.T) {
	testCases := []struct {
		description          string
		discover             Discover
		version              string
		expectedDevices      []Device
		expectedDevicesError error
		expectedHooks        []Hook
		expectedHooksError   error
		expectedMounts       []Mount
		expectedMountsError  error
	}{
		{
			description: "empty discoverer remains empty",
			discover:    None{},
		},
		{
			description: "non-matching discoverer remains unchanged",
			discover: &DiscoverMock{
				DevicesFunc: func() ([]Device, error) {
					devices := []Device{
						{
							Path: "/dev/dev1",
						},
					}
					return devices, nil
				},
				HooksFunc: func() ([]Hook, error) {
					hooks := []Hook{
						{
							Lifecycle: "prestart",
							Path:      "/path/to/a/hook",
							Args:      []string{"hook", "arg1", "arg2"},
						},
					}
					return hooks, nil
				},
				MountsFunc: func() ([]Mount, error) {
					mounts := []Mount{
						{
							Path: "/usr/lib/libnotcuda.so.1.2.3",
						},
					}
					return mounts, nil
				},
			},
			expectedDevices: []Device{
				{
					Path: "/dev/dev1",
				},
			},
			expectedHooks: []Hook{
				{
					Lifecycle: "prestart",
					Path:      "/path/to/a/hook",
					Args:      []string{"hook", "arg1", "arg2"},
				},
			},
			expectedMounts: []Mount{
				{
					Path: "/usr/lib/libnotcuda.so.1.2.3",
				},
			},
		},
		{
			description: "libcuda.so.RM_VERSION is matched",
			discover: &DiscoverMock{
				DevicesFunc: func() ([]Device, error) {
					return nil, nil
				},
				HooksFunc: func() ([]Hook, error) {
					return nil, nil
				},
				MountsFunc: func() ([]Mount, error) {
					mounts := []Mount{
						{
							Path: "/usr/lib/libcuda.so.1.2.3",
						},
					}
					return mounts, nil
				},
			},
			version: "1.2.3",
			expectedMounts: []Mount{
				{
					Path: "/usr/lib/libcuda.so.1.2.3",
				},
			},
			expectedHooks: []Hook{
				{
					Lifecycle: "createContainer",
					Path:      "/path/to/nvidia-cdi-hook",
					Args:      []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
				},
			},
		},
		{
			description: "libcuda.so.RM_VERSION is matched by pattern",
			discover: &DiscoverMock{
				DevicesFunc: func() ([]Device, error) {
					return nil, nil
				},
				HooksFunc: func() ([]Hook, error) {
					return nil, nil
				},
				MountsFunc: func() ([]Mount, error) {
					mounts := []Mount{
						{
							Path: "/usr/lib/libcuda.so.1.2.3",
						},
					}
					return mounts, nil
				},
			},
			version: "",
			expectedMounts: []Mount{
				{
					Path: "/usr/lib/libcuda.so.1.2.3",
				},
			},
			expectedHooks: []Hook{
				{
					Lifecycle: "createContainer",
					Path:      "/path/to/nvidia-cdi-hook",
					Args:      []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
				},
			},
		},
		{
			description: "beta libcuda.so.RM_VERSION is matched",
			discover: &DiscoverMock{
				DevicesFunc: func() ([]Device, error) {
					return nil, nil
				},
				HooksFunc: func() ([]Hook, error) {
					return nil, nil
				},
				MountsFunc: func() ([]Mount, error) {
					mounts := []Mount{
						{
							Path: "/usr/lib/libcuda.so.1.2",
						},
					}
					return mounts, nil
				},
			},
			expectedMounts: []Mount{
				{
					Path: "/usr/lib/libcuda.so.1.2",
				},
			},
			expectedHooks: []Hook{
				{
					Lifecycle: "createContainer",
					Path:      "/path/to/nvidia-cdi-hook",
					Args:      []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
				},
			},
		},
		{
			description: "non-matching libcuda.so.RM_VERSION is ignored",
			discover: &DiscoverMock{
				DevicesFunc: func() ([]Device, error) {
					return nil, nil
				},
				HooksFunc: func() ([]Hook, error) {
					return nil, nil
				},
				MountsFunc: func() ([]Mount, error) {
					mounts := []Mount{
						{
							Path: "/usr/lib/libcuda.so.1.2.3",
						},
					}
					return mounts, nil
				},
			},
			version: "4.5.6",
			expectedMounts: []Mount{
				{
					Path: "/usr/lib/libcuda.so.1.2.3",
				},
			},
		},
		{
			description: "hooks are extended",
			discover: &DiscoverMock{
				DevicesFunc: func() ([]Device, error) {
					return nil, nil
				},
				HooksFunc: func() ([]Hook, error) {
					hooks := []Hook{
						{
							Lifecycle: "prestart",
							Path:      "/path/to/a/hook",
							Args:      []string{"hook", "arg1", "arg2"},
						},
					}
					return hooks, nil
				},
				MountsFunc: func() ([]Mount, error) {
					mounts := []Mount{
						{
							Path: "/usr/lib/libcuda.so.1.2.3",
						},
					}
					return mounts, nil
				},
			},
			version: "1.2.3",
			expectedMounts: []Mount{
				{
					Path: "/usr/lib/libcuda.so.1.2.3",
				},
			},
			expectedHooks: []Hook{
				{
					Lifecycle: "prestart",
					Path:      "/path/to/a/hook",
					Args:      []string{"hook", "arg1", "arg2"},
				},
				{
					Lifecycle: "createContainer",
					Path:      "/path/to/nvidia-cdi-hook",
					Args:      []string{"nvidia-cdi-hook", "create-symlinks", "--link", "libcuda.so.1::/usr/lib/libcuda.so"},
				},
			},
		},
		{
			description: "all driver so symlinks are matched",
			discover: &DiscoverMock{
				DevicesFunc: func() ([]Device, error) {
					return nil, nil
				},
				HooksFunc: func() ([]Hook, error) {
					return nil, nil
				},
				MountsFunc: func() ([]Mount, error) {
					mounts := []Mount{
						{
							Path: "/usr/lib/libcuda.so.1.2.3",
						},
						{
							Path: "/usr/lib/libGLX_nvidia.so.1.2.3",
						},
						{
							Path: "/usr/lib/libnvidia-opticalflow.so.1.2.3",
						},
						{
							Path: "/usr/lib/libanother.so.1.2.3",
						},
					}
					return mounts, nil
				},
			},
			expectedMounts: []Mount{
				{
					Path: "/usr/lib/libcuda.so.1.2.3",
				},
				{
					Path: "/usr/lib/libGLX_nvidia.so.1.2.3",
				},
				{
					Path: "/usr/lib/libnvidia-opticalflow.so.1.2.3",
				},
				{
					Path: "/usr/lib/libanother.so.1.2.3",
				},
			},
			expectedHooks: []Hook{
				{
					Lifecycle: "createContainer",
					Path:      "/path/to/nvidia-cdi-hook",
					Args: []string{
						"nvidia-cdi-hook", "create-symlinks",
						"--link", "libcuda.so.1::/usr/lib/libcuda.so",
						"--link", "libGLX_nvidia.so.1.2.3::/usr/lib/libGLX_indirect.so.0",
						"--link", "libnvidia-opticalflow.so.1::/usr/lib/libnvidia-opticalflow.so",
					},
				},
			},
		},
	}

	for _, tc := range testCases {
		t.Run(tc.description, func(t *testing.T) {
			d := WithDriverDotSoSymlinks(
				tc.discover,
				tc.version,
				"/path/to/nvidia-cdi-hook",
			)

			devices, err := d.Devices()
			require.ErrorIs(t, err, tc.expectedDevicesError)
			require.EqualValues(t, tc.expectedDevices, devices)

			hooks, err := d.Hooks()
			require.ErrorIs(t, err, tc.expectedHooksError)
			require.EqualValues(t, tc.expectedHooks, hooks)

			mounts, err := d.Mounts()
			require.ErrorIs(t, err, tc.expectedMountsError)
			require.EqualValues(t, tc.expectedMounts, mounts)
		})
	}
}