package net.sourceforge.cilib.cascadecorrelationalgorithm;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import net.sourceforge.cilib.algorithm.AbstractAlgorithm;
import net.sourceforge.cilib.nn.NeuralNetwork;
import net.sourceforge.cilib.nn.architecture.builder.CascadeArchitectureBuilder;
import net.sourceforge.cilib.nn.architecture.builder.LayerConfiguration;
import net.sourceforge.cilib.nn.architecture.visitors.CascadeVisitor;
import net.sourceforge.cilib.nn.components.Neuron;
import net.sourceforge.cilib.problem.Problem;
import net.sourceforge.cilib.problem.nn.CascadeHiddenNeuronCorrelationProblem;
import net.sourceforge.cilib.problem.nn.CascadeOutputLayerTrainingProblem;
import net.sourceforge.cilib.problem.nn.NNTrainingProblem;
import net.sourceforge.cilib.problem.solution.Fitness;
import net.sourceforge.cilib.problem.solution.InferiorFitness;
import net.sourceforge.cilib.problem.solution.OptimisationSolution;
import net.sourceforge.cilib.type.types.Numeric;
import net.sourceforge.cilib.type.types.Real;
import net.sourceforge.cilib.type.types.container.Vector;

/* loaded from: input_file:net/sourceforge/cilib/cascadecorrelationalgorithm/CascadeCorrelationAlgorithm.class */
public class CascadeCorrelationAlgorithm extends AbstractAlgorithm {
    private AbstractAlgorithm phase1Algorithm;
    private AbstractAlgorithm phase2Algorithm;
    private CascadeHiddenNeuronCorrelationProblem phase1Problem;
    private CascadeOutputLayerTrainingProblem phase2Problem;
    private Fitness trackedFitness;
    private Neuron neuronPrototype;

    public CascadeCorrelationAlgorithm() {
        this.neuronPrototype = new Neuron();
        this.trackedFitness = InferiorFitness.instance();
        this.phase1Problem = new CascadeHiddenNeuronCorrelationProblem();
        this.phase2Problem = new CascadeOutputLayerTrainingProblem();
    }

    public CascadeCorrelationAlgorithm(CascadeCorrelationAlgorithm cascadeCorrelationAlgorithm) {
        this.neuronPrototype = cascadeCorrelationAlgorithm.neuronPrototype.getClone();
        this.trackedFitness = cascadeCorrelationAlgorithm.trackedFitness.getClone();
        this.phase1Algorithm = cascadeCorrelationAlgorithm.phase1Algorithm.getClone();
        this.phase2Algorithm = cascadeCorrelationAlgorithm.phase2Algorithm.getClone();
        this.phase1Problem = cascadeCorrelationAlgorithm.phase1Problem.getClone();
        this.phase2Problem = cascadeCorrelationAlgorithm.phase2Problem.getClone();
    }

    @Override // net.sourceforge.cilib.algorithm.AbstractAlgorithm, net.sourceforge.cilib.util.Cloneable
    public CascadeCorrelationAlgorithm getClone() {
        return new CascadeCorrelationAlgorithm(this);
    }

    @Override // net.sourceforge.cilib.algorithm.AbstractAlgorithm
    public void algorithmInitialisation() {
        NNTrainingProblem nNTrainingProblem = (NNTrainingProblem) this.optimisationProblem;
        nNTrainingProblem.initialise();
        NeuralNetwork neuralNetwork = nNTrainingProblem.getNeuralNetwork();
        this.phase1Problem.setNeuron(this.neuronPrototype);
        this.phase1Problem.setTrainingSet(nNTrainingProblem.getTrainingSet());
        this.phase1Problem.setValidationSet(nNTrainingProblem.getValidationSet());
        this.phase1Problem.setGeneralisationSet(nNTrainingProblem.getGeneralisationSet());
        this.phase1Problem.setNeuralNetwork(neuralNetwork);
        this.phase2Problem.setTrainingSet(nNTrainingProblem.getTrainingSet());
        this.phase2Problem.setValidationSet(nNTrainingProblem.getValidationSet());
        this.phase2Problem.setGeneralisationSet(nNTrainingProblem.getGeneralisationSet());
        this.phase2Problem.setNeuralNetwork(neuralNetwork);
    }

    @Override // net.sourceforge.cilib.algorithm.AbstractAlgorithm
    protected void algorithmIteration() {
        if (getIterations() > 0) {
            phase1();
        }
        phase2();
    }

    @VisibleForTesting
    protected void phase1() {
        NeuralNetwork neuralNetwork = ((NNTrainingProblem) this.optimisationProblem).getNeuralNetwork();
        Vector weights = neuralNetwork.getWeights();
        AbstractAlgorithm clone = this.phase1Algorithm.getClone();
        this.phase1Problem.initialise();
        clone.setOptimisationProblem(this.phase1Problem);
        clone.performInitialisation();
        clone.runAlgorithm();
        LinkedList newLinkedList = Lists.newLinkedList(clone.getSolutions());
        List<LayerConfiguration> layerConfigurations = neuralNetwork.getArchitecture().getArchitectureBuilder().getLayerConfigurations();
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < layerConfigurations.size() - 1; i3++) {
            i2 += i * layerConfigurations.get(i3).getSize();
            i += layerConfigurations.get(i3).getSize();
            if (layerConfigurations.get(i3).isBias()) {
                i++;
            }
        }
        Iterator it = newLinkedList.iterator();
        while (it.hasNext()) {
            Iterator<Numeric> it2 = ((Vector) ((OptimisationSolution) it.next()).getPosition()).iterator();
            while (it2.hasNext()) {
                int i4 = i2;
                i2++;
                weights.insert(i4, it2.next());
            }
        }
        int i5 = i2 + i;
        for (int i6 = 0; i6 < layerConfigurations.get(layerConfigurations.size() - 1).getSize(); i6++) {
            for (int i7 = 0; i7 < newLinkedList.size(); i7++) {
                weights.insert(i5, Real.valueOf(Double.NaN));
            }
            i5 += newLinkedList.size() + i;
        }
        neuralNetwork.getArchitecture().getArchitectureBuilder().addLayer(layerConfigurations.size() - 1, new LayerConfiguration(newLinkedList.size(), this.neuronPrototype.getActivationFunction(), false));
        neuralNetwork.initialise();
        neuralNetwork.setWeights(weights);
    }

    @VisibleForTesting
    protected void phase2() {
        NeuralNetwork neuralNetwork = ((NNTrainingProblem) this.optimisationProblem).getNeuralNetwork();
        Vector weights = neuralNetwork.getWeights();
        AbstractAlgorithm clone = this.phase2Algorithm.getClone();
        this.phase2Problem.initialise();
        clone.setOptimisationProblem(this.phase2Problem);
        clone.performInitialisation();
        clone.runAlgorithm();
        OptimisationSolution bestSolution = clone.getBestSolution();
        this.trackedFitness = bestSolution.getFitness();
        Vector vector = (Vector) bestSolution.getPosition();
        for (int i = 0; i < vector.size(); i++) {
            weights.set((weights.size() - 1) - i, vector.get((vector.size() - 1) - i));
        }
        neuralNetwork.setWeights(weights);
    }

    @Override // net.sourceforge.cilib.algorithm.AbstractAlgorithm, net.sourceforge.cilib.algorithm.Algorithm
    public OptimisationSolution getBestSolution() {
        return new OptimisationSolution(((NNTrainingProblem) this.optimisationProblem).getNeuralNetwork().getWeights(), this.trackedFitness);
    }

    @Override // net.sourceforge.cilib.algorithm.AbstractAlgorithm, net.sourceforge.cilib.algorithm.Algorithm
    public Iterable<OptimisationSolution> getSolutions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(getBestSolution());
        return arrayList;
    }

    @Override // net.sourceforge.cilib.algorithm.AbstractAlgorithm, net.sourceforge.cilib.algorithm.Algorithm
    public void setOptimisationProblem(Problem problem) {
        Preconditions.checkArgument(problem instanceof NNTrainingProblem, "CascadeCorrelationAlgorithm can only be used with NNTrainingProblem.");
        Preconditions.checkArgument(((NNTrainingProblem) problem).getNeuralNetwork().getArchitecture().getArchitectureBuilder() instanceof CascadeArchitectureBuilder, "Cascade architecture is needed.");
        Preconditions.checkArgument(((NNTrainingProblem) problem).getNeuralNetwork().getOperationVisitor() instanceof CascadeVisitor, "CascadeVisitor is needed.");
        this.optimisationProblem = problem;
    }

    public void setPhase1Algorithm(AbstractAlgorithm abstractAlgorithm) {
        this.phase1Algorithm = abstractAlgorithm;
    }

    public void setPhase2Algorithm(AbstractAlgorithm abstractAlgorithm) {
        this.phase2Algorithm = abstractAlgorithm;
    }

    public int getPhase1EvaluationCount() {
        return this.phase1Problem.getFitnessEvaluations();
    }

    public int getPhase2EvaluationCount() {
        return this.phase2Problem.getFitnessEvaluations();
    }

    public int getPhase1WeightEvaluationCount() {
        return this.phase1Problem.getWeightEvaluationCount();
    }

    public int getPhase2WeightEvaluationCount() {
        return this.phase2Problem.getWeightEvaluationCount();
    }
}
