package net.sourceforge.cilib.measurement.single;

import java.util.Iterator;
import java.util.List;
import net.sourceforge.cilib.algorithm.Algorithm;
import net.sourceforge.cilib.measurement.Measurement;
import net.sourceforge.cilib.nn.architecture.Layer;
import net.sourceforge.cilib.problem.nn.NNTrainingProblem;
import net.sourceforge.cilib.type.types.Numeric;
import net.sourceforge.cilib.type.types.Real;
import net.sourceforge.cilib.type.types.container.Vector;

/* loaded from: input_file:net/sourceforge/cilib/measurement/single/CascadeNetworkWeightShift.class */
public class CascadeNetworkWeightShift implements Measurement {
    private Vector previousExpansion;
    private Vector previousIteration;

    public CascadeNetworkWeightShift() {
        this.previousExpansion = null;
        this.previousIteration = null;
    }

    public CascadeNetworkWeightShift(CascadeNetworkWeightShift cascadeNetworkWeightShift) {
        this.previousExpansion = cascadeNetworkWeightShift.previousExpansion.getClone();
        this.previousIteration = cascadeNetworkWeightShift.previousIteration.getClone();
    }

    @Override // net.sourceforge.cilib.util.Cloneable
    public CascadeNetworkWeightShift getClone() {
        return new CascadeNetworkWeightShift(this);
    }

    @Override // net.sourceforge.cilib.measurement.Measurement
    public Real getValue(Algorithm algorithm) {
        Vector vector = (Vector) algorithm.getBestSolution().getPosition();
        List<Layer> layers = ((NNTrainingProblem) algorithm.getOptimisationProblem()).getNeuralNetwork().getArchitecture().getLayers();
        int size = layers.get(layers.size() - 2).size();
        int size2 = layers.get(layers.size() - 1).size();
        int i = 0;
        for (int i2 = 0; i2 < layers.size() - 2; i2++) {
            i += layers.get(i2).size();
        }
        int size3 = vector.size() - 1;
        for (int i3 = 0; i3 < size2 * (i + size); i3++) {
            int i4 = size3;
            size3--;
            vector.remove(vector.get(i4));
        }
        if (this.previousIteration == null) {
            this.previousIteration = vector.getClone();
            return Real.valueOf(0.0d);
        }
        if (vector.size() > this.previousIteration.size()) {
            this.previousExpansion = this.previousIteration;
        }
        this.previousIteration = vector.getClone();
        for (int i5 = 0; i5 < i * size; i5++) {
            int i6 = size3;
            size3--;
            vector.remove(vector.get(i6));
        }
        if (vector.size() == 0) {
            return Real.valueOf(0.0d);
        }
        double d = 0.0d;
        Iterator<Numeric> it = vector.subtract(this.previousExpansion).iterator();
        while (it.hasNext()) {
            d += Math.pow(it.next().doubleValue(), 2.0d);
        }
        return Real.valueOf(d);
    }
}
