/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.data.evaluators;

import com.github.chen0040.data.evaluators.ConfusionMatrix;
import com.github.chen0040.data.utils.NumberUtils;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ClassifierEvaluator
implements Serializable {
    private static final long serialVersionUID = -6691826271325237852L;
    private ConfusionMatrix confusionMatrix = new ConfusionMatrix();

    public void evaluate(String actual, String predicted) {
        this.confusionMatrix.incCount(actual, predicted);
    }

    public List<String> classLabels() {
        return this.confusionMatrix.getLabels();
    }

    public void reset() {
        this.confusionMatrix.reset();
    }

    public ConfusionMatrix getConfusionMatrix() {
        return this.confusionMatrix;
    }

    public void setConfusionMatrix(ConfusionMatrix confusionMatrix) {
        this.confusionMatrix = confusionMatrix;
    }

    public double getAccuracy() {
        double accuracy = 0.0;
        List<String> list = this.confusionMatrix.getLabels();
        int correctCount = 0;
        int totalCount = 0;
        for (int i = 0; i < list.size(); ++i) {
            String actual = list.get(i);
            for (int j = 0; j < list.size(); ++j) {
                String predicted = list.get(j);
                int value = this.confusionMatrix.getCount(actual, predicted);
                correctCount += i == j ? value : 0;
                totalCount += value;
            }
        }
        if (totalCount > 0) {
            accuracy = (double)correctCount / (double)totalCount;
        }
        return accuracy;
    }

    public double getMisclassificationRate() {
        return 1.0 - this.getAccuracy();
    }

    public int getTruePositiveCount(String classLabel) {
        return this.confusionMatrix.getCount(classLabel, classLabel);
    }

    public int getFalsePositiveCount(String classLabel) {
        return this.confusionMatrix.getColumnSum(classLabel) - this.getTruePositiveCount(classLabel);
    }

    public double avgTruePositive() {
        List<String> labels = this.classLabels();
        if (labels.isEmpty()) {
            return 0.0;
        }
        int sum = 0;
        for (String label : labels) {
            sum += this.getTruePositiveCount(label);
        }
        return (double)sum / (double)labels.size();
    }

    public double avgFalsePositive() {
        List<String> labels = this.classLabels();
        if (labels.isEmpty()) {
            return 0.0;
        }
        int sum = 0;
        for (String label : labels) {
            sum += this.getFalsePositiveCount(label);
        }
        return (double)sum / (double)labels.size();
    }

    public Map<String, Double> getPrecisionByClass() {
        HashMap<String, Double> result = new HashMap<String, Double>();
        List<String> list = this.classLabels();
        for (int i = 0; i < list.size(); ++i) {
            String label = list.get(i);
            int correctCount = this.confusionMatrix.getCount(label, label);
            int totalPredictedCount = this.confusionMatrix.getColumnSum(label);
            double precision = 0.0;
            if (totalPredictedCount > 0) {
                precision = (double)correctCount / (double)totalPredictedCount;
            }
            result.put(label, precision);
        }
        return result;
    }

    public Map<String, Double> getRecallByClass() {
        HashMap<String, Double> result = new HashMap<String, Double>();
        List<String> list = this.classLabels();
        for (int i = 0; i < list.size(); ++i) {
            String label = list.get(i);
            int correctCount = this.confusionMatrix.getCount(label, label);
            int totalTrueCount = this.confusionMatrix.getRowSum(label);
            double recall = 0.0;
            if (totalTrueCount > 0) {
                recall = (double)correctCount / (double)totalTrueCount;
            }
            result.put(label, recall);
        }
        return result;
    }

    public Map<String, Double> getFalloutByClass() {
        HashMap<String, Double> result = new HashMap<String, Double>();
        List<String> list = this.classLabels();
        for (int i = 0; i < list.size(); ++i) {
            String label = list.get(i);
            int totalNegativeCount = 0;
            int falsePositiveCount = 0;
            for (int j = 0; j < list.size(); ++j) {
                if (i == j) continue;
                String notTrueLabel = list.get(j);
                falsePositiveCount += this.confusionMatrix.getCount(notTrueLabel, label);
                totalNegativeCount += this.confusionMatrix.getRowSum(notTrueLabel);
            }
            double fallout = 0.0;
            if (totalNegativeCount > 0) {
                fallout = (double)falsePositiveCount / (double)totalNegativeCount;
            }
            result.put(label, fallout);
        }
        return result;
    }

    public Map<String, Double> getF1ScoreByClass() {
        Map<String, Double> precisions = this.getPrecisionByClass();
        Map<String, Double> recalls = this.getRecallByClass();
        List<String> labels = this.classLabels();
        HashMap<String, Double> result = new HashMap<String, Double>();
        for (String label : labels) {
            double recall;
            double precision = precisions.get(label);
            if (NumberUtils.isZero(precision + (recall = recalls.get(label).doubleValue()))) continue;
            double f1score = 2.0 * (precision * recall) / (precision + recall);
            result.put(label, f1score);
        }
        return result;
    }

    public double getMacroF1Score() {
        double sum = 0.0;
        int count = 0;
        Map<String, Double> data = this.getF1ScoreByClass();
        for (Map.Entry<String, Double> entry : data.entrySet()) {
            sum += entry.getValue().doubleValue();
            ++count;
        }
        if (count == 0) {
            return 0.0;
        }
        return sum / (double)count;
    }

    public double getMicroF1Score() {
        Map<String, Double> precisions = this.getPrecisionByClass();
        Map<String, Double> recalls = this.getRecallByClass();
        List<String> labels = this.classLabels();
        double precisionAvg = 0.0;
        double recallAvg = 0.0;
        for (String label : labels) {
            double precision = precisions.get(label);
            double recall = recalls.get(label);
            precisionAvg += precision;
            recallAvg += recall;
        }
        return 2.0 * ((precisionAvg /= (double)labels.size()) * (recallAvg /= (double)labels.size())) / (precisionAvg + recallAvg);
    }

    public String getSummary() {
        StringBuilder sb = new StringBuilder();
        sb.append("accuracy: ").append(this.getAccuracy());
        sb.append("\nmis-classification: ").append(this.getMisclassificationRate());
        sb.append("\nmacro f1-score: ").append(this.getMacroF1Score());
        sb.append("\nmicro f1-score: ").append(this.getMicroF1Score());
        return sb.toString();
    }

    public void report() {
        System.out.println(this.getSummary());
    }
}

