package ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset;

import ai.libs.jaicore.ml.classification.singlelabel.timeseries.dataset.attribute.NDArrayTimeseriesAttribute;
import ai.libs.jaicore.ml.classification.singlelabel.timeseries.model.NDArrayTimeseries;
import ai.libs.jaicore.ml.core.dataset.ADataset;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.stream.Collectors;
import org.api4.java.ai.ml.core.dataset.schema.ILabeledInstanceSchema;
import org.api4.java.ai.ml.core.dataset.schema.attribute.IAttribute;
import org.api4.java.ai.ml.core.dataset.schema.attribute.ITimeseriesAttribute;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.exception.DatasetCreationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:ai/libs/jaicore/ml/classification/singlelabel/timeseries/dataset/TimeSeriesDataset.class */
public class TimeSeriesDataset extends ADataset<ITimeSeriesInstance> implements ILabeledDataset<ITimeSeriesInstance> {
    private static final long serialVersionUID = -6819487387561457394L;
    private List<INDArray> valueMatrices;
    private List<INDArray> timestampMatrices;
    private transient List<Object> targets;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/libs/jaicore/ml/classification/singlelabel/timeseries/dataset/TimeSeriesDataset$TimeSeriesDatasetIterator.class */
    public class TimeSeriesDatasetIterator implements Iterator<ITimeSeriesInstance> {
        private int current = 0;

        TimeSeriesDatasetIterator() {
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return TimeSeriesDataset.this.getNumberOfInstances() > ((long) this.current);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public ITimeSeriesInstance next() {
            if (!hasNext()) {
                throw new NoSuchElementException();
            }
            TimeSeriesDataset timeSeriesDataset = TimeSeriesDataset.this;
            int i = this.current;
            this.current = i + 1;
            return timeSeriesDataset.get(i);
        }
    }

    public TimeSeriesDataset(ILabeledInstanceSchema iLabeledInstanceSchema, List<INDArray> list, List<INDArray> list2, List<Object> list3) {
        this(iLabeledInstanceSchema);
        Iterator it = iLabeledInstanceSchema.getAttributeList().iterator();
        while (it.hasNext()) {
            if (!(((IAttribute) it.next()) instanceof ITimeseriesAttribute)) {
                throw new IllegalArgumentException("The schema contains attributes which are not timeseries");
            }
        }
        Set set = (Set) list.stream().map(iNDArray -> {
            return Long.valueOf(iNDArray.shape()[0]);
        }).collect(Collectors.toSet());
        if (set.size() > 1) {
            throw new IllegalArgumentException("The value matrices vary in length i.e. they have different number of instances");
        }
        Set set2 = (Set) list2.stream().map(iNDArray2 -> {
            return Long.valueOf(iNDArray2.shape()[0]);
        }).collect(Collectors.toSet());
        if (set2.size() > 1) {
            throw new IllegalArgumentException("The timestamp matrices vary in length i.e. they have different number of instances");
        }
        set.addAll(set2);
        if (set.size() > 1) {
            throw new IllegalArgumentException("There are different number of instances for values and timestamps");
        }
        this.valueMatrices = list;
        this.timestampMatrices = list2;
        this.targets = list3;
    }

    public TimeSeriesDataset(ILabeledInstanceSchema iLabeledInstanceSchema) {
        super(iLabeledInstanceSchema);
    }

    public void add(String str, INDArray iNDArray, INDArray iNDArray2) {
        this.valueMatrices.add(iNDArray);
        this.timestampMatrices.add(iNDArray2);
        addAttribute(str, iNDArray);
    }

    public void removeColumn(int i) {
        this.valueMatrices.remove(i);
        this.timestampMatrices.remove(i);
        m28getInstanceSchema().removeAttribute(i);
    }

    public void replace(int i, INDArray iNDArray, INDArray iNDArray2) {
        this.valueMatrices.set(i, iNDArray);
        if (iNDArray2 != null && this.timestampMatrices != null && this.timestampMatrices.size() > i) {
            this.timestampMatrices.set(i, iNDArray2);
        }
        NDArrayTimeseriesAttribute createAttribute = createAttribute("ts" + i, iNDArray);
        m28getInstanceSchema().removeAttribute(i);
        m28getInstanceSchema().addAttribute(i, createAttribute);
    }

    public Object getTargets() {
        return this.targets;
    }

    public INDArray getTargetsAsINDArray() {
        if (this.targets.get(0) instanceof Number) {
            return Nd4j.create(this.targets.stream().mapToDouble(obj -> {
                return ((Double) obj).doubleValue();
            }).toArray());
        }
        return null;
    }

    public int getNumberOfVariables() {
        return this.valueMatrices.size();
    }

    public long getNumberOfInstances() {
        return this.valueMatrices.get(0).shape()[0];
    }

    public INDArray getValues(int i) {
        return this.valueMatrices.get(i);
    }

    public INDArray getTimestamps(int i) {
        return this.timestampMatrices.get(i);
    }

    public INDArray getValuesOrNull(int i) {
        if (this.valueMatrices.size() > i) {
            return this.valueMatrices.get(i);
        }
        return null;
    }

    public INDArray getTimestampsOrNull(int i) {
        if (this.timestampMatrices == null || this.timestampMatrices.size() <= i) {
            return null;
        }
        return this.timestampMatrices.get(i);
    }

    @Override // java.util.ArrayList, java.util.AbstractCollection, java.util.Collection, java.util.List
    public boolean isEmpty() {
        return this.valueMatrices.isEmpty();
    }

    public boolean isUnivariate() {
        return this.valueMatrices.size() == 1;
    }

    public boolean isMultivariate() {
        return this.valueMatrices.size() > 1;
    }

    private NDArrayTimeseriesAttribute createAttribute(String str, INDArray iNDArray) {
        return new NDArrayTimeseriesAttribute(str, (int) iNDArray.shape()[1]);
    }

    private void addAttribute(String str, INDArray iNDArray) {
        m28getInstanceSchema().addAttribute(createAttribute(str, iNDArray));
        this.valueMatrices.add(iNDArray);
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.List
    public TimeSeriesInstance get(int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.valueMatrices.size(); i2++) {
            arrayList.add(new NDArrayTimeseries(this.valueMatrices.get(i2).getRow(i)));
        }
        return new TimeSeriesInstance(arrayList, this.targets.get(i));
    }

    @Override // java.util.ArrayList, java.util.AbstractList, java.util.AbstractCollection, java.util.Collection, java.lang.Iterable, java.util.List
    public Iterator<ITimeSeriesInstance> iterator() {
        return new TimeSeriesDatasetIterator();
    }

    /* renamed from: createEmptyCopy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public TimeSeriesDataset m7createEmptyCopy() throws DatasetCreationException, InterruptedException {
        return new TimeSeriesDataset(m28getInstanceSchema());
    }

    @Override // ai.libs.jaicore.ml.core.dataset.ADataset
    public Object[][] getFeatureMatrix() {
        throw new UnsupportedOperationException();
    }

    @Override // ai.libs.jaicore.ml.core.dataset.ADataset
    public Object[] getLabelVector() {
        return this.targets.toArray();
    }

    /* renamed from: createCopy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public TimeSeriesDataset m6createCopy() throws DatasetCreationException, InterruptedException {
        TimeSeriesDataset m7createEmptyCopy = m7createEmptyCopy();
        Iterator<ITimeSeriesInstance> it = iterator();
        while (it.hasNext()) {
            m7createEmptyCopy.add(it.next());
        }
        return m7createEmptyCopy;
    }

    @Override // ai.libs.jaicore.ml.core.dataset.ADataset, java.util.ArrayList, java.util.AbstractList, java.util.Collection, java.util.List
    public int hashCode() {
        return (31 * ((31 * ((31 * super.hashCode()) + (this.targets == null ? 0 : this.targets.hashCode()))) + (this.timestampMatrices == null ? 0 : this.timestampMatrices.hashCode()))) + (this.valueMatrices == null ? 0 : this.valueMatrices.hashCode());
    }

    @Override // ai.libs.jaicore.ml.core.dataset.ADataset, java.util.ArrayList, java.util.AbstractList, java.util.Collection, java.util.List
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!super.equals(obj) || getClass() != obj.getClass()) {
            return false;
        }
        TimeSeriesDataset timeSeriesDataset = (TimeSeriesDataset) obj;
        if (this.targets == null) {
            if (timeSeriesDataset.targets != null) {
                return false;
            }
        } else if (!this.targets.equals(timeSeriesDataset.targets)) {
            return false;
        }
        if (this.timestampMatrices == null) {
            if (timeSeriesDataset.timestampMatrices != null) {
                return false;
            }
        } else if (!this.timestampMatrices.equals(timeSeriesDataset.timestampMatrices)) {
            return false;
        }
        return this.valueMatrices == null ? timeSeriesDataset.valueMatrices == null : this.valueMatrices.equals(timeSeriesDataset.valueMatrices);
    }
}
