package tests

import (
	"context"
	"encoding/json"
	"sync"
	"sync/atomic"
	"testing"
	"time"

	"k8s.io/apimachinery/pkg/util/rand"

	"github.com/go-test/deep"

	v12 "k8s.io/apimachinery/pkg/apis/meta/v1"

	"k8s.io/apimachinery/pkg/types"

	"k8s.io/apimachinery/pkg/api/resource"

	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"

	"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/catalog"
	catalogMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/catalog/mocks"
	"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/workqueue"

	v1 "k8s.io/api/core/v1"

	"github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io"
	"github.com/stretchr/testify/mock"

	coreMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks"
	ioMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks"

	pluginCore "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core"
	"github.com/lyft/flytestdlib/promutils"
	"github.com/lyft/flytestdlib/storage"
	"github.com/stretchr/testify/assert"

	idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core"
)

func createSampleContainerTask() *idlCore.Container {
	return &idlCore.Container{
		Command: []string{"cmd"},
		Args:    []string{"{{$inputPrefix}}"},
		Image:   "img1",
		Config: []*idlCore.KeyValuePair{
			{
				Key:   "dynamic_queue",
				Value: "queue1",
			},
		},
	}
}

func BuildTaskTemplate() *idlCore.TaskTemplate {
	return &idlCore.TaskTemplate{
		Target: &idlCore.TaskTemplate_Container{
			Container: createSampleContainerTask(),
		},
	}
}

func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *idlCore.TaskTemplate,
	inputs *idlCore.LiteralMap, expectedOutputs *idlCore.LiteralMap, expectedFailure *idlCore.ExecutionError,
	iterationUpdate func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error) {

	ctx := context.Background()

	ds, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())
	assert.NoError(t, err)

	execID := rand.String(3)

	basePrefix := storage.DataReference("fake://bucket/prefix/" + execID)
	assert.NoError(t, ds.WriteProtobuf(ctx, basePrefix+"/inputs.pb", storage.Options{}, inputs))

	tr := &coreMocks.TaskReader{}
	tr.OnRead(ctx).Return(template, nil)

	inputReader := &ioMocks.InputReader{}
	inputReader.OnGetInputPrefixPath().Return(basePrefix)
	inputReader.OnGetInputPath().Return(basePrefix + "/inputs.pb")

	outputWriter := &ioMocks.OutputWriter{}
	outputWriter.OnGetOutputPrefixPath().Return(basePrefix)
	outputWriter.OnGetErrorPath().Return(basePrefix + "/error.pb")
	outputWriter.OnGetOutputPath().Return(basePrefix + "/outputs.pb")
	outputWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
		or := args.Get(1).(io.OutputReader)
		literals, ee, err := or.Read(ctx)
		assert.NoError(t, err)

		if ee != nil {
			assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetErrorPath(), storage.Options{}, ee))
		}

		if literals != nil {
			assert.NoError(t, ds.WriteProtobuf(ctx, outputWriter.GetOutputPath(), storage.Options{}, literals))
		}
	})

	pluginStateWriter := &coreMocks.PluginStateWriter{}
	latestKnownState := atomic.Value{}
	pluginStateWriter.OnPutMatch(mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
		latestKnownState.Store(args.Get(1))
	})

	pluginStateWriter.OnReset().Return(nil).Run(func(args mock.Arguments) {
		latestKnownState.Store(nil)
	})

	pluginStateReader := &coreMocks.PluginStateReader{}
	pluginStateReader.OnGetMatch(mock.Anything).Return(0, nil).Run(func(args mock.Arguments) {
		o := args.Get(0)
		x, err := json.Marshal(latestKnownState.Load())
		assert.NoError(t, err)
		assert.NoError(t, json.Unmarshal(x, &o))
	})
	pluginStateReader.OnGetStateVersion().Return(0)

	tID := &coreMocks.TaskExecutionID{}
	tID.OnGetGeneratedName().Return(execID + "-my-task-1")
	tID.OnGetID().Return(idlCore.TaskExecutionIdentifier{
		TaskId: &idlCore.Identifier{
			ResourceType: idlCore.ResourceType_TASK,
			Project:      "a",
			Domain:       "d",
			Name:         "n",
			Version:      "abc",
		},
		NodeExecutionId: &idlCore.NodeExecutionIdentifier{
			NodeId: "node1",
			ExecutionId: &idlCore.WorkflowExecutionIdentifier{
				Project: "a",
				Domain:  "d",
				Name:    "exec",
			},
		},
		RetryAttempt: 0,
	})

	overrides := &coreMocks.TaskOverrides{}
	overrides.OnGetConfig().Return(&v1.ConfigMap{Data: map[string]string{
		"dynamic-queue": "queue1",
	}})
	overrides.OnGetResources().Return(&v1.ResourceRequirements{
		Requests: map[v1.ResourceName]resource.Quantity{},
		Limits:   map[v1.ResourceName]resource.Quantity{},
	})

	tMeta := &coreMocks.TaskExecutionMetadata{}
	tMeta.OnGetTaskExecutionID().Return(tID)
	tMeta.OnGetOverrides().Return(overrides)
	tMeta.OnGetK8sServiceAccount().Return("s")
	tMeta.OnGetNamespace().Return("fake-development")
	tMeta.OnGetLabels().Return(map[string]string{})
	tMeta.OnGetAnnotations().Return(map[string]string{})
	tMeta.OnGetOwnerReference().Return(v12.OwnerReference{})
	tMeta.OnGetOwnerID().Return(types.NamespacedName{
		Namespace: "fake-development",
		Name:      execID,
	})

	catClient := &catalogMocks.Client{}
	catData := sync.Map{}
	catClient.On("Get", mock.Anything, mock.Anything).Return(
		func(ctx context.Context, key catalog.Key) io.OutputReader {
			data, found := catData.Load(key)
			if !found {
				return nil
			}

			or := &ioMocks.OutputReader{}
			or.OnExistsMatch(mock.Anything).Return(true, nil)
			or.OnIsErrorMatch(mock.Anything).Return(false, nil)
			or.OnReadMatch(mock.Anything).Return(data.(*idlCore.LiteralMap), nil, nil)
			return or
		},
		func(ctx context.Context, key catalog.Key) error {
			_, found := catData.Load(key)
			if !found {
				return status.Error(codes.NotFound, "No output found for key")
			}

			return nil
		})
	catClient.On(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
		Return(nil).
		Run(func(args mock.Arguments) {
			key := args.Get(1).(catalog.Key)
			or := args.Get(2).(io.OutputReader)
			o, ee, err := or.Read(ctx)
			assert.NoError(t, err)
			// TODO: Outputing error is not yet supported.
			assert.Nil(t, ee)
			catData.Store(key, o)
		})
	cat, err := catalog.NewAsyncClient(catClient, catalog.Config{
		ReaderWorkqueueConfig: workqueue.Config{
			MaxRetries:         0,
			Workers:            2,
			IndexCacheMaxItems: 100,
		},
		WriterWorkqueueConfig: workqueue.Config{
			MaxRetries:         0,
			Workers:            2,
			IndexCacheMaxItems: 100,
		},
	}, promutils.NewTestScope())
	assert.NoError(t, err)
	assert.NoError(t, cat.Start(ctx))

	eRecorder := &coreMocks.EventsRecorder{}
	eRecorder.OnRecordRawMatch(mock.Anything, mock.Anything).Return(nil)

	resourceManager := &coreMocks.ResourceManager{}
	resourceManager.OnAllocateResourceMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(pluginCore.AllocationStatusGranted, nil)
	resourceManager.OnReleaseResourceMatch(mock.Anything, mock.Anything, mock.Anything).Return(nil)

	tCtx := &coreMocks.TaskExecutionContext{}
	tCtx.OnInputReader().Return(inputReader)
	tCtx.OnTaskRefreshIndicator().Return(func(ctx context.Context) {})
	tCtx.OnOutputWriter().Return(outputWriter)
	tCtx.OnDataStore().Return(ds)
	tCtx.OnTaskReader().Return(tr)
	tCtx.OnPluginStateWriter().Return(pluginStateWriter)
	tCtx.OnPluginStateReader().Return(pluginStateReader)
	tCtx.OnTaskExecutionMetadata().Return(tMeta)
	tCtx.OnCatalog().Return(cat)
	tCtx.OnEventsRecorder().Return(eRecorder)
	tCtx.OnResourceManager().Return(resourceManager)
	tCtx.OnMaxDatasetSizeBytes().Return(1000000)
	// TODO: return that
	tCtx.OnSecretManager()

	trns := pluginCore.DoTransitionType(pluginCore.TransitionTypeBarrier, pluginCore.PhaseInfoQueued(time.Now(), 0, ""))
	for !trns.Info().Phase().IsTerminal() {
		trns, err = executor.Handle(ctx, tCtx)
		assert.NoError(t, err)
		if iterationUpdate != nil {
			assert.NoError(t, iterationUpdate(ctx, tCtx))
		}
	}

	assert.NoError(t, err)
	if expectedOutputs != nil {
		assert.True(t, trns.Info().Phase().IsSuccess())
		actualOutputs := &idlCore.LiteralMap{}
		assert.NoError(t, ds.ReadProtobuf(context.TODO(), outputWriter.GetOutputPath(), actualOutputs))

		if diff := deep.Equal(expectedOutputs, actualOutputs); diff != nil {
			t.Errorf("Expected != Actual. Diff: %v", diff)
		}
	} else if expectedFailure != nil {
		assert.True(t, trns.Info().Phase().IsFailure())
		actualError := &idlCore.ExecutionError{}
		assert.NoError(t, ds.ReadProtobuf(context.TODO(), outputWriter.GetErrorPath(), actualError))

		if diff := deep.Equal(expectedFailure, actualError); diff != nil {
			t.Errorf("Expected != Actual. Diff: %v", diff)
		}
	}
}
