/*
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
*/

package main

import (
	"fmt"
	"strings"
	"testing"

	"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
	"github.com/opencontainers/runtime-spec/specs-go"
	testlog "github.com/sirupsen/logrus/hooks/test"
	"github.com/stretchr/testify/require"
)

func TestAddNvidiaHook(t *testing.T) {
	logger, logHook := testlog.NewNullLogger()
	shim := nvidiaContainerRuntime{
		logger: logger,
	}

	testCases := []struct {
		spec         *specs.Spec
		errorPrefix  string
		shouldNotAdd bool
	}{
		{
			spec: &specs.Spec{},
		},
		{
			spec: &specs.Spec{
				Hooks: &specs.Hooks{},
			},
		},
		{
			spec: &specs.Spec{
				Hooks: &specs.Hooks{
					Prestart: []specs.Hook{{
						Path: "some-hook",
					}},
				},
			},
		},
		{
			spec: &specs.Spec{
				Hooks: &specs.Hooks{
					Prestart: []specs.Hook{{
						Path: "nvidia-container-runtime-hook",
					}},
				},
			},
			shouldNotAdd: true,
		},
	}

	for i, tc := range testCases {
		logHook.Reset()

		var numPrestartHooks int
		if tc.spec.Hooks != nil {
			numPrestartHooks = len(tc.spec.Hooks.Prestart)
		}

		err := shim.addNVIDIAHook(tc.spec)

		if tc.errorPrefix == "" {
			require.NoErrorf(t, err, "%d: %v", i, tc)
		} else {
			require.Truef(t, strings.HasPrefix(err.Error(), tc.errorPrefix), "%d: %v", i, tc)

			require.NotNilf(t, tc.spec.Hooks, "%d: %v", i, tc)
			require.Equalf(t, 1, nvidiaHookCount(tc.spec.Hooks), "%d: %v", i, tc)

			if tc.shouldNotAdd {
				require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc)
			} else {
				require.Equal(t, numPrestartHooks+1, len(tc.spec.Hooks.Poststart), "%d: %v", i, tc)

				nvidiaHook := tc.spec.Hooks.Poststart[len(tc.spec.Hooks.Poststart)-1]

				// TODO: This assumes that the hook has been set up in the makefile
				expectedPath := "/usr/bin/nvidia-container-runtime-hook"
				require.Equalf(t, expectedPath, nvidiaHook.Path, "%d: %v", i, tc)
				require.Equalf(t, []string{expectedPath, "prestart"}, nvidiaHook.Args, "%d: %v", i, tc)
				require.Emptyf(t, nvidiaHook.Env, "%d: %v", i, tc)
				require.Nilf(t, nvidiaHook.Timeout, "%d: %v", i, tc)
			}
		}
	}
}

func TestNvidiaContainerRuntime(t *testing.T) {
	logger, hook := testlog.NewNullLogger()

	testCases := []struct {
		shim         nvidiaContainerRuntime
		shouldModify bool
		args         []string
		modifyError  error
		writeError   error
	}{
		{
			shim:         nvidiaContainerRuntime{},
			shouldModify: false,
		},
		{
			shim:         nvidiaContainerRuntime{},
			args:         []string{"create"},
			shouldModify: true,
		},
		{
			shim:         nvidiaContainerRuntime{},
			args:         []string{"--bundle=create"},
			shouldModify: false,
		},
		{
			shim:         nvidiaContainerRuntime{},
			args:         []string{"--bundle", "create"},
			shouldModify: false,
		},
		{
			shim:         nvidiaContainerRuntime{},
			args:         []string{"create"},
			shouldModify: true,
		},
		{
			shim:         nvidiaContainerRuntime{},
			args:         []string{"create"},
			modifyError:  fmt.Errorf("error modifying"),
			shouldModify: true,
		},
		{
			shim:         nvidiaContainerRuntime{},
			args:         []string{"create"},
			writeError:   fmt.Errorf("error writing"),
			shouldModify: true,
		},
	}

	for i, tc := range testCases {
		tc.shim.logger = logger
		hook.Reset()

		ociMock := &oci.SpecMock{
			ModifyFunc: func(specModifier oci.SpecModifier) error {
				return tc.modifyError
			},
			FlushFunc: func() error {
				return tc.writeError
			},
		}
		require.Equal(t, tc.shouldModify, tc.shim.modificationRequired(tc.args), "%d: %v", i, tc)

		tc.shim.ociSpec = ociMock
		tc.shim.runtime = &MockShim{}

		err := tc.shim.Exec(tc.args)
		if tc.modifyError != nil || tc.writeError != nil {
			require.Error(t, err, "%d: %v", i, tc)
		} else {
			require.NoError(t, err, "%d: %v", i, tc)
		}

		if tc.shouldModify {
			require.Equal(t, 1, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
		} else {
			require.Equal(t, 0, len(ociMock.ModifyCalls()), "%d: %v", i, tc)
		}

		writeExpected := tc.shouldModify && tc.modifyError == nil
		if writeExpected {
			require.Equal(t, 1, len(ociMock.FlushCalls()), "%d: %v", i, tc)
		} else {
			require.Equal(t, 0, len(ociMock.FlushCalls()), "%d: %v", i, tc)
		}
	}
}

type MockShim struct {
	called      bool
	args        []string
	returnError error
}

func (m *MockShim) Exec(args []string) error {
	m.called = true
	m.args = args
	return m.returnError
}