package net.sourceforge.cilib.measurement.single;

import java.util.Iterator;
import net.sourceforge.cilib.algorithm.Algorithm;
import net.sourceforge.cilib.io.StandardPatternDataTable;
import net.sourceforge.cilib.io.pattern.StandardPattern;
import net.sourceforge.cilib.measurement.Measurement;
import net.sourceforge.cilib.nn.NeuralNetwork;
import net.sourceforge.cilib.nn.architecture.visitors.OutputErrorVisitor;
import net.sourceforge.cilib.problem.nn.NNTrainingProblem;
import net.sourceforge.cilib.type.types.Numeric;
import net.sourceforge.cilib.type.types.Real;
import net.sourceforge.cilib.type.types.Type;
import net.sourceforge.cilib.type.types.container.Vector;

/* loaded from: input_file:net/sourceforge/cilib/measurement/single/MSEGeneralisationError.class */
public class MSEGeneralisationError implements Measurement {
    private static final long serialVersionUID = -1014032196750640716L;

    @Override // net.sourceforge.cilib.util.Cloneable
    public Measurement getClone() {
        return this;
    }

    @Override // net.sourceforge.cilib.measurement.Measurement
    public Type getValue(Algorithm algorithm) {
        Vector vector = (Vector) algorithm.getBestSolution().getPosition();
        NNTrainingProblem nNTrainingProblem = (NNTrainingProblem) algorithm.getOptimisationProblem();
        StandardPatternDataTable generalisationSet = nNTrainingProblem.getGeneralisationSet();
        NeuralNetwork neuralNetwork = nNTrainingProblem.getNeuralNetwork();
        neuralNetwork.setWeights(vector);
        double d = 0.0d;
        OutputErrorVisitor outputErrorVisitor = new OutputErrorVisitor();
        Vector vector2 = null;
        Iterator<StandardPattern> it = generalisationSet.iterator();
        while (it.hasNext()) {
            StandardPattern next = it.next();
            neuralNetwork.evaluatePattern(next);
            outputErrorVisitor.setInput(next);
            neuralNetwork.getArchitecture().accept(outputErrorVisitor);
            vector2 = outputErrorVisitor.getOutput();
            Iterator<Numeric> it2 = vector2.iterator();
            while (it2.hasNext()) {
                Numeric next2 = it2.next();
                d += next2.doubleValue() * next2.doubleValue();
            }
        }
        return Real.valueOf(d / (generalisationSet.getNumRows() * vector2.size()));
    }
}
