package ml.shifu.guagua.example.nn;

import java.io.File;
import java.util.Iterator;
import ml.shifu.guagua.util.SizeEstimator;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.data.buffer.BufferedMLDataSet;

/* loaded from: input_file:ml/shifu/guagua/example/nn/MemoryDiskMLDataSet.class */
public class MemoryDiskMLDataSet implements MLDataSet {
    private long maxByteSize;
    private long byteSize;
    private MLDataSet memoryDataSet;
    private MLDataSet diskDataSet;
    private int inputCount;
    private int outputCount;
    private String fileName;
    private long memoryCount;
    private long diskCount;

    public MemoryDiskMLDataSet(String str, int i, int i2) {
        this.maxByteSize = Long.MAX_VALUE;
        this.byteSize = 0L;
        this.memoryCount = 0L;
        this.diskCount = 0L;
        this.memoryDataSet = new BasicMLDataSet();
        this.inputCount = i;
        this.outputCount = i2;
        this.fileName = str;
    }

    public MemoryDiskMLDataSet(long j, String str) {
        this.maxByteSize = Long.MAX_VALUE;
        this.byteSize = 0L;
        this.memoryCount = 0L;
        this.diskCount = 0L;
        this.maxByteSize = j;
        this.memoryDataSet = new BasicMLDataSet();
        this.fileName = str;
    }

    public MemoryDiskMLDataSet(long j, String str, int i, int i2) {
        this.maxByteSize = Long.MAX_VALUE;
        this.byteSize = 0L;
        this.memoryCount = 0L;
        this.diskCount = 0L;
        this.maxByteSize = j;
        this.memoryDataSet = new BasicMLDataSet();
        this.inputCount = i;
        this.outputCount = i2;
        this.fileName = str;
    }

    public final void beginLoad(int i, int i2) {
        this.inputCount = i;
        this.outputCount = i2;
        if (this.diskDataSet != null) {
            this.diskDataSet.beginLoad(this.inputCount, this.outputCount);
        }
    }

    public final void endLoad() {
        if (this.diskDataSet != null) {
            this.diskDataSet.endLoad();
        }
    }

    public Iterator<MLDataPair> iterator() {
        return new Iterator<MLDataPair>() { // from class: ml.shifu.guagua.example.nn.MemoryDiskMLDataSet.1
            private Iterator<MLDataPair> iter1;
            private Iterator<MLDataPair> iter2;
            private boolean isMemoryHasNext;
            private boolean isDiskHasNext;

            {
                this.iter1 = MemoryDiskMLDataSet.this.memoryDataSet.iterator();
                this.iter2 = MemoryDiskMLDataSet.this.diskDataSet == null ? null : MemoryDiskMLDataSet.this.diskDataSet.iterator();
                this.isMemoryHasNext = false;
                this.isDiskHasNext = false;
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                boolean hasNext = this.iter1.hasNext();
                if (hasNext) {
                    this.isMemoryHasNext = true;
                    this.isDiskHasNext = false;
                    return hasNext;
                }
                boolean hasNext2 = this.iter2 == null ? false : this.iter2.hasNext();
                if (hasNext2) {
                    this.isMemoryHasNext = false;
                    this.isDiskHasNext = true;
                } else {
                    this.isMemoryHasNext = false;
                    this.isDiskHasNext = false;
                }
                return hasNext2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.Iterator
            public MLDataPair next() {
                if (this.isMemoryHasNext) {
                    return this.iter1.next();
                }
                if (!this.isDiskHasNext || this.iter2 == null) {
                    return null;
                }
                return this.iter2.next();
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    public int getIdealSize() {
        return this.outputCount;
    }

    public int getInputSize() {
        return this.inputCount;
    }

    public boolean isSupervised() {
        return this.memoryDataSet.isSupervised();
    }

    public long getRecordCount() {
        long recordCount = this.memoryDataSet.getRecordCount();
        if (this.diskDataSet != null) {
            recordCount += this.diskDataSet.getRecordCount();
        }
        return recordCount;
    }

    public void getRecord(long j, MLDataPair mLDataPair) {
        if (j < this.memoryCount) {
            this.memoryDataSet.getRecord(j, mLDataPair);
        } else {
            this.diskDataSet.getRecord(j - this.memoryCount, mLDataPair);
        }
    }

    public MLDataSet openAdditional() {
        throw new UnsupportedOperationException();
    }

    public void add(MLData mLData) {
        long estimate = SizeEstimator.estimate(mLData);
        if (this.byteSize + estimate < this.maxByteSize) {
            this.byteSize += estimate;
            this.memoryCount++;
            this.memoryDataSet.add(mLData);
        } else {
            if (this.diskDataSet == null) {
                this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
                this.diskDataSet.beginLoad(this.inputCount, this.outputCount);
            }
            this.byteSize += estimate;
            this.diskCount++;
            this.diskDataSet.add(mLData);
        }
    }

    public void add(MLData mLData, MLData mLData2) {
        long estimate = SizeEstimator.estimate(mLData) + SizeEstimator.estimate(mLData2);
        if (this.byteSize + estimate < this.maxByteSize) {
            this.byteSize += estimate;
            this.memoryCount++;
            this.memoryDataSet.add(mLData, mLData2);
        } else {
            if (this.diskDataSet == null) {
                this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
                this.diskDataSet.beginLoad(this.inputCount, this.outputCount);
            }
            this.byteSize += estimate;
            this.diskCount++;
            this.diskDataSet.add(mLData, mLData2);
        }
    }

    public void add(MLDataPair mLDataPair) {
        long estimate = SizeEstimator.estimate(mLDataPair);
        if (this.byteSize + estimate < this.maxByteSize) {
            this.byteSize += estimate;
            this.memoryCount++;
            this.memoryDataSet.add(mLDataPair);
        } else {
            if (this.diskDataSet == null) {
                this.diskDataSet = new BufferedMLDataSet(new File(this.fileName));
                this.diskDataSet.beginLoad(this.inputCount, this.outputCount);
            }
            this.byteSize += estimate;
            this.diskCount++;
            this.diskDataSet.add(mLDataPair);
        }
    }

    public void close() {
        this.memoryDataSet.close();
        if (this.diskDataSet != null) {
            this.diskDataSet.close();
        }
    }

    public long getMemoryCount() {
        return this.memoryCount;
    }

    public long getDiskCount() {
        return this.diskCount;
    }

    public static void main(String[] strArr) {
        double[] dArr = {1.0d};
        MLDataPair basicMLDataPair = new BasicMLDataPair(new BasicMLData(createInput(1.0d)), new BasicMLData(dArr));
        MemoryDiskMLDataSet memoryDiskMLDataSet = new MemoryDiskMLDataSet(400L, "a.txt");
        memoryDiskMLDataSet.beginLoad(10, 1);
        memoryDiskMLDataSet.add(basicMLDataPair);
        BasicMLDataPair basicMLDataPair2 = new BasicMLDataPair(new BasicMLData(createInput(2.0d)), new BasicMLData(dArr));
        BasicMLDataPair basicMLDataPair3 = new BasicMLDataPair(new BasicMLData(createInput(3.0d)), new BasicMLData(dArr));
        BasicMLDataPair basicMLDataPair4 = new BasicMLDataPair(new BasicMLData(createInput(4.0d)), new BasicMLData(dArr));
        BasicMLDataPair basicMLDataPair5 = new BasicMLDataPair(new BasicMLData(createInput(5.0d)), new BasicMLData(dArr));
        BasicMLDataPair basicMLDataPair6 = new BasicMLDataPair(new BasicMLData(createInput(6.0d)), new BasicMLData(dArr));
        memoryDiskMLDataSet.add((MLDataPair) basicMLDataPair2);
        memoryDiskMLDataSet.add((MLDataPair) basicMLDataPair3);
        memoryDiskMLDataSet.add((MLDataPair) basicMLDataPair4);
        memoryDiskMLDataSet.add((MLDataPair) basicMLDataPair5);
        memoryDiskMLDataSet.add((MLDataPair) basicMLDataPair6);
        memoryDiskMLDataSet.endLoad();
        long recordCount = memoryDiskMLDataSet.getRecordCount();
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= recordCount) {
                break;
            }
            long currentTimeMillis = System.currentTimeMillis();
            BasicMLDataPair basicMLDataPair7 = new BasicMLDataPair(new BasicMLData(createInput(6.0d)), new BasicMLData(dArr));
            memoryDiskMLDataSet.getRecord(j2, basicMLDataPair7);
            System.out.println((System.currentTimeMillis() - currentTimeMillis) + " " + basicMLDataPair7);
            j = j2 + 1;
        }
        System.out.println();
        Iterator<MLDataPair> it = memoryDiskMLDataSet.iterator();
        while (it.hasNext()) {
            System.out.println((System.currentTimeMillis() - System.currentTimeMillis()) + " " + it.next());
        }
        System.out.println();
        Iterator<MLDataPair> it2 = memoryDiskMLDataSet.iterator();
        while (it2.hasNext()) {
            System.out.println((System.currentTimeMillis() - System.currentTimeMillis()) + " " + it2.next());
        }
        memoryDiskMLDataSet.close();
        System.out.println(SizeEstimator.estimate(basicMLDataPair));
    }

    private static double[] createInput(double d) {
        double[] dArr = new double[10];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = d;
        }
        return dArr;
    }
}
