package ai.libs.jaicore.ml.ranking.dyad.learner.util;

import ai.libs.jaicore.ml.ranking.dyad.dataset.DenseDyadRankingInstance;
import ai.libs.jaicore.ml.ranking.dyad.dataset.SparseDyadRankingInstance;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyad;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingDataset;
import org.api4.java.ai.ml.ranking.dyad.dataset.IDyadRankingInstance;
import org.api4.java.common.math.IVector;

/* loaded from: input_file:ai/libs/jaicore/ml/ranking/dyad/learner/util/AbstractDyadScaler.class */
public abstract class AbstractDyadScaler implements Serializable {
    private static final long serialVersionUID = -825893010030419116L;
    protected SummaryStatistics[] statsX;
    protected SummaryStatistics[] statsY;

    public SummaryStatistics[] getStatsX() {
        return this.statsX;
    }

    public SummaryStatistics[] getStatsY() {
        return this.statsY;
    }

    public void fit(IDyadRankingDataset iDyadRankingDataset) {
        int length = ((IDyad) ((IDyadRankingInstance) iDyadRankingDataset.get(0)).getLabel().get(0)).getContext().length();
        int length2 = ((IDyad) ((IDyadRankingInstance) iDyadRankingDataset.get(0)).getLabel().get(0)).getAlternative().length();
        this.statsX = new SummaryStatistics[length];
        this.statsY = new SummaryStatistics[length2];
        for (int i = 0; i < length; i++) {
            this.statsX[i] = new SummaryStatistics();
        }
        for (int i2 = 0; i2 < length2; i2++) {
            this.statsY[i2] = new SummaryStatistics();
        }
        Iterator it = iDyadRankingDataset.iterator();
        while (it.hasNext()) {
            for (IDyad iDyad : (IDyadRankingInstance) it.next()) {
                for (int i3 = 0; i3 < length; i3++) {
                    this.statsX[i3].addValue(iDyad.getContext().getValue(i3));
                }
                for (int i4 = 0; i4 < length2; i4++) {
                    this.statsY[i4].addValue(iDyad.getAlternative().getValue(i4));
                }
            }
        }
    }

    public void transform(IDyadRankingDataset iDyadRankingDataset) {
        int length = ((IDyad) ((IDyadRankingInstance) iDyadRankingDataset.get(0)).getLabel().get(0)).getContext().length();
        int length2 = ((IDyad) ((IDyadRankingInstance) iDyadRankingDataset.get(0)).getLabel().get(0)).getAlternative().length();
        if (length != this.statsX.length || length2 != this.statsY.length) {
            throw new IllegalArgumentException("The scaler was fit to dyads with instances of length " + this.statsX.length + " and alternatives of length " + this.statsY.length + "\n but received instances of length " + length + " and alternatives of length " + length2);
        }
        transformInstances(iDyadRankingDataset);
        transformAlternatives(iDyadRankingDataset);
    }

    public void transformInstances(IDyadRankingDataset iDyadRankingDataset) {
        transformInstances(iDyadRankingDataset, new ArrayList());
    }

    public void transformAlternatives(IDyadRankingDataset iDyadRankingDataset) {
        transformAlternatives(iDyadRankingDataset, new ArrayList());
    }

    public abstract void transformInstances(IDyad iDyad, List<Integer> list);

    public abstract void transformAlternatives(IDyad iDyad, List<Integer> list);

    public abstract void transformInstaceVector(IVector iVector, List<Integer> list);

    public void transformInstances(SparseDyadRankingInstance sparseDyadRankingInstance, List<Integer> list) {
        transformInstaceVector(sparseDyadRankingInstance.getContext(), list);
    }

    public void transformInstances(IDyadRankingInstance iDyadRankingInstance, List<Integer> list) {
        Iterator it = iDyadRankingInstance.iterator();
        while (it.hasNext()) {
            transformInstances((IDyad) it.next(), list);
        }
    }

    public void transformAlternatives(IDyadRankingInstance iDyadRankingInstance, List<Integer> list) {
        Iterator it = iDyadRankingInstance.iterator();
        while (it.hasNext()) {
            transformAlternatives((IDyad) it.next(), list);
        }
    }

    public void transformInstances(IDyadRankingDataset iDyadRankingDataset, List<Integer> list) {
        Iterator it = iDyadRankingDataset.iterator();
        while (it.hasNext()) {
            IDyadRankingInstance iDyadRankingInstance = (IDyadRankingInstance) it.next();
            if (iDyadRankingInstance instanceof SparseDyadRankingInstance) {
                transformInstances((SparseDyadRankingInstance) iDyadRankingInstance, list);
            } else {
                if (!(iDyadRankingInstance instanceof DenseDyadRankingInstance)) {
                    throw new IllegalArgumentException("The scalers only support SparseDyadRankingInstance and DyadRankingInstance!");
                }
                transformInstances((DenseDyadRankingInstance) iDyadRankingInstance, list);
            }
        }
    }

    public void transformAlternatives(IDyadRankingDataset iDyadRankingDataset, List<Integer> list) {
        Iterator it = iDyadRankingDataset.iterator();
        while (it.hasNext()) {
            transformAlternatives((IDyadRankingInstance) it.next(), list);
        }
    }

    public void fitTransform(IDyadRankingDataset iDyadRankingDataset) {
        fit(iDyadRankingDataset);
        transform(iDyadRankingDataset);
    }

    public String getPrettySTDString() {
        if (this.statsX == null || this.statsY == null) {
            throw new IllegalStateException("The scaler must be fit before calling this method!");
        }
        StringBuilder sb = new StringBuilder();
        sb.append("Standard deviations for instances: ");
        for (SummaryStatistics summaryStatistics : this.statsX) {
            sb.append(summaryStatistics.getStandardDeviation());
            sb.append(", ");
        }
        sb.append(System.lineSeparator());
        sb.append("Standard deviations for alternatives: ");
        for (SummaryStatistics summaryStatistics2 : this.statsY) {
            sb.append(summaryStatistics2.getStandardDeviation());
            sb.append(", ");
        }
        sb.append(System.lineSeparator());
        return sb.toString();
    }

    public String getPrettyMeansString() {
        if (this.statsX == null || this.statsY == null) {
            throw new IllegalStateException("The scaler must be fit before calling this method!");
        }
        StringBuilder sb = new StringBuilder();
        sb.append("Means for instances: ");
        for (SummaryStatistics summaryStatistics : this.statsX) {
            sb.append(summaryStatistics.getMean());
            sb.append(", ");
        }
        sb.append(System.lineSeparator());
        sb.append("Means for alternatives: ");
        for (SummaryStatistics summaryStatistics2 : this.statsY) {
            sb.append(summaryStatistics2.getMean());
            sb.append(", ");
        }
        sb.append(System.lineSeparator());
        return sb.toString();
    }
}
