/*
 * Decompiled with CFR 0.152.
 */
package net.finmath.montecarlo.automaticdifferentiation.forward;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
import java.util.function.IntToDoubleFunction;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import net.finmath.functions.DoubleTernaryOperator;
import net.finmath.montecarlo.RandomVariableFromDoubleArray;
import net.finmath.montecarlo.automaticdifferentiation.RandomVariableDifferentiable;
import net.finmath.stochastic.ConditionalExpectationEstimator;
import net.finmath.stochastic.RandomVariable;
import net.finmath.stochastic.Scalar;

public class RandomVariableDifferentiableAD
implements RandomVariableDifferentiable {
    private static final long serialVersionUID = 2459373647785530657L;
    private static final int typePriorityDefault = 3;
    private final int typePriority;
    private static AtomicLong indexOfNextRandomVariable = new AtomicLong(0L);
    private RandomVariable values;
    private final OperatorTreeNode operatorTreeNode;

    public static RandomVariableDifferentiableAD of(double value) {
        return new RandomVariableDifferentiableAD(value);
    }

    public static RandomVariableDifferentiableAD of(RandomVariable randomVariable) {
        return new RandomVariableDifferentiableAD(randomVariable);
    }

    public RandomVariableDifferentiableAD(double value) {
        this(new RandomVariableFromDoubleArray(value), null, null);
    }

    public RandomVariableDifferentiableAD(double time, double[] realisations) {
        this(new RandomVariableFromDoubleArray(time, realisations), null, null);
    }

    public RandomVariableDifferentiableAD(RandomVariable randomVariable) {
        this(randomVariable, null, null);
    }

    private RandomVariableDifferentiableAD(RandomVariable values, List<RandomVariable> arguments, OperatorType operator) {
        this(values, arguments, null, operator);
    }

    public RandomVariableDifferentiableAD(RandomVariable values, List<RandomVariable> arguments, ConditionalExpectationEstimator estimator, OperatorType operator) {
        this(values, arguments, estimator, operator, 3);
    }

    public RandomVariableDifferentiableAD(RandomVariable values, List<RandomVariable> arguments, ConditionalExpectationEstimator estimator, OperatorType operator, int methodArgumentTypePriority) {
        this.values = values;
        this.operatorTreeNode = new OperatorTreeNode(operator, arguments, estimator);
        this.typePriority = methodArgumentTypePriority;
    }

    public OperatorTreeNode getOperatorTreeNode() {
        return this.operatorTreeNode;
    }

    @Override
    public RandomVariable getValues() {
        return this.values;
    }

    @Override
    public Long getID() {
        return this.getOperatorTreeNode().id;
    }

    @Override
    public Map<Long, RandomVariable> getGradient(Set<Long> independentIDs) {
        HashMap<Long, RandomVariable> derivatives = new HashMap<Long, RandomVariable>();
        derivatives.put(this.getID(), new RandomVariableFromDoubleArray(1.0));
        TreeMap<Long, OperatorTreeNode> independents = new TreeMap<Long, OperatorTreeNode>();
        independents.put(this.getID(), this.getOperatorTreeNode());
        while (independents.size() > 0) {
            Map.Entry independentEntry = independents.lastEntry();
            Long id = (Long)independentEntry.getKey();
            OperatorTreeNode independent = (OperatorTreeNode)independentEntry.getValue();
            List arguments = independent.arguments;
            if (arguments != null && arguments.size() > 0) {
                independent.propagateDerivativesFromResultToArgument(derivatives);
                for (OperatorTreeNode argument : arguments) {
                    if (argument == null) continue;
                    Long argumentId = argument.id;
                    independents.put(argumentId, argument);
                }
                derivatives.remove(id);
            }
            independents.remove(id);
        }
        return derivatives;
    }

    @Override
    public Map<Long, RandomVariable> getTangents(Set<Long> dependentIDs) {
        throw new UnsupportedOperationException();
    }

    @Override
    public boolean equals(RandomVariable randomVariable) {
        return this.getValues().equals(randomVariable);
    }

    @Override
    public double getFiltrationTime() {
        return this.getValues().getFiltrationTime();
    }

    @Override
    public int getTypePriority() {
        return this.typePriority;
    }

    @Override
    public double get(int pathOrState) {
        return this.getValues().get(pathOrState);
    }

    @Override
    public int size() {
        return this.getValues().size();
    }

    @Override
    public boolean isDeterministic() {
        return this.getValues().isDeterministic();
    }

    @Override
    public double[] getRealizations() {
        return this.getValues().getRealizations();
    }

    @Override
    public Double doubleValue() {
        return this.getValues().doubleValue();
    }

    @Override
    public double getMin() {
        return this.getValues().getMin();
    }

    @Override
    public double getMax() {
        return this.getValues().getMax();
    }

    @Override
    public double getAverage() {
        return this.getValues().getAverage();
    }

    @Override
    public double getAverage(RandomVariable probabilities) {
        return this.getValues().getAverage(probabilities);
    }

    @Override
    public double getVariance() {
        return this.getValues().getVariance();
    }

    @Override
    public double getVariance(RandomVariable probabilities) {
        return this.getValues().getVariance(probabilities);
    }

    @Override
    public double getSampleVariance() {
        return this.getValues().getSampleVariance();
    }

    @Override
    public double getStandardDeviation() {
        return this.getValues().getStandardDeviation();
    }

    @Override
    public double getStandardDeviation(RandomVariable probabilities) {
        return this.getValues().getStandardDeviation(probabilities);
    }

    @Override
    public double getStandardError() {
        return this.getValues().getStandardError();
    }

    @Override
    public double getStandardError(RandomVariable probabilities) {
        return this.getValues().getStandardError(probabilities);
    }

    @Override
    public double getQuantile(double quantile) {
        return this.getValues().getQuantile(quantile);
    }

    @Override
    public double getQuantile(double quantile, RandomVariable probabilities) {
        return this.getValues().getQuantile(quantile, probabilities);
    }

    @Override
    public double getQuantileExpectation(double quantileStart, double quantileEnd) {
        return this.getValues().getQuantileExpectation(quantileStart, quantileEnd);
    }

    @Override
    public double[] getHistogram(double[] intervalPoints) {
        return this.getValues().getHistogram(intervalPoints);
    }

    @Override
    public double[][] getHistogram(int numberOfPoints, double standardDeviations) {
        return this.getValues().getHistogram(numberOfPoints, standardDeviations);
    }

    @Override
    public RandomVariable cache() {
        this.values = this.values.cache();
        return this;
    }

    @Override
    public RandomVariable cap(double cap) {
        return new RandomVariableDifferentiableAD(this.getValues().cap(cap), Arrays.asList(this, new RandomVariableFromDoubleArray(cap)), OperatorType.CAP);
    }

    @Override
    public RandomVariable floor(double floor) {
        return new RandomVariableDifferentiableAD(this.getValues().floor(floor), Arrays.asList(this, new RandomVariableFromDoubleArray(floor)), OperatorType.FLOOR);
    }

    @Override
    public RandomVariable add(double value) {
        return new RandomVariableDifferentiableAD(this.getValues().add(value), Arrays.asList(this, new RandomVariableFromDoubleArray(value)), OperatorType.ADD);
    }

    @Override
    public RandomVariable sub(double value) {
        return new RandomVariableDifferentiableAD(this.getValues().sub(value), Arrays.asList(this, new RandomVariableFromDoubleArray(value)), OperatorType.SUB);
    }

    @Override
    public RandomVariable mult(double value) {
        return new RandomVariableDifferentiableAD(this.getValues().mult(value), Arrays.asList(this, new RandomVariableFromDoubleArray(value)), OperatorType.MULT);
    }

    @Override
    public RandomVariable div(double value) {
        return new RandomVariableDifferentiableAD(this.getValues().div(value), Arrays.asList(this, new RandomVariableFromDoubleArray(value)), OperatorType.DIV);
    }

    @Override
    public RandomVariable pow(double exponent) {
        return new RandomVariableDifferentiableAD(this.getValues().pow(exponent), Arrays.asList(this, new RandomVariableFromDoubleArray(exponent)), OperatorType.POW);
    }

    @Override
    public RandomVariable average() {
        return new RandomVariableDifferentiableAD(this.getValues().average(), Arrays.asList(this), OperatorType.AVERAGE);
    }

    @Override
    public RandomVariable getConditionalExpectation(ConditionalExpectationEstimator estimator) {
        return new RandomVariableDifferentiableAD(this.getValues().average(), Arrays.asList(this), estimator, OperatorType.CONDITIONAL_EXPECTATION);
    }

    @Override
    public RandomVariable squared() {
        return new RandomVariableDifferentiableAD(this.getValues().squared(), Arrays.asList(this), OperatorType.SQUARED);
    }

    @Override
    public RandomVariable sqrt() {
        return new RandomVariableDifferentiableAD(this.getValues().sqrt(), Arrays.asList(this), OperatorType.SQRT);
    }

    @Override
    public RandomVariable exp() {
        return new RandomVariableDifferentiableAD(this.getValues().exp(), Arrays.asList(this), OperatorType.EXP);
    }

    @Override
    public RandomVariable log() {
        return new RandomVariableDifferentiableAD(this.getValues().log(), Arrays.asList(this), OperatorType.LOG);
    }

    @Override
    public RandomVariable sin() {
        return new RandomVariableDifferentiableAD(this.getValues().sin(), Arrays.asList(this), OperatorType.SIN);
    }

    @Override
    public RandomVariable cos() {
        return new RandomVariableDifferentiableAD(this.getValues().cos(), Arrays.asList(this), OperatorType.COS);
    }

    @Override
    public RandomVariable add(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.add(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().add(randomVariable.getValues()), Arrays.asList(this, randomVariable), OperatorType.ADD);
    }

    @Override
    public RandomVariable sub(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.bus(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().sub(randomVariable.getValues()), Arrays.asList(this, randomVariable), OperatorType.SUB);
    }

    @Override
    public RandomVariable bus(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.sub(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().bus(randomVariable.getValues()), Arrays.asList(randomVariable, this), OperatorType.SUB);
    }

    @Override
    public RandomVariable mult(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.mult(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().mult(randomVariable.getValues()), Arrays.asList(this, randomVariable), OperatorType.MULT);
    }

    @Override
    public RandomVariable div(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.vid(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().div(randomVariable.getValues()), Arrays.asList(this, randomVariable), OperatorType.DIV);
    }

    @Override
    public RandomVariable vid(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.div(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().vid(randomVariable.getValues()), Arrays.asList(randomVariable, this), OperatorType.DIV);
    }

    @Override
    public RandomVariable cap(RandomVariable randomVariable) {
        if (randomVariable.getTypePriority() > this.getTypePriority()) {
            return randomVariable.cap(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().cap(randomVariable.getValues()), Arrays.asList(this, randomVariable), OperatorType.CAP);
    }

    @Override
    public RandomVariable floor(RandomVariable floor) {
        if (floor.getTypePriority() > this.getTypePriority()) {
            return floor.floor(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().floor(floor.getValues()), Arrays.asList(this, floor), OperatorType.FLOOR);
    }

    @Override
    public RandomVariable accrue(RandomVariable rate, double periodLength) {
        if (rate.getTypePriority() > this.getTypePriority()) {
            return rate.mult(periodLength).add(1.0).mult(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().accrue(rate.getValues(), periodLength), Arrays.asList(this, rate, new RandomVariableFromDoubleArray(periodLength)), OperatorType.ACCRUE);
    }

    @Override
    public RandomVariable discount(RandomVariable rate, double periodLength) {
        if (rate.getTypePriority() > this.getTypePriority()) {
            return rate.mult(periodLength).add(1.0).invert().mult(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().discount(rate.getValues(), periodLength), Arrays.asList(this, rate, new RandomVariableFromDoubleArray(periodLength)), OperatorType.DISCOUNT);
    }

    @Override
    public RandomVariable choose(RandomVariable valueIfTriggerNonNegative, RandomVariable valueIfTriggerNegative) {
        return new RandomVariableDifferentiableAD(this.getValues().choose(valueIfTriggerNonNegative.getValues(), valueIfTriggerNegative.getValues()), Arrays.asList(this, valueIfTriggerNonNegative, valueIfTriggerNegative), OperatorType.BARRIER);
    }

    @Override
    public RandomVariable invert() {
        return new RandomVariableDifferentiableAD(this.getValues().invert(), Arrays.asList(this), OperatorType.INVERT);
    }

    @Override
    public RandomVariable abs() {
        return new RandomVariableDifferentiableAD(this.getValues().abs(), Arrays.asList(this), OperatorType.ABS);
    }

    @Override
    public RandomVariable addProduct(RandomVariable factor1, double factor2) {
        if (factor1.getTypePriority() > this.getTypePriority()) {
            return factor1.mult(factor2).add(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().addProduct(factor1.getValues(), factor2), Arrays.asList(this, factor1, new RandomVariableFromDoubleArray(factor2)), OperatorType.ADDPRODUCT);
    }

    @Override
    public RandomVariable addProduct(RandomVariable factor1, RandomVariable factor2) {
        if (factor1.getTypePriority() > this.getTypePriority() || factor2.getTypePriority() > this.getTypePriority()) {
            return factor1.mult(factor2).add(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().addProduct(factor1.getValues(), factor2.getValues()), Arrays.asList(this, factor1, factor2), OperatorType.ADDPRODUCT);
    }

    @Override
    public RandomVariable addRatio(RandomVariable numerator, RandomVariable denominator) {
        if (numerator.getTypePriority() > this.getTypePriority() || denominator.getTypePriority() > this.getTypePriority()) {
            return numerator.div(denominator).add(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().addRatio(numerator.getValues(), denominator.getValues()), Arrays.asList(this, numerator, denominator), OperatorType.ADDRATIO);
    }

    @Override
    public RandomVariable subRatio(RandomVariable numerator, RandomVariable denominator) {
        if (numerator.getTypePriority() > this.getTypePriority() || denominator.getTypePriority() > this.getTypePriority()) {
            return numerator.div(denominator).mult(-1.0).add(this);
        }
        return new RandomVariableDifferentiableAD(this.getValues().subRatio(numerator.getValues(), denominator.getValues()), Arrays.asList(this, numerator, denominator), OperatorType.SUBRATIO);
    }

    @Override
    public RandomVariable isNaN() {
        return this.getValues().isNaN();
    }

    @Override
    public IntToDoubleFunction getOperator() {
        return this.getValues().getOperator();
    }

    @Override
    public DoubleStream getRealizationsStream() {
        return this.getValues().getRealizationsStream();
    }

    @Override
    public RandomVariable apply(DoubleUnaryOperator operator) {
        throw new UnsupportedOperationException("Applying functions is not supported.");
    }

    @Override
    public RandomVariable apply(DoubleBinaryOperator operator, RandomVariable argument) {
        throw new UnsupportedOperationException("Applying functions is not supported.");
    }

    @Override
    public RandomVariable apply(DoubleTernaryOperator operator, RandomVariable argument1, RandomVariable argument2) {
        throw new UnsupportedOperationException("Applying functions is not supported.");
    }

    public RandomVariable getVarianceAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAD(new RandomVariableFromDoubleArray(this.getVariance()), Arrays.asList(this), OperatorType.VARIANCE);
    }

    public RandomVariable getSampleVarianceAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAD(new RandomVariableFromDoubleArray(this.getSampleVariance()), Arrays.asList(this), OperatorType.SVARIANCE);
    }

    public RandomVariable getStandardDeviationAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAD(new RandomVariableFromDoubleArray(this.getStandardDeviation()), Arrays.asList(this), OperatorType.STDEV);
    }

    public RandomVariable getStandardErrorAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAD(new RandomVariableFromDoubleArray(this.getStandardError()), Arrays.asList(this), OperatorType.STDERROR);
    }

    public RandomVariable getMinAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAD(new RandomVariableFromDoubleArray(this.getMin()), Arrays.asList(this), OperatorType.MIN);
    }

    public RandomVariable getMaxAsRandomVariableAAD() {
        return new RandomVariableDifferentiableAD(new RandomVariableFromDoubleArray(this.getMax()), Arrays.asList(this), OperatorType.MAX);
    }

    @Override
    public Map<Long, RandomVariable> getTangents() {
        return null;
    }

    static /* synthetic */ AtomicLong access$000() {
        return indexOfNextRandomVariable;
    }

    private static class OperatorTreeNode {
        private final Long id = RandomVariableDifferentiableAD.access$000().getAndIncrement();
        private final OperatorType operatorType;
        private final List<OperatorTreeNode> arguments;
        private final List<RandomVariable> argumentValues;
        private final Object operator;
        private static final RandomVariable zero = new Scalar(0.0);
        private static final RandomVariable one = new Scalar(1.0);
        private static final RandomVariable minusOne = new Scalar(-1.0);

        OperatorTreeNode(OperatorType operatorType, List<RandomVariable> arguments, Object operator) {
            this(operatorType, arguments != null ? arguments.stream().map(new Function<RandomVariable, OperatorTreeNode>(){

                @Override
                public OperatorTreeNode apply(RandomVariable x) {
                    return x != null && x instanceof RandomVariableDifferentiableAD ? ((RandomVariableDifferentiableAD)x).getOperatorTreeNode() : null;
                }
            }).collect(Collectors.toList()) : null, arguments != null ? arguments.stream().map(new Function<RandomVariable, RandomVariable>(){

                @Override
                public RandomVariable apply(RandomVariable x) {
                    return x != null && x instanceof RandomVariableDifferentiableAD ? ((RandomVariableDifferentiableAD)x).getValues() : x;
                }
            }).collect(Collectors.toList()) : null, operator);
        }

        OperatorTreeNode(OperatorType operatorType, List<OperatorTreeNode> arguments, List<RandomVariable> argumentValues, Object operator) {
            this.operatorType = operatorType;
            this.arguments = arguments;
            this.operator = operator;
            if (operatorType != null && (operatorType.equals((Object)OperatorType.ADD) || operatorType.equals((Object)OperatorType.SUB))) {
                argumentValues = null;
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.AVERAGE)) {
                argumentValues = null;
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.MULT)) {
                if (arguments.get(0) == null) {
                    argumentValues.set(1, null);
                }
                if (arguments.get(1) == null) {
                    argumentValues.set(0, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.DIV)) {
                if (arguments.get(1) == null) {
                    argumentValues.set(0, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.ADDPRODUCT)) {
                argumentValues.set(0, null);
                if (arguments.get(1) == null) {
                    argumentValues.set(2, null);
                }
                if (arguments.get(2) == null) {
                    argumentValues.set(1, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.ACCRUE)) {
                if (arguments.get(1) == null && arguments.get(2) == null) {
                    argumentValues.set(0, null);
                }
                if (arguments.get(0) == null && arguments.get(1) == null) {
                    argumentValues.set(1, null);
                }
                if (arguments.get(0) == null && arguments.get(2) == null) {
                    argumentValues.set(2, null);
                }
            } else if (operatorType != null && operatorType.equals((Object)OperatorType.BARRIER) && arguments.get(0) == null) {
                argumentValues.set(1, null);
                argumentValues.set(2, null);
            }
            this.argumentValues = argumentValues;
        }

        private void propagateDerivativesFromResultToArgument(Map<Long, RandomVariable> derivatives) {
            if (this.arguments == null) {
                return;
            }
            for (int argumentIndex = 0; argumentIndex < this.arguments.size(); ++argumentIndex) {
                OperatorTreeNode argument = this.arguments.get(argumentIndex);
                if (argument == null) continue;
                Long argumentID = argument.id;
                RandomVariable partialDerivative = this.getPartialDerivative(argument, argumentIndex);
                RandomVariable derivative = derivatives.get(this.id);
                RandomVariable argumentDerivative = derivatives.get(argumentID);
                if (this.operatorType == OperatorType.AVERAGE) {
                    derivative = derivative.average();
                }
                if (this.operatorType == OperatorType.CONDITIONAL_EXPECTATION) {
                    ConditionalExpectationEstimator estimator = (ConditionalExpectationEstimator)this.operator;
                    derivative = estimator.getConditionalExpectation(derivative);
                }
                argumentDerivative = argumentDerivative == null ? derivative.mult(partialDerivative) : argumentDerivative.addProduct(partialDerivative, derivative);
                derivatives.put(argumentID, argumentDerivative);
            }
        }

        private RandomVariable getPartialDerivative(OperatorTreeNode differential, int differentialIndex) {
            if (!this.arguments.contains(differential)) {
                return zero;
            }
            RandomVariable X = this.arguments.size() > 0 && this.argumentValues != null ? this.argumentValues.get(0) : null;
            RandomVariable Y = this.arguments.size() > 1 && this.argumentValues != null ? this.argumentValues.get(1) : null;
            RandomVariable Z = this.arguments.size() > 2 && this.argumentValues != null ? this.argumentValues.get(2) : null;
            RandomVariable derivative = null;
            switch (this.operatorType) {
                case SQUARED: {
                    derivative = X.mult(2.0);
                    break;
                }
                case SQRT: {
                    derivative = X.sqrt().invert().mult(0.5);
                    break;
                }
                case EXP: {
                    derivative = X.exp();
                    break;
                }
                case LOG: {
                    derivative = X.invert();
                    break;
                }
                case SIN: {
                    derivative = X.cos();
                    break;
                }
                case COS: {
                    derivative = X.sin().mult(-1.0);
                    break;
                }
                case INVERT: {
                    derivative = X.invert().squared().mult(-1.0);
                    break;
                }
                case AVERAGE: {
                    derivative = one;
                    break;
                }
                case CONDITIONAL_EXPECTATION: {
                    derivative = one;
                    break;
                }
                case VARIANCE: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)X.size());
                    break;
                }
                case STDEV: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)X.size()).mult(0.5).div(Math.sqrt(X.getVariance()));
                    break;
                }
                case MIN: {
                    final double min = X.getMin();
                    derivative = X.apply(new DoubleUnaryOperator(){

                        @Override
                        public double applyAsDouble(double x) {
                            return x == min ? 1.0 : 0.0;
                        }
                    });
                    break;
                }
                case MAX: {
                    final double max = X.getMax();
                    derivative = X.apply(new DoubleUnaryOperator(){

                        @Override
                        public double applyAsDouble(double x) {
                            return x == max ? 1.0 : 0.0;
                        }
                    });
                    break;
                }
                case ABS: {
                    derivative = X.choose(one, minusOne);
                    break;
                }
                case STDERROR: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)X.size()).mult(0.5).div(Math.sqrt(X.getVariance() * (double)X.size()));
                    break;
                }
                case SVARIANCE: {
                    derivative = X.sub(X.getAverage() * (2.0 * (double)X.size() - 1.0) / (double)X.size()).mult(2.0 / (double)(X.size() - 1));
                    break;
                }
                case ADD: {
                    derivative = one;
                    break;
                }
                case SUB: {
                    derivative = differentialIndex == 0 ? one : minusOne;
                    break;
                }
                case MULT: {
                    derivative = differentialIndex == 0 ? Y : X;
                    break;
                }
                case DIV: {
                    derivative = differentialIndex == 0 ? Y.invert() : X.div(Y.squared()).mult(-1.0);
                    break;
                }
                case CAP: {
                    if (differentialIndex == 0) {
                        derivative = X.sub(Y).choose(zero, one);
                        break;
                    }
                    derivative = X.sub(Y).choose(one, zero);
                    break;
                }
                case FLOOR: {
                    if (differentialIndex == 0) {
                        derivative = X.sub(Y).choose(one, zero);
                        break;
                    }
                    derivative = X.sub(Y).choose(zero, one);
                    break;
                }
                case AVERAGE2: {
                    derivative = differentialIndex == 0 ? Y : X;
                    break;
                }
                case VARIANCE2: {
                    derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y) * (double)(X.size() - 1)).sub(X.getAverage(Y)))) : X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X) * (double)(X.size() - 1)).sub(Y.getAverage(X))));
                    break;
                }
                case STDEV2: {
                    derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y) * (double)(X.size() - 1)).sub(X.getAverage(Y)))).div(Math.sqrt(X.getVariance(Y))) : X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X) * (double)(X.size() - 1)).sub(Y.getAverage(X)))).div(Math.sqrt(Y.getVariance(X)));
                    break;
                }
                case STDERROR2: {
                    derivative = differentialIndex == 0 ? Y.mult(2.0).mult(X.mult(Y.add(X.getAverage(Y) * (double)(X.size() - 1)).sub(X.getAverage(Y)))).div(Math.sqrt(X.getVariance(Y) * (double)X.size())) : X.mult(2.0).mult(Y.mult(X.add(Y.getAverage(X) * (double)(X.size() - 1)).sub(Y.getAverage(X)))).div(Math.sqrt(Y.getVariance(X) * (double)Y.size()));
                    break;
                }
                case POW: {
                    derivative = differentialIndex == 0 ? X.pow(Y.getAverage() - 1.0).mult(Y) : zero;
                    break;
                }
                case ADDPRODUCT: {
                    if (differentialIndex == 0) {
                        derivative = one;
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = Z;
                        break;
                    }
                    derivative = Y;
                    break;
                }
                case ADDRATIO: {
                    if (differentialIndex == 0) {
                        derivative = one;
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = Z.invert();
                        break;
                    }
                    derivative = Y.div(Z.squared()).mult(-1.0);
                    break;
                }
                case SUBRATIO: {
                    if (differentialIndex == 0) {
                        derivative = one;
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = Z.invert().mult(-1.0);
                        break;
                    }
                    derivative = Y.div(Z.squared());
                    break;
                }
                case ACCRUE: {
                    if (differentialIndex == 0) {
                        derivative = Y.mult(Z).add(1.0);
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = X.mult(Z);
                        break;
                    }
                    derivative = X.mult(Y);
                    break;
                }
                case DISCOUNT: {
                    if (differentialIndex == 0) {
                        derivative = Y.mult(Z).add(1.0).invert();
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = X.mult(Z).div(Y.mult(Z).add(1.0).squared()).mult(-1.0);
                        break;
                    }
                    derivative = X.mult(Y).div(Y.mult(Z).add(1.0).squared()).mult(-1.0);
                    break;
                }
                case BARRIER: {
                    if (differentialIndex == 0) {
                        derivative = Y.sub(Z);
                        double epsilon = 0.2 * X.getStandardDeviation();
                        derivative = derivative.mult(X.add(epsilon / 2.0).choose(new RandomVariableFromDoubleArray(1.0), new RandomVariableFromDoubleArray(0.0)));
                        derivative = derivative.mult(X.sub(epsilon / 2.0).choose(new RandomVariableFromDoubleArray(0.0), new RandomVariableFromDoubleArray(1.0)));
                        derivative = derivative.div(epsilon);
                        break;
                    }
                    if (differentialIndex == 1) {
                        derivative = X.choose(one, zero);
                        break;
                    }
                    derivative = X.choose(zero, one);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Operation " + this.operatorType.name() + " not supported in differentiation.");
                }
            }
            return derivative;
        }
    }

    private static enum OperatorType {
        ADD,
        MULT,
        DIV,
        SUB,
        SQUARED,
        SQRT,
        LOG,
        SIN,
        COS,
        EXP,
        INVERT,
        CAP,
        FLOOR,
        ABS,
        ADDPRODUCT,
        ADDRATIO,
        SUBRATIO,
        BARRIER,
        DISCOUNT,
        ACCRUE,
        POW,
        MIN,
        MAX,
        AVERAGE,
        VARIANCE,
        STDEV,
        STDERROR,
        SVARIANCE,
        AVERAGE2,
        VARIANCE2,
        STDEV2,
        STDERROR2,
        CONDITIONAL_EXPECTATION;

    }
}

