package net.sourceforge.cilib.problem.nn;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import net.sourceforge.cilib.nn.architecture.Layer;
import net.sourceforge.cilib.nn.architecture.visitors.CascadeVisitor;
import net.sourceforge.cilib.nn.architecture.visitors.OutputErrorVisitor;
import net.sourceforge.cilib.nn.components.Neuron;
import net.sourceforge.cilib.problem.objective.Maximise;
import net.sourceforge.cilib.problem.solution.MaximisationFitness;
import net.sourceforge.cilib.type.StringBasedDomainRegistry;
import net.sourceforge.cilib.type.types.Type;
import net.sourceforge.cilib.type.types.container.Vector;

/* loaded from: input_file:net/sourceforge/cilib/problem/nn/CascadeHiddenNeuronCorrelationProblem.class */
public class CascadeHiddenNeuronCorrelationProblem extends NNTrainingProblem {
    private Neuron neuron;
    private ArrayList<Layer> activationCache;
    private ArrayList<Vector> errorCache;
    private Vector errorMeans;
    private int weightEvaluationCount;

    public CascadeHiddenNeuronCorrelationProblem() {
        this.objective = new Maximise();
        this.neuron = new Neuron();
        this.activationCache = new ArrayList<>();
        this.errorCache = new ArrayList<>();
        this.errorMeans = Vector.of();
        this.weightEvaluationCount = 0;
    }

    public CascadeHiddenNeuronCorrelationProblem(CascadeHiddenNeuronCorrelationProblem cascadeHiddenNeuronCorrelationProblem) {
        super(cascadeHiddenNeuronCorrelationProblem);
        this.objective = new Maximise();
        this.neuron = cascadeHiddenNeuronCorrelationProblem.neuron.getClone();
        this.errorMeans = cascadeHiddenNeuronCorrelationProblem.errorMeans.getClone();
        this.weightEvaluationCount = cascadeHiddenNeuronCorrelationProblem.weightEvaluationCount;
        this.activationCache = new ArrayList<>();
        Iterator<Layer> it = cascadeHiddenNeuronCorrelationProblem.activationCache.iterator();
        while (it.hasNext()) {
            this.activationCache.add(it.next().getClone());
        }
        this.errorCache = new ArrayList<>();
        Iterator<Vector> it2 = cascadeHiddenNeuronCorrelationProblem.errorCache.iterator();
        while (it2.hasNext()) {
            this.errorCache.add(it2.next().getClone());
        }
    }

    @Override // net.sourceforge.cilib.problem.AbstractProblem, net.sourceforge.cilib.util.Cloneable
    public CascadeHiddenNeuronCorrelationProblem getClone() {
        return new CascadeHiddenNeuronCorrelationProblem(this);
    }

    @Override // net.sourceforge.cilib.problem.nn.NNTrainingProblem
    public void initialise() {
        generateCache();
        List<Layer> layers = this.neuralNetwork.getArchitecture().getLayers();
        int i = 0;
        for (int i2 = 0; i2 < layers.size() - 1; i2++) {
            i += layers.get(i2).size();
        }
        String domain = this.neuralNetwork.getArchitecture().getArchitectureBuilder().getLayerBuilder().getDomain();
        this.domainRegistry = new StringBasedDomainRegistry();
        this.domainRegistry.setDomainString(domain + "^" + i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.sourceforge.cilib.problem.AbstractProblem
    public MaximisationFitness calculateFitness(Type type) {
        this.weightEvaluationCount += ((Vector) type).size();
        this.neuron.setWeights((Vector) type);
        double[] dArr = new double[this.trainingSet.size()];
        for (int i = 0; i < this.activationCache.size(); i++) {
            dArr[i] = this.neuron.calculateActivation(this.activationCache.get(i));
        }
        double d = dArr[0];
        for (int i2 = 1; i2 < this.trainingSet.size(); i2++) {
            d += dArr[i2];
        }
        double size = d / this.trainingSet.size();
        double d2 = 0.0d;
        for (int i3 = 0; i3 < this.errorMeans.size(); i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.trainingSet.size(); i4++) {
                d3 += (dArr[i4] - size) * (this.errorCache.get(i4).get(i3).doubleValue() - this.errorMeans.get(i3).doubleValue());
            }
            d2 += Math.abs(d3);
        }
        return new MaximisationFitness(Double.valueOf(d2));
    }

    private void generateCache() {
        this.activationCache.clear();
        this.errorCache.clear();
        CascadeVisitor cascadeVisitor = new CascadeVisitor();
        OutputErrorVisitor outputErrorVisitor = new OutputErrorVisitor();
        for (int i = 0; i < this.trainingSet.size(); i++) {
            cascadeVisitor.setInput(this.trainingSet.getRow(i));
            cascadeVisitor.visit(this.neuralNetwork.getArchitecture());
            Layer layer = new Layer();
            for (Layer layer2 : this.neuralNetwork.getArchitecture().getLayers()) {
                for (int i2 = 0; i2 < layer2.size(); i2++) {
                    layer.add(layer2.getNeuron(i2));
                }
            }
            this.activationCache.add(layer.getClone());
            outputErrorVisitor.setInput(this.trainingSet.getRow(i));
            outputErrorVisitor.visit(this.neuralNetwork.getArchitecture());
            this.errorCache.add(outputErrorVisitor.getOutput());
        }
        Vector copyOf = Vector.copyOf(this.errorCache.get(0));
        for (int i3 = 1; i3 < this.trainingSet.size(); i3++) {
            copyOf = copyOf.plus(this.errorCache.get(i3));
        }
        this.errorMeans = copyOf.divide(this.trainingSet.size());
    }

    public ArrayList<Layer> getActivationCache() {
        return this.activationCache;
    }

    public ArrayList<Vector> getErrorCache() {
        return this.errorCache;
    }

    public Vector getErrorMeans() {
        return this.errorMeans;
    }

    public void setNeuron(Neuron neuron) {
        this.neuron = neuron;
    }

    public int getWeightEvaluationCount() {
        return this.weightEvaluationCount;
    }
}
