/*
 * Decompiled with CFR 0.152.
 */
package dev.brachtendorf.clustering;

import com.github.kilianB.pcg.fast.PcgRSFast;
import dev.brachtendorf.ArrayUtil;
import dev.brachtendorf.clustering.ClusterAlgorithm;
import dev.brachtendorf.clustering.ClusterResult;
import dev.brachtendorf.clustering.distance.DistanceFunction;
import dev.brachtendorf.clustering.distance.EuclideanDistance;
import java.util.DoubleSummaryStatistics;

public class KMeans
implements ClusterAlgorithm {
    protected int k;
    protected DistanceFunction distanceFunction;
    protected int lastIterationCount;

    public KMeans(int clusters) {
        this(clusters, new EuclideanDistance());
    }

    public KMeans(int clusters, DistanceFunction distanceFunction) {
        this.k = clusters;
        this.distanceFunction = distanceFunction;
    }

    @Override
    public ClusterResult cluster(double[][] data) {
        int[] cluster = new int[data.length];
        if (this.k == 1) {
            ArrayUtil.fillArray(cluster, () -> 0);
            return new ClusterResult(cluster, data);
        }
        if (this.k >= data.length) {
            throw new IllegalArgumentException("Can't compute more clusters than datapoints are present");
        }
        int dataDimension = data[0].length;
        DoubleSummaryStatistics[][] clusterMeans = this.computeStartingClusters(data, this.k, dataDimension);
        this.computeKMeans(clusterMeans, data, cluster, dataDimension);
        return new ClusterResult(cluster, data);
    }

    protected DoubleSummaryStatistics[][] computeStartingClusters(double[][] data, int k, int dataDimension) {
        int j;
        int i;
        PcgRSFast rng = new PcgRSFast();
        double[][] range = new double[data.length][2];
        DoubleSummaryStatistics[][] clusterMeans = new DoubleSummaryStatistics[k][dataDimension];
        for (double[] arr : range) {
            arr[0] = Double.MAX_VALUE;
            arr[1] = -1.7976931348623157E308;
        }
        ArrayUtil.fillArrayMulti(clusterMeans, () -> new DoubleSummaryStatistics());
        for (i = 0; i < data.length; ++i) {
            for (j = 0; j < dataDimension; ++j) {
                double value = data[i][j];
                if (value < range[i][0]) {
                    range[i][0] = value;
                }
                if (!(value > range[i][1])) continue;
                range[i][1] = value;
            }
        }
        for (i = 0; i < k; ++i) {
            for (j = 0; j < dataDimension; ++j) {
                clusterMeans[i][j].accept(rng.nextDouble() * (range[j][1] - range[j][0]) + range[j][0]);
            }
        }
        return clusterMeans;
    }

    protected void computeKMeans(DoubleSummaryStatistics[][] clusterMeans, double[][] data, int[] cluster, int dataDimension) {
        this.lastIterationCount = 0;
        boolean dirty = false;
        do {
            int dataIndex;
            dirty = false;
            for (dataIndex = 0; dataIndex < data.length; ++dataIndex) {
                double minDistance = Double.MAX_VALUE;
                int bestCluster = -1;
                for (int clusterIndex = 0; clusterIndex < this.k; ++clusterIndex) {
                    double distToCluster = this.distanceFunction.distance(clusterMeans[clusterIndex], data[dataIndex]);
                    if (!(distToCluster < minDistance)) continue;
                    bestCluster = clusterIndex;
                    minDistance = distToCluster;
                }
                if (cluster[dataIndex] == bestCluster) continue;
                cluster[dataIndex] = bestCluster;
                dirty = true;
            }
            if (dirty) {
                ArrayUtil.fillArrayMulti(clusterMeans, () -> new DoubleSummaryStatistics());
                for (dataIndex = 0; dataIndex < data.length; ++dataIndex) {
                    double[] dat = data[dataIndex];
                    DoubleSummaryStatistics[] clusterTemp = clusterMeans[cluster[dataIndex]];
                    for (int i = 0; i < dataDimension; ++i) {
                        clusterTemp[i].accept(dat[i]);
                    }
                }
            }
            ++this.lastIterationCount;
        } while (dirty);
    }

    public int iterations() {
        return this.lastIterationCount;
    }
}

