// Copyright 2017 The Cockroach Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
// implied. See the License for the specific language governing
// permissions and limitations under the License. See the AUTHORS file
// for names of contributors.

package tpcc

import (
	"context"
	"math"
	"strconv"
	"strings"
	"time"

	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
	"github.com/cockroachdb/cockroach/pkg/workload"
	"github.com/cockroachdb/cockroach/pkg/workload/histogram"
	"github.com/pkg/errors"
	"golang.org/x/exp/rand"
)

const (
	numWorkersPerWarehouse = 10
	numConnsPerWarehouse   = 2
)

// tpccTX is an interface for running a TPCC transaction.
type tpccTx interface {
	// run executes the TPCC transaction against the given warehouse ID.
	run(ctx context.Context, wID int) (interface{}, error)
}

type createTxFn func(ctx context.Context, config *tpcc, mcp *workload.MultiConnPool) (tpccTx, error)

// txInfo stores high-level information about the TPCC transactions. The create
// function is used to create an object that implements tpccTx.
type txInfo struct {
	name        string // display name
	constructor createTxFn
	keyingTime  int     // keying time in seconds, see 5.2.5.7
	thinkTime   float64 // minimum mean of think time distribution, 5.2.5.7
	weight      int     // percent likelihood that each transaction type is run
}

var allTxs = [...]txInfo{
	{
		name:        "newOrder",
		constructor: createNewOrder,
		keyingTime:  18,
		thinkTime:   12,
	},
	{
		name:        "payment",
		constructor: createPayment,
		keyingTime:  3,
		thinkTime:   12,
	},
	{
		name:        "orderStatus",
		constructor: createOrderStatus,
		keyingTime:  2,
		thinkTime:   10,
	},
	{
		name:        "delivery",
		constructor: createDelivery,
		keyingTime:  2,
		thinkTime:   5,
	},
	{
		name:        "stockLevel",
		constructor: createStockLevel,
		keyingTime:  2,
		thinkTime:   5,
	},
}

func initializeMix(config *tpcc) error {
	config.txInfos = append([]txInfo(nil), allTxs[0:]...)
	nameToTx := make(map[string]int)
	for i, tx := range config.txInfos {
		nameToTx[tx.name] = i
	}

	items := strings.Split(config.mix, `,`)
	totalWeight := 0
	for _, item := range items {
		kv := strings.Split(item, `=`)
		if len(kv) != 2 {
			return errors.Errorf(`Invalid mix %s: %s is not a k=v pair`, config.mix, item)
		}
		txName, weightStr := kv[0], kv[1]

		weight, err := strconv.Atoi(weightStr)
		if err != nil {
			return errors.Errorf(
				`Invalid percentage mix %s: %s is not an integer`, config.mix, weightStr)
		}

		i, ok := nameToTx[txName]
		if !ok {
			return errors.Errorf(
				`Invalid percentage mix %s: no such transaction %s`, config.mix, txName)
		}

		config.txInfos[i].weight = weight
		totalWeight += weight
	}

	config.deck = make([]int, 0, totalWeight)
	for i, t := range config.txInfos {
		for j := 0; j < t.weight; j++ {
			config.deck = append(config.deck, i)
		}
	}

	return nil
}

type worker struct {
	config *tpcc
	// txs maps 1-to-1 with config.txInfos.
	txs       []tpccTx
	hists     *histogram.Histograms
	warehouse int

	deckPerm []int
	permIdx  int
}

func newWorker(
	ctx context.Context,
	config *tpcc,
	mcp *workload.MultiConnPool,
	hists *histogram.Histograms,
	warehouse int,
) (*worker, error) {
	w := &worker{
		config:    config,
		txs:       make([]tpccTx, len(config.txInfos)),
		hists:     hists,
		warehouse: warehouse,
		deckPerm:  append([]int(nil), config.deck...),
		permIdx:   len(config.deck),
	}
	for i := range w.txs {
		var err error
		w.txs[i], err = config.txInfos[i].constructor(ctx, config, mcp)
		if err != nil {
			return nil, err
		}
	}
	return w, nil
}

func (w *worker) run(ctx context.Context) error {
	// 5.2.4.2: the required mix is achieved by selecting each new transaction
	// uniformly at random from a deck whose content guarantees the required
	// transaction mix. Each pass through a deck must be made in a different
	// uniformly random order.
	if w.permIdx == len(w.deckPerm) {
		rand.Shuffle(len(w.deckPerm), func(i, j int) {
			w.deckPerm[i], w.deckPerm[j] = w.deckPerm[j], w.deckPerm[i]
		})
		w.permIdx = 0
	}
	// Move through our permutation slice until its exhausted, using each value to
	// to index into our deck of transactions, which contains indexes into the
	// txInfos / txs slices.
	opIdx := w.deckPerm[w.permIdx]
	txInfo := &w.config.txInfos[opIdx]
	tx := w.txs[opIdx]
	w.permIdx++

	warehouseID := w.warehouse
	if !w.config.doWaits {
		warehouseID = rand.Intn(w.config.warehouses)
	} else {
		// Wait out the entire keying and think time even if the context is
		// expired. This prevents all workers from immediately restarting when
		// the workload's ramp period expires, which can overload a cluster.
		time.Sleep(time.Duration(txInfo.keyingTime) * time.Second)
	}

	// Run transactions with a background context because we don't want to
	// cancel them when the context expires. Instead, let them finish normally
	// but don't account for them in the histogram.
	start := timeutil.Now()
	if _, err := tx.run(context.Background(), warehouseID); err != nil {
		return errors.Wrapf(err, "error in %s", txInfo.name)
	}
	if ctx.Err() == nil {
		elapsed := timeutil.Since(start)
		w.hists.Get(txInfo.name).Record(elapsed)
	}

	if w.config.doWaits {
		// 5.2.5.4: Think time is taken independently from a negative exponential
		// distribution. Think time = -log(r) * u, where r is a uniform random number
		// between 0 and 1 and u is the mean think time per operation.
		// Each distribution is truncated at 10 times its mean value.
		thinkTime := -math.Log(rand.Float64()) * txInfo.thinkTime
		if thinkTime > (txInfo.thinkTime * 10) {
			thinkTime = txInfo.thinkTime * 10
		}
		time.Sleep(time.Duration(thinkTime) * time.Second)
	}
	return ctx.Err()
}
