package multisab.processing.machineLearning;

import multisab.processing.analysis.ExtractedFeatures;
import multisab.processing.analysis.ModelingMethod;
import multisab.processing.analysis.PerformModeling;
import multisab.processing.machineLearning.RFcore.*;
import multisab.processing.multisabException.ModelingException;


import java.io.File;
import java.util.*;





public class RFtesting {

    public void run(String[] args) {

        try {


            File sourceFile = new File("E:\\Nikolina\\FER\\MUTLISAB\\MULTISAB\\OutputFeatureVectors_1.csv");
            double[][] features = ExtractedFeatures.readNumericExtractedFeatures(sourceFile.getAbsolutePath(), true);

            TrainTestSet trainTestSet;

            /*
                HOLDOUT
             */
            trainTestSet = ModelEvaluation.holdout(features, 0.66,false,System.currentTimeMillis(),true, ExtractedFeatures.readExtractedDisorders(sourceFile.getAbsolutePath()), true );

            //create train dataset
            Dataset trainDS = new Dataset();
            for(int i=0; i<trainTestSet.trainSet.length; i++){
                ArrayList<Field> fields = new ArrayList<>();
                for (int j=0; j<trainTestSet.trainSet[0].length; j++){
                    RealValField rvf = new RealValField(j);
                    rvf.setValue(trainTestSet.trainSet[i][j]);
                    fields.add(rvf);
                }
                CategoryField cf = new CategoryField(i);
                //check if dictionary list is empty
                if(trainDS.dictionaries.get(trainTestSet.trainSet[0].length)==null){
                    trainDS.dictionaries.put(trainTestSet.trainSet[0].length, new HashMap<Integer, String>());
                }
                Map dict = trainDS.dictionaries.get(trainTestSet.trainSet[0].length);
                if(	dict.containsValue(trainTestSet.trainSetClassesLabels[i])){
                    int key = Entry.getKey(dict, trainTestSet.trainSetClassesLabels[i]);
                    cf.setValue(Entry.getKey(dict, trainTestSet.trainSetClassesLabels[i]));
                }
                else{
                    dict.put(dict.keySet().size(), trainTestSet.trainSetClassesLabels[i]);
                    cf.setValue(Entry.getKey(dict, trainTestSet.trainSetClassesLabels[i]));
                }
                fields.add(cf);
                Entry e = new Entry();
                e.setFields(fields);
                trainDS.addNewEntry(e);
            }

            //create test dataset
            Dataset testDS = new Dataset();
            for(int i=0; i<trainTestSet.testSet.length; i++){
                ArrayList<Field> fields = new ArrayList<>();
                for (int j=0; j<trainTestSet.testSet[0].length; j++){
                    RealValField rvf = new RealValField(j);
                    rvf.setValue(trainTestSet.testSet[i][j]);
                    fields.add(rvf);
                }
                CategoryField cf = new CategoryField(i);
                Map dict = trainDS.dictionaries.get(trainTestSet.trainSet[0].length); //dohvati dictionairy od train seta (mora biti isti dictionary)
                if(	dict.containsValue(trainTestSet.testSetClassesLabels[i])){
                    int key = Entry.getKey(dict, trainTestSet.testSetClassesLabels[i]);
                    cf.setValue(Entry.getKey(dict, trainTestSet.testSetClassesLabels[i]));
                }
                else{
                    dict.put(dict.keySet().size(), trainTestSet.testSetClassesLabels[i]);
                    cf.setValue(Entry.getKey(dict, trainTestSet.testSetClassesLabels[i]));
                }
                fields.add(cf);
                Entry e = new Entry();
                e.setFields(fields);
                testDS.addNewEntry(e);
                testDS.dictionaries = trainDS.dictionaries; // must have same dictionairies


            }


            //Do random forest
            Integer maxDepth = 10000;
            Integer minSize = 1;
            Double sampleSize = 1.0;
            Integer nTrees = 10;

            List<String> params = new ArrayList<>();
            params.add(maxDepth.toString());
            params.add(minSize.toString());
            params.add(sampleSize.toString());
            params.add(nTrees.toString());
            Integer nrFeatures=(int)Math.sqrt((trainDS.getEntries().get(0).getFields().size()-1));
            params.add(nrFeatures.toString());

            List<Object> scores = new ArrayList<>();
            List<Object> predictedRes = new ArrayList<>();

            Algorithm rf = new RandomForest(params);
            predictedRes.addAll( rf.execute(trainDS, testDS));
            scores.add(accuracyMetric(testDS, predictedRes));


            double accuracy = scores.stream().mapToDouble(a->Double.parseDouble(a.toString())).average().getAsDouble();

            System.out.println(accuracy);


        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

     double accuracyMetric(Dataset actual, List<Object> predicted){
        int correct = 0;
        int resIndex = actual.getEntries().get(0).getFields().size() -1;

        for(int i=0; i< actual.getEntries().size(); i++){
            Integer actValue = (Integer) actual.getEntries().get(i).getFields().get(resIndex).getValue();
            if (actValue.intValue() == ((Integer) predicted.get(i)).intValue()){
                correct +=1;
            }
        }

        return (double)correct/((double)actual.getEntries().size());
    }

    public static void main(String[] args) {
      //  RFtesting prg = new RFtesting();
      //  prg.run(args);

      Integer maxDepth = 10000;
        Integer minSize = 1;
        Double sampleSize = 1.0;
        Integer nTrees = 10;

        String [] params = new String[4];
        params[0]=maxDepth.toString();
        params[1]=minSize.toString();
        params[2]=sampleSize.toString();
        params[3]= nTrees.toString();

        String [] method = new String[1];
        method[0] = ModelingMethod.EXT;

        try {
            PerformModeling.executeModeling("E:\\Nikolina\\FER\\MUTLISAB\\MULTISAB\\OutputFeatureVectors_1.csv", null, null, "E:\\Nikolina\\FER\\MUTLISAB\\MULTISAB\\RF_test_1.txt", "Holdout", "66", 1, false, true, 1, method, params, false, false, false);

        }
        catch (ModelingException e){
            e.printStackTrace();
        }

    }
}
