package ml.shifu.guagua.example.kmeans;

import com.google.common.base.Splitter;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import ml.shifu.guagua.hadoop.io.GuaguaLineRecordReader;
import ml.shifu.guagua.hadoop.io.GuaguaWritableAdapter;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.io.GuaguaFileSplit;
import ml.shifu.guagua.util.MemoryDiskList;
import ml.shifu.guagua.worker.AbstractWorkerComputable;
import ml.shifu.guagua.worker.WorkerContext;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ml/shifu/guagua/example/kmeans/KMeansWorker.class */
public class KMeansWorker extends AbstractWorkerComputable<KMeansMasterParams, KMeansWorkerParams, GuaguaWritableAdapter<LongWritable>, GuaguaWritableAdapter<Text>> {
    private static final Logger LOG = LoggerFactory.getLogger(KMeansWorker.class);
    private MemoryDiskList<TaggedRecord> dataList;
    private int k;
    private int c;
    private String separator;

    public void initRecordReader(GuaguaFileSplit guaguaFileSplit) throws IOException {
        setRecordReader(new GuaguaLineRecordReader());
        getRecordReader().initialize(guaguaFileSplit);
    }

    public void init(WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        this.k = Integer.parseInt(workerContext.getProps().getProperty(KMeansContants.KMEANS_K_NUMBER));
        this.c = Integer.parseInt(workerContext.getProps().getProperty(KMeansContants.KMEANS_COLUMN_NUMBER));
        this.separator = workerContext.getProps().getProperty(KMeansContants.KMEANS_DATA_SEPERATOR);
        this.dataList = new MemoryDiskList<>((long) (Runtime.getRuntime().maxMemory() * Double.valueOf(workerContext.getProps().getProperty("guagua.data.memoryFraction", "0.5")).doubleValue()), workerContext.getProps().getProperty("guagua.data.tmpfolder", "tmp") + File.separator + System.currentTimeMillis());
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { // from class: ml.shifu.guagua.example.kmeans.KMeansWorker.1
            @Override // java.lang.Runnable
            public void run() {
                KMeansWorker.this.dataList.close();
                KMeansWorker.this.dataList.clear();
            }
        }));
        workerContext.setAttachment(this.dataList);
    }

    public KMeansWorkerParams doCompute(WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        if (workerContext.getCurrentIteration() == 1) {
            return doFirstIteration(workerContext);
        }
        this.dataList.reOpen();
        return doOtherIterations(workerContext);
    }

    private KMeansWorkerParams doFirstIteration(WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        KMeansWorkerParams kMeansWorkerParams = new KMeansWorkerParams();
        kMeansWorkerParams.setK(this.k);
        kMeansWorkerParams.setC(this.c);
        kMeansWorkerParams.setFirstIteration(true);
        int size = (int) this.dataList.size();
        ArrayList arrayList = new ArrayList(size);
        if (this.k >= size) {
            Iterator it = this.dataList.iterator();
            while (it.hasNext()) {
                arrayList.add(toDouble((TaggedRecord) it.next()));
            }
        } else {
            int i = size / this.k;
            int i2 = 0;
            this.dataList.reOpen();
            Iterator it2 = this.dataList.iterator();
            while (it2.hasNext()) {
                TaggedRecord taggedRecord = (TaggedRecord) it2.next();
                int i3 = i2;
                i2++;
                if (i3 % i == 0) {
                    arrayList.add(toDouble(taggedRecord));
                }
            }
        }
        kMeansWorkerParams.setPointList(arrayList);
        return kMeansWorkerParams;
    }

    private double[] toDouble(TaggedRecord taggedRecord) {
        Double[] record = taggedRecord.getRecord();
        double[] dArr = new double[record.length];
        int length = record.length;
        for (int i = 0; i < length; i++) {
            Double d = record[i];
            dArr[0] = d == null ? 0.0d : d.doubleValue();
        }
        return dArr;
    }

    private KMeansWorkerParams doOtherIterations(WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        List<double[]> pointList = workerContext.getLastMasterResult().getPointList();
        LOG.debug("Initial centers:%s", pointList);
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        for (int i = 0; i < this.k; i++) {
            linkedList.add(new double[this.c]);
            linkedList2.add(0);
        }
        Iterator it = this.dataList.iterator();
        while (it.hasNext()) {
            TaggedRecord taggedRecord = (TaggedRecord) it.next();
            int findClosedCenter = findClosedCenter(taggedRecord.getRecord(), pointList);
            taggedRecord.setTag(findClosedCenter);
            linkedList2.set(findClosedCenter, Integer.valueOf(((Integer) linkedList2.get(findClosedCenter)).intValue() + 1));
            double[] dArr = (double[]) linkedList.get(findClosedCenter);
            for (int i2 = 0; i2 < this.c; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (taggedRecord.getRecord()[i2] == null ? 0.0d : taggedRecord.getRecord()[i2].doubleValue());
            }
        }
        LOG.debug("sumList:%s", linkedList);
        LOG.debug("countList:%s", linkedList2);
        KMeansWorkerParams kMeansWorkerParams = new KMeansWorkerParams();
        kMeansWorkerParams.setK(this.k);
        kMeansWorkerParams.setC(this.c);
        kMeansWorkerParams.setFirstIteration(false);
        kMeansWorkerParams.setPointList(linkedList);
        kMeansWorkerParams.setCountList(linkedList2);
        return kMeansWorkerParams;
    }

    protected void postLoad(WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        this.dataList.switchState();
    }

    private int findClosedCenter(Double[] dArr, List<double[]> list) {
        int i = 0;
        double distance = distance(dArr, list.get(0));
        for (int i2 = 1; i2 < list.size(); i2++) {
            if (distance(dArr, list.get(i2)) < distance) {
                i = i2;
            }
        }
        return i;
    }

    private double distance(Double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < dArr2.length; i++) {
            d += dArr[i] == null ? 0.0d : dArr[i].doubleValue() * dArr2[i];
        }
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i2 = 0; i2 < dArr2.length; i2++) {
            d2 += dArr[i2] == null ? 0.0d : dArr[i2].doubleValue() * dArr[i2].doubleValue();
            d3 += dArr2[i2] * dArr2[i2];
        }
        return d / (Math.sqrt(d2) * Math.sqrt(d3));
    }

    public void load(GuaguaWritableAdapter<LongWritable> guaguaWritableAdapter, GuaguaWritableAdapter<Text> guaguaWritableAdapter2, WorkerContext<KMeansMasterParams, KMeansWorkerParams> workerContext) {
        String text = guaguaWritableAdapter2.getWritable().toString();
        Double[] dArr = new Double[this.c];
        int i = 0;
        Iterator it = Splitter.on(this.separator).split(text).iterator();
        while (it.hasNext()) {
            try {
                int i2 = i;
                i++;
                dArr[i2] = Double.valueOf(Double.parseDouble((String) it.next()));
            } catch (NumberFormatException e) {
                int i3 = i;
                i++;
                dArr[i3] = null;
            }
        }
        this.dataList.append(new TaggedRecord(dArr));
    }

    public /* bridge */ /* synthetic */ void load(Bytable bytable, Bytable bytable2, WorkerContext workerContext) {
        load((GuaguaWritableAdapter<LongWritable>) bytable, (GuaguaWritableAdapter<Text>) bytable2, (WorkerContext<KMeansMasterParams, KMeansWorkerParams>) workerContext);
    }

    /* renamed from: doCompute, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Bytable m5doCompute(WorkerContext workerContext) {
        return doCompute((WorkerContext<KMeansMasterParams, KMeansWorkerParams>) workerContext);
    }
}
