package multisab.processing.machineLearning;

import multisab.processing.multisabException.ModelingException;

import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;

public class EvaluationMeasures {
    private ConfusionMatrix cm;
    private int [] truePositivesPerClass;
    private int [] trueNegativesPerClass;
    private int [] falsePositivesPerClass;
    private int [] falseNegativesPerClass;
    private double [] accuracyTestSetPerClass;
    private double [] sensitivityTestSetPerClass;
    private double [] specificityTestSetPerClass;
    private double [] positivePredictiveValueTestSetPerClass;
    private double [] negativePredictiveValueTestSetPerClass;
    private double [] F1ScoreTestSetPerClass;
    private double [] DERTestSetPerClass;
    private double [] MCCTestSetPerClass;
    private double averageAccuracyTestSet;
    private double averageSensitivityTestSet;
    private double averageSpecificityTestSet;
    private double averagePositivePredictiveValueTestSet;
    private double averageNegativePredictiveValueTestSet;
    private double averageF1ScoreTestSet;
    private double averageDERTestSet;
    private double averageMCCTestSet;
    private String [] actualClasses;
    private String [] predictedClasses;
    private String [] allUniqueClasses;
    private String [] featureVectorNumber;
    private String classifierName;

    public EvaluationMeasures(String classifierName, String [] actualClasses, String [] predictedClasses, String [] allUniqueClasses){
        this.actualClasses = actualClasses;
        this.predictedClasses = predictedClasses;
        this.classifierName = classifierName;
        this.allUniqueClasses = allUniqueClasses;
    }
    public EvaluationMeasures(String classifierName, String [][] foldsActualClasses, String [][] foldsPredictedClasses, String [] allUniqueClasses){
        this.classifierName = classifierName;
        this.allUniqueClasses = allUniqueClasses;

        int foldsCount = foldsPredictedClasses.length;

        this.predictedClasses = new String[foldsCount*foldsPredictedClasses[0].length];
        for (int i = 0; i < foldsCount; i++){
            for (int j = 0; j < foldsPredictedClasses[0].length; j++){
                predictedClasses[i*foldsPredictedClasses[0].length + j] = foldsPredictedClasses[i][j];
            }
        }
        this.actualClasses = new String[foldsCount*foldsActualClasses[0].length];
        for (int i = 0; i < foldsCount; i++){
            for (int j = 0; j < foldsActualClasses[0].length; j++){
                actualClasses[i*foldsActualClasses[0].length + j] = foldsActualClasses[i][j];
            }
        }
    }

    public void constructConfusionMatrixAndDetermineMetrics() throws ModelingException {
        cm = new ConfusionMatrix(this.actualClasses,this.predictedClasses,this.allUniqueClasses);

        int [][] classifications = cm.getClassifications();
        String [] classNames = cm.getClassNames();

        truePositivesPerClass = new int[classNames.length];
        trueNegativesPerClass = new int[classNames.length];
        falsePositivesPerClass = new int[classNames.length];
        falseNegativesPerClass = new int[classNames.length];

        for (int i=0; i<classifications.length; i++){
            for (int j=0; j<classifications.length; j++){
                if (i==j){
                    truePositivesPerClass[i] = classifications[i][j];
                    for (int k = 0; k<classifications.length; k++){
                        if (k != j){
                            trueNegativesPerClass[k] += classifications[i][j];
                        }
                    }
                }
                else {
                    falseNegativesPerClass[i] += classifications[i][j];
                    falsePositivesPerClass[j] += classifications[i][j];
                }
            }
        }
        accuracyTestSetPerClass = new double[classNames.length];
        sensitivityTestSetPerClass = new double[classNames.length];
        specificityTestSetPerClass = new double[classNames.length];
        positivePredictiveValueTestSetPerClass = new double[classNames.length];
        negativePredictiveValueTestSetPerClass = new double[classNames.length];
        F1ScoreTestSetPerClass = new double[classNames.length];
        DERTestSetPerClass = new double[classNames.length];

        for (int i=0; i<classNames.length; i++){
            accuracyTestSetPerClass[i] = (double)(truePositivesPerClass[i] + trueNegativesPerClass[i])/(truePositivesPerClass[i] + trueNegativesPerClass[i] + falsePositivesPerClass[i] + falseNegativesPerClass[i]);
            if (truePositivesPerClass[i] + falseNegativesPerClass[i] == 0) {
                sensitivityTestSetPerClass[i] = 1.0;
            }
            else {
                sensitivityTestSetPerClass[i] = (double)(truePositivesPerClass[i]) / (truePositivesPerClass[i] + falseNegativesPerClass[i]);
            }
            if (trueNegativesPerClass[i] + falsePositivesPerClass[i] == 0){
                specificityTestSetPerClass[i] = 1.0;
            }
            else {
                specificityTestSetPerClass[i] = (double)(trueNegativesPerClass[i]) / (trueNegativesPerClass[i] + falsePositivesPerClass[i]);
            }
            if (truePositivesPerClass[i] + falsePositivesPerClass[i]==0){
                positivePredictiveValueTestSetPerClass[i] = 1.0;
            }
            else {
                positivePredictiveValueTestSetPerClass[i] = (double)(truePositivesPerClass[i]) / (truePositivesPerClass[i] + falsePositivesPerClass[i]);
            }
            if (trueNegativesPerClass[i] + falseNegativesPerClass[i] == 0) {
                negativePredictiveValueTestSetPerClass[i] = 1.0;
            }
            else {
                negativePredictiveValueTestSetPerClass[i] = (double)(trueNegativesPerClass[i]) / (trueNegativesPerClass[i] + falseNegativesPerClass[i]);
            }
            if (2*truePositivesPerClass[i] + falsePositivesPerClass[i] + falseNegativesPerClass[i] == 0){
                F1ScoreTestSetPerClass[i] = 1.0;
            }
            else {
                F1ScoreTestSetPerClass[i] = (double)(2 * truePositivesPerClass[i]) / (2 * truePositivesPerClass[i] + falsePositivesPerClass[i] + falseNegativesPerClass[i]);
            }
            if (truePositivesPerClass[i] + falseNegativesPerClass[i] == 0){
                DERTestSetPerClass[i] = 1.0;
            }
            else {
                DERTestSetPerClass[i] = (double)(falsePositivesPerClass[i]+falseNegativesPerClass[i])/(truePositivesPerClass[i] + falseNegativesPerClass[i]);
            }
            //MCCTestSetPerClass[i] = (truePositivesPerClass[i]*trueNegativesPerClass[i] - falsePositivesPerClass[i]*falseNegativesPerClass[i])/(Math.sqrt((truePositivesPerClass[i] + falsePositivesPerClass[i])*(truePositivesPerClass[i] + falseNegativesPerClass[i])*(trueNegativesPerClass[i] + falsePositivesPerClass[i])*(trueNegativesPerClass[i] + falseNegativesPerClass[i])));
        }
        averageAccuracyTestSet = 0.0;
        averageSensitivityTestSet = 0.0;
        averageSpecificityTestSet = 0.0;
        averagePositivePredictiveValueTestSet = 0.0;
        averageNegativePredictiveValueTestSet = 0.0;
        averageF1ScoreTestSet = 0.0;
        averageDERTestSet = 0.0;
        //averageMCCTestSet = 0.0;

        for (int i=0; i<classNames.length; i++){
            averageAccuracyTestSet += accuracyTestSetPerClass[i];
            averageSensitivityTestSet += sensitivityTestSetPerClass[i];
            averageSpecificityTestSet += specificityTestSetPerClass[i];
            averagePositivePredictiveValueTestSet += positivePredictiveValueTestSetPerClass[i];
            averageNegativePredictiveValueTestSet += negativePredictiveValueTestSetPerClass[i];
            averageF1ScoreTestSet += F1ScoreTestSetPerClass[i];
            averageDERTestSet += DERTestSetPerClass[i];
            //averageMCCTestSet += MCCTestSetPerClass[i];
        }
        averageAccuracyTestSet /= classNames.length;
        averageSensitivityTestSet /= classNames.length;
        averageSpecificityTestSet /= classNames.length;
        averagePositivePredictiveValueTestSet /= classNames.length;
        averageNegativePredictiveValueTestSet /= classNames.length;
        averageF1ScoreTestSet /= classNames.length;
        averageDERTestSet /= classNames.length;
        //averageMCCTestSet /= classNames.length;
    }

    public int [] getTruePositivesPerClass(){
        return this.truePositivesPerClass;
    }
    public int [] getTrueNegativesPerClass(){
        return this.trueNegativesPerClass;
    }
    public int [] getFalseNegativesPerClass(){
        return this.falseNegativesPerClass;
    }
    public int [] getFalsePositivesPerClass(){
        return this.falsePositivesPerClass;
    }
    public double [] getAccuracyTestSetPerClass(){
        return this.accuracyTestSetPerClass;
    }
    public double [] getSensitivityTestSetPerClass(){
        return this.sensitivityTestSetPerClass;
    }
    public double [] getSpecificityTestSetPerClass(){
        return this.specificityTestSetPerClass;
    }
    public double [] getPositivePredictiveValueTestSetPerClass(){
        return this.positivePredictiveValueTestSetPerClass;
    }
    public double [] getNegativePredictiveValueTestSetPerClass(){
        return this.negativePredictiveValueTestSetPerClass;
    }
    public double [] getF1ScoreTestSetPerClass(){
        return this.F1ScoreTestSetPerClass;
    }
    public double [] getDERTestSetPerClass(){
        return this.DERTestSetPerClass;
    }
    public double [] getMCCTestSetPerClass(){
        return this.MCCTestSetPerClass;
    }
    public double getAverageAccuracyTestSet(){
        return this.averageAccuracyTestSet;
    }
    public double getAverageSensitivityTestSet(){
        return this.averageSensitivityTestSet;
    }
    public double getAverageSpecificityTestSet(){
        return this.averageSpecificityTestSet;
    }
    public double getAveragePositivePredictiveValueTestSet(){
        return this.averagePositivePredictiveValueTestSet;
    }
    public double getAverageNegativePredictiveValueTestSet(){
        return this.averageNegativePredictiveValueTestSet;
    }
    public double getAverageF1ScoreTestSet(){
        return this.averageF1ScoreTestSet;
    }
    public double getAverageDERTestSet(){
        return this.averageDERTestSet;
    }
    public double getAverageMCCTestSet(){
        return this.averageMCCTestSet;
    }

    public void saveClassificationResultsToFile(String filename, boolean append, String evaluationMethod, String holdoutPercentage, int kFolds, boolean first) throws ModelingException{
        if (filename != null){
            try {
                BufferedWriter bw = new BufferedWriter(new FileWriter(filename,append));

                bw.write(this.classifierName);
                bw.newLine();
                bw.write(evaluationMethod);
                bw.newLine();
                if (evaluationMethod.equals("Holdout")){
                    bw.write("Training set percentage included in holdout: "+holdoutPercentage);
                    bw.newLine();
                }
                else if (evaluationMethod.equals("kFoldCrossValidation")){
                    bw.write("Number of folds: "+kFolds);
                    bw.newLine();
                }
                if (first) {
                    bw.write("In order sampling was used.");
                    bw.newLine();
                }
                else {
                    bw.write("Random sampling was used.");
                    bw.newLine();
                }

                String [] clNames = cm.getClassNames();
                int i = 0;
                int j = 0;
                for (i = 0; i < clNames.length-1; i++){
                    bw.write(clNames[i]+",");
                }
                bw.write(clNames[i]);
                bw.newLine();
                int [][] classifications = cm.getClassifications();
                for (i = 0; i < classifications.length; i++){
                    for (j = 0; j < classifications[i].length; j++){
                        bw.write(classifications[i][j]+" ");
                    }
                    bw.newLine();
                }
                bw.newLine();
                bw.write("TP per class: ");
                for (i = 0; i < truePositivesPerClass.length-1; i++){
                    bw.write(truePositivesPerClass[i]+",");
                }
                bw.write(truePositivesPerClass[i]+"");
                bw.newLine();
                bw.write("TN per class: ");
                for (i = 0; i < trueNegativesPerClass.length-1; i++){
                    bw.write(trueNegativesPerClass[i]+",");
                }
                bw.write(trueNegativesPerClass[i]+"");
                bw.newLine();

                bw.write("FP per class: ");
                for (i = 0; i < falsePositivesPerClass.length-1; i++){
                    bw.write(falsePositivesPerClass[i]+",");
                }
                bw.write(falsePositivesPerClass[i]+"");
                bw.newLine();

                bw.write("FN per class: ");
                for (i = 0; i < falseNegativesPerClass.length-1; i++){
                    bw.write(falseNegativesPerClass[i]+",");
                }
                bw.write(falseNegativesPerClass[i]+"");
                bw.newLine();

                bw.write("ACC: ");
                bw.write(accuracyTestSetPerClass[0]+"");
                bw.newLine();

                bw.write("SENS per class: ");
                for (i = 0; i < sensitivityTestSetPerClass.length-1; i++){
                    bw.write(sensitivityTestSetPerClass[i]+",");
                }
                bw.write(sensitivityTestSetPerClass[i]+"");
                bw.newLine();

                bw.write("SPEC per class: ");
                for (i = 0; i < specificityTestSetPerClass.length-1; i++){
                    bw.write(specificityTestSetPerClass[i]+",");
                }
                bw.write(specificityTestSetPerClass[i]+"");
                bw.newLine();

                bw.write("PPV per class: ");
                for (i = 0; i < positivePredictiveValueTestSetPerClass.length-1; i++){
                    bw.write(positivePredictiveValueTestSetPerClass[i]+",");
                }
                bw.write(positivePredictiveValueTestSetPerClass[i]+"");
                bw.newLine();

                bw.write("NPV per class: ");
                for (i = 0; i < negativePredictiveValueTestSetPerClass.length-1; i++){
                    bw.write(negativePredictiveValueTestSetPerClass[i]+",");
                }
                bw.write(negativePredictiveValueTestSetPerClass[i]+"");
                bw.newLine();

                bw.write("F1 per class: ");
                for (i = 0; i < F1ScoreTestSetPerClass.length-1; i++){
                    bw.write(F1ScoreTestSetPerClass[i]+",");
                }
                bw.write(F1ScoreTestSetPerClass[i]+"");
                bw.newLine();

                bw.write("DER per class: ");
                for (i = 0; i < DERTestSetPerClass.length-1; i++){
                    bw.write(DERTestSetPerClass[i]+",");
                }
                bw.write(DERTestSetPerClass[i]+"");
                bw.newLine();
                /*
                bw.write("MCC per class: ");
                for (i = 0; i < MCCTestSetPerClass.length-1; i++){
                    bw.write(MCCTestSetPerClass[i]+",");
                }
                bw.write(MCCTestSetPerClass[i]+"");
                bw.newLine();*/

                bw.write("Mean ACC: " + averageAccuracyTestSet);
                bw.newLine();
                bw.write("Mean SENS: " + averageSensitivityTestSet);
                bw.newLine();
                bw.write("Mean SPEC: " + averageSpecificityTestSet);
                bw.newLine();
                bw.write("Mean PPV: " + averagePositivePredictiveValueTestSet);
                bw.newLine();
                bw.write("Mean NPV: " + averageNegativePredictiveValueTestSet);
                bw.newLine();
                bw.write("Mean F1: " + averageF1ScoreTestSet);
                bw.newLine();
                bw.write("Mean DER: " + averageDERTestSet);
                bw.newLine();
                /*bw.write("Mean MCC: " + averageMCCTestSet);
                bw.newLine();*/
                bw.write("Actual class   Predicted class");
                bw.newLine();
                for (i = 0; i < actualClasses.length; i++){
                    bw.write(actualClasses[i]+" "+predictedClasses[i]);
                    bw.newLine();
                }
                bw.flush();
                bw.close();
            }
            catch (IOException exc){
                throw new ModelingException("Classification results cannot be saved to file");
            }
        }
    }
}
