package burlap.behavior.stochasticgames.solvers;

import scpsolver.constraints.LinearBiggerThanEqualsConstraint;
import scpsolver.constraints.LinearEqualsConstraint;
import scpsolver.lpsolver.SolverFactory;
import scpsolver.problems.LinearProgram;

/* loaded from: input_file:burlap/behavior/stochasticgames/solvers/CorrelatedEquilibriumSolver.class */
public class CorrelatedEquilibriumSolver {

    /* loaded from: input_file:burlap/behavior/stochasticgames/solvers/CorrelatedEquilibriumSolver$CorrelatedEquilibriumObjective.class */
    public enum CorrelatedEquilibriumObjective {
        UTILITARIAN,
        EGALITARIAN,
        REPUBLICAN,
        LIBERTARIAN
    }

    private CorrelatedEquilibriumSolver() {
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    public static void main(String[] strArr) {
        double[][] correlatedEQJointStrategy = getCorrelatedEQJointStrategy(CorrelatedEquilibriumObjective.UTILITARIAN, new double[]{new double[]{6.0d, 2.0d}, new double[]{7.0d, 0.0d}}, new double[]{new double[]{6.0d, 7.0d}, new double[]{2.0d, 0.0d}});
        double[] marginalizeRowPlayerStrategy = GeneralBimatrixSolverTools.marginalizeRowPlayerStrategy(correlatedEQJointStrategy);
        double[] marginalizeColPlayerStrategy = GeneralBimatrixSolverTools.marginalizeColPlayerStrategy(correlatedEQJointStrategy);
        for (int i = 0; i < correlatedEQJointStrategy.length; i++) {
            for (int i2 = 0; i2 < correlatedEQJointStrategy[i].length; i2++) {
                System.out.print(correlatedEQJointStrategy[i][i2] + " ");
            }
            System.out.println("");
        }
        System.out.println("------");
        for (double d : marginalizeRowPlayerStrategy) {
            System.out.print(d + " ");
        }
        System.out.println("");
        for (double d2 : marginalizeColPlayerStrategy) {
            System.out.print(d2 + " ");
        }
        System.out.println("");
    }

    public static double[][] getCorrelatedEQJointStrategy(CorrelatedEquilibriumObjective correlatedEquilibriumObjective, double[][] dArr, double[][] dArr2) {
        if (correlatedEquilibriumObjective.equals(CorrelatedEquilibriumObjective.UTILITARIAN)) {
            return getCorrelatedEQJointStrategyUtilitarian(dArr, dArr2);
        }
        if (correlatedEquilibriumObjective.equals(CorrelatedEquilibriumObjective.EGALITARIAN)) {
            return getCorrelatedEQJointStrategyEgalitarian(dArr, dArr2);
        }
        if (correlatedEquilibriumObjective.equals(CorrelatedEquilibriumObjective.REPUBLICAN)) {
            return getCorrelatedEQJointStrategyRepublican(dArr, dArr2);
        }
        if (correlatedEquilibriumObjective.equals(CorrelatedEquilibriumObjective.LIBERTARIAN)) {
            return getCorrelatedEQJointStrategyLibertarianForRow(dArr, dArr2);
        }
        throw new RuntimeException("Unknown objective type");
    }

    public static double[][] getCorrelatedEQJointStrategyUtilitarian(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = length * length2;
        LinearProgram linearProgram = new LinearProgram(getUtilitarianObjective(dArr, dArr2));
        addCorrelatedEquilibriumMainConstraints(linearProgram, dArr, dArr2, i, 0);
        return runLPAndGetJointActionProbs(linearProgram, length, length2);
    }

    public static double[][] getCorrelatedEQJointStrategyEgalitarian(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = (length * length2) + 1;
        LinearProgram linearProgram = new LinearProgram(getEgalitarianObjective(dArr, dArr2));
        int addCorrelatedEquilibriumMainConstraints = addCorrelatedEquilibriumMainConstraints(linearProgram, dArr, dArr2, i, 0);
        double[] dArr3 = new double[i];
        double[] dArr4 = new double[i];
        for (int i2 = 0; i2 < i - 1; i2++) {
            int[] rowCol = rowCol(i2, length2);
            dArr3[i2] = dArr[rowCol[0]][rowCol[1]];
            dArr4[i2] = dArr2[rowCol[0]][rowCol[1]];
        }
        dArr3[i - 1] = -1.0d;
        dArr4[i - 1] = -1.0d;
        linearProgram.addConstraint(new LinearBiggerThanEqualsConstraint(dArr3, 0.0d, "c" + addCorrelatedEquilibriumMainConstraints));
        linearProgram.addConstraint(new LinearBiggerThanEqualsConstraint(dArr4, 0.0d, "c" + (addCorrelatedEquilibriumMainConstraints + 1)));
        return runLPAndGetJointActionProbs(linearProgram, length, length2);
    }

    public static double[][] getCorrelatedEQJointStrategyRepublican(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = length * length2;
        LinearProgram linearProgram = new LinearProgram(getRepublicanObjective(dArr));
        addCorrelatedEquilibriumMainConstraints(linearProgram, dArr, dArr2, i, 0);
        double[][] runLPAndGetJointActionProbs = runLPAndGetJointActionProbs(linearProgram, length, length2);
        double d = GeneralBimatrixSolverTools.expectedPayoffs(dArr, dArr2, runLPAndGetJointActionProbs)[0];
        LinearProgram linearProgram2 = new LinearProgram(getRepublicanObjective(dArr2));
        addCorrelatedEquilibriumMainConstraints(linearProgram2, dArr, dArr2, i, 0);
        double[][] runLPAndGetJointActionProbs2 = runLPAndGetJointActionProbs(linearProgram2, length, length2);
        return d > GeneralBimatrixSolverTools.expectedPayoffs(dArr, dArr2, runLPAndGetJointActionProbs2)[0] ? runLPAndGetJointActionProbs : runLPAndGetJointActionProbs2;
    }

    public static double[][] getCorrelatedEQJointStrategyLibertarianForRow(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = length * length2;
        LinearProgram linearProgram = new LinearProgram(getRepublicanObjective(dArr));
        addCorrelatedEquilibriumMainConstraints(linearProgram, dArr, dArr2, i, 0);
        return runLPAndGetJointActionProbs(linearProgram, length, length2);
    }

    public static double[][] getCorrelatedEQJointStrategyLibertarianForCol(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = length * length2;
        LinearProgram linearProgram = new LinearProgram(getRepublicanObjective(dArr2));
        addCorrelatedEquilibriumMainConstraints(linearProgram, dArr, dArr2, i, 0);
        return runLPAndGetJointActionProbs(linearProgram, length, length2);
    }

    protected static double[][] runLPAndGetJointActionProbs(LinearProgram linearProgram, int i, int i2) {
        int i3 = i * i2;
        linearProgram.setMinProblem(false);
        double[] solve = SolverFactory.newDefault().solve(linearProgram);
        double[][] dArr = new double[i][i2];
        for (int i4 = 0; i4 < i3; i4++) {
            int[] rowCol = rowCol(i4, i2);
            dArr[rowCol[0]][rowCol[1]] = solve[i4];
        }
        return dArr;
    }

    protected static int addCorrelatedEquilibriumMainConstraints(LinearProgram linearProgram, double[][] dArr, double[][] dArr2, int i, int i2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i3 = length * length2;
        for (int i4 = 0; i4 < length; i4++) {
            for (int i5 = 0; i5 < length; i5++) {
                if (i5 != i4) {
                    double[] constantDoubleArray = GeneralBimatrixSolverTools.constantDoubleArray(0.0d, i);
                    for (int i6 = 0; i6 < length2; i6++) {
                        constantDoubleArray[jointIndex(i4, i6, length2)] = dArr[i4][i6] - dArr[i5][i6];
                    }
                    linearProgram.addConstraint(new LinearBiggerThanEqualsConstraint(constantDoubleArray, 0.0d, "c" + i2));
                    i2++;
                }
            }
        }
        for (int i7 = 0; i7 < length2; i7++) {
            for (int i8 = 0; i8 < length2; i8++) {
                if (i8 != i7) {
                    double[] constantDoubleArray2 = GeneralBimatrixSolverTools.constantDoubleArray(0.0d, i);
                    for (int i9 = 0; i9 < length; i9++) {
                        constantDoubleArray2[jointIndex(i9, i7, length2)] = dArr2[i9][i7] - dArr2[i9][i8];
                    }
                    linearProgram.addConstraint(new LinearBiggerThanEqualsConstraint(constantDoubleArray2, 0.0d, "c" + i2));
                    i2++;
                }
            }
        }
        double[] constantDoubleArray3 = GeneralBimatrixSolverTools.constantDoubleArray(0.0d, i);
        for (int i10 = 0; i10 < i3; i10++) {
            constantDoubleArray3[i10] = 1.0d;
        }
        linearProgram.addConstraint(new LinearEqualsConstraint(constantDoubleArray3, 1.0d, "c" + i2));
        int i11 = i2 + 1;
        for (int i12 = 0; i12 < i3; i12++) {
            linearProgram.addConstraint(new LinearBiggerThanEqualsConstraint(GeneralBimatrixSolverTools.zero1Array(i12, i), 0.0d, "c" + i11));
            i11++;
        }
        return i11;
    }

    public static double[] getUtilitarianObjective(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = length * length2;
        double[] dArr3 = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            int[] rowCol = rowCol(i2, length2);
            int i3 = rowCol[0];
            int i4 = rowCol[1];
            dArr3[i2] = dArr[i3][i4] + dArr2[i3][i4];
        }
        return dArr3;
    }

    public static double[] getEgalitarianObjective(double[][] dArr, double[][] dArr2) {
        int length = (dArr.length * dArr[0].length) + 1;
        return GeneralBimatrixSolverTools.zero1Array(length - 1, length);
    }

    public static double[] getRepublicanObjective(double[][] dArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = length * length2;
        double[] dArr2 = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            int[] rowCol = rowCol(i2, length2);
            dArr2[i2] = dArr[rowCol[0]][rowCol[1]];
        }
        return dArr2;
    }

    protected static int jointIndex(int i, int i2, int i3) {
        return (i * i3) + i2;
    }

    protected static int[] rowCol(int i, int i2) {
        return new int[]{i / i2, i % i2};
    }

    protected static double[][] removeZeroRows(double[][] dArr) {
        int i = 0;
        double[][] dArr2 = new double[dArr.length][dArr[0].length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (!isZeroArray(dArr[i2])) {
                for (int i3 = 0; i3 < dArr[i2].length; i3++) {
                    dArr2[i][i3] = dArr[i2][i3];
                }
                i++;
            }
        }
        if (i == dArr.length) {
            return dArr2;
        }
        double[][] dArr3 = new double[i][dArr[0].length];
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < dArr2[i4].length; i5++) {
                dArr3[i4][i5] = dArr2[i4][i5];
            }
        }
        return dArr3;
    }

    protected static boolean isZeroArray(double[] dArr) {
        for (double d : dArr) {
            if (d != 0.0d) {
                return false;
            }
        }
        return true;
    }

    public static double[] roundNegativesToZero(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > 0.0d) {
                dArr2[i] = dArr[i];
            } else {
                dArr2[i] = 0.0d;
            }
        }
        return dArr2;
    }
}
