package multisab.processing.machineLearning;

import org.encog.neural.networks.training.Train;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;

public class ModelEvaluation {

    /**
     *
     * @param features Input dataset
     * @param percentageForTraining Number between 0 and 1, where 0 means that all data is used for testing, and 1 means that all data is used for training
     * @param first If set to true, then the dataset will be split taking the first M feature vectors for training, and N-M feature vectors for testing
     * @param seed
     * @param last used for setting actual classes
     * @return
     */
    public static TrainTestSet holdout(double[][] features, double percentageForTraining, boolean first, long seed, boolean last, String [] classesLabels, boolean includesOrdinalNumber){
        TrainTestSet trainTestSet = new TrainTestSet();

        double [][] trainDataset = null;
        double [][] testDataset = null;
        int [] testSetFeatureVectorsNumber = null;
        String[] trainSetClassesLabels = null;
        String[] testSetClassesLabels = null;

        if (first) {
            int cutIndex = (int) (features.length * percentageForTraining);

            if (includesOrdinalNumber){
                trainDataset = new double[cutIndex][features[0].length-1];
                testDataset = new double[features.length - cutIndex][features[0].length-1];
                trainSetClassesLabels = new String[cutIndex];
                testSetClassesLabels = new String[features.length - cutIndex];
                testSetFeatureVectorsNumber = new int[features.length - cutIndex];
                for (int i = 0; i<cutIndex; i++){
                    trainSetClassesLabels[i] = classesLabels[i];
                    for (int j=0; j<features[0].length-1; j++) {
                        trainDataset[i][j] = features[i][j+1];
                    }
                }
                for (int i = 0; i<features.length-cutIndex; i++){
                    testSetClassesLabels[i] = classesLabels[i+cutIndex];
                    for (int j=0; j<features[0].length-1; j++) {
                        testDataset[i][j] = features[i+cutIndex][j+1];
                    }
                    testSetFeatureVectorsNumber[i] = (int) features[i+cutIndex][0];
                }
            }
            else {
                trainDataset = new double[cutIndex][features[0].length];
                testDataset = new double[features.length - cutIndex][features[0].length];
                trainSetClassesLabels = new String[cutIndex];
                testSetClassesLabels = new String[features.length - cutIndex];
                testSetFeatureVectorsNumber = new int[features.length - cutIndex];
                for (int i = 0; i<cutIndex; i++){
                    trainSetClassesLabels[i] = classesLabels[i];
                    for (int j=0; j<features[0].length; j++) {
                        trainDataset[i][j] = features[i][j];
                    }
                }
                for (int i = 0; i<features.length-cutIndex; i++){
                    testSetClassesLabels[i] = classesLabels[i+cutIndex];
                    for (int j=0; j<features[0].length; j++) {
                        testDataset[i][j] = features[i+cutIndex][j];
                    }
                    testSetFeatureVectorsNumber[i] = i+cutIndex;
                }
            }
            trainTestSet.setTrainSet(trainDataset);
            trainTestSet.setTestSet(testDataset);
            trainTestSet.setTrainSetClassesLabels(trainSetClassesLabels);
            trainTestSet.setTestSetClassesLabels(testSetClassesLabels);
            trainTestSet.setTestSetFeatureVectorNumber(testSetFeatureVectorsNumber);
        }
        else { // select a random sample each time
            int trainingSamples = (int) (features.length * percentageForTraining);

            if (includesOrdinalNumber) {
                trainDataset = new double[trainingSamples][features[0].length-1];
                testDataset = new double[features.length - trainingSamples][features[0].length-1];
                trainSetClassesLabels = new String[trainingSamples];
                testSetClassesLabels = new String[features.length - trainingSamples];
                testSetFeatureVectorsNumber = new int[features.length - trainingSamples];

                Random r = new Random(seed);

                double [] tempSample;
                List<double[]> listFeatures = new ArrayList<>();
                for (int i = 0; i < features.length; i++) {
                    listFeatures.add(features[i]);
                }
                int j = 0;
                int index;
                for (int i = 0; i < trainingSamples; i++) {
                    index = r.nextInt(features.length - i);
                    tempSample = listFeatures.remove(index);
                    trainSetClassesLabels[i] = classesLabels[(int)(tempSample[0])];
                    for (j = 0; j < tempSample.length - 1; j++) {
                        trainDataset[i][j] = tempSample[j+1];
                    }
                    // add here code if, in any case, training set feature vectors number should be remembered
                }
                for (int i = 0; i < listFeatures.size(); i++) {
                    tempSample = listFeatures.get(i);
                    testSetClassesLabels[i] = classesLabels[(int)(tempSample[0])];
                    for (j = 0; j < tempSample.length - 1; j++) {
                        testDataset[i][j] = tempSample[j+1];
                    }
                    testSetFeatureVectorsNumber[i] = (int) tempSample[0];
                }
                trainTestSet.setTrainSet(trainDataset);
                trainTestSet.setTestSet(testDataset);
                trainTestSet.setTrainSetClassesLabels(trainSetClassesLabels);
                trainTestSet.setTestSetClassesLabels(testSetClassesLabels);
                trainTestSet.setTestSetFeatureVectorNumber(testSetFeatureVectorsNumber);
            }
            else {
                trainDataset = new double[trainingSamples][features[0].length];
                testDataset = new double[features.length - trainingSamples][features[0].length];
                trainSetClassesLabels = new String[trainingSamples];
                testSetClassesLabels = new String[features.length - trainingSamples];
                testSetFeatureVectorsNumber = new int[features.length - trainingSamples];

                List<Integer> randomList = new ArrayList<>();

                for (int i = 0; i < features.length; i++){
                    randomList.add(i);
                }
                Collections.shuffle(randomList, new Random(seed));

                for (int i = 0; i < trainingSamples; i++){
                    trainDataset[i] = features[randomList.get(i)];
                    trainSetClassesLabels[i] = classesLabels[randomList.get(i)];
                }
                for (int i = 0; i < testDataset.length; i++){
                    testDataset[i] = features[randomList.get(trainingSamples+i)];
                    testSetClassesLabels[i] = classesLabels[randomList.get(trainingSamples+i)];
                    testSetFeatureVectorsNumber[i] = randomList.get(trainingSamples+i);
                }

                trainTestSet.setTrainSet(trainDataset);
                trainTestSet.setTestSet(testDataset);
                trainTestSet.setTrainSetClassesLabels(trainSetClassesLabels);
                trainTestSet.setTestSetClassesLabels(testSetClassesLabels);
                trainTestSet.setTestSetFeatureVectorNumber(testSetFeatureVectorsNumber);
            }
        }
        //if (last) trainTestSet.setGoalIndex(features[0].length-2);
        //else trainTestSet.setGoalIndex(goalIndex);
        return trainTestSet;
    }

    public static TrainTestSet kFoldCrossValidation(double[][] features, int kFolds, boolean first, long seed, boolean last, int goalIndex, String [] classesLabels, boolean includesOrdinalNumber) {
        TrainTestSet trainTestSet = new TrainTestSet();

        // needs to adapt for correct class labels
        trainTestSet.setNewFoldsCount(kFolds);

        int cutIndex = (int) (features.length / kFolds);
        double[][][] fold = null;
        String[][] foldClassesLabels = null;
        int [] testSetFeatureVectorsNumber = new int[features.length];

        if (first) {
            if (!includesOrdinalNumber) {
                fold = new double[kFolds][cutIndex][features[0].length];
                foldClassesLabels = new String[kFolds][cutIndex];
                for (int i = 0; i < kFolds; i++) {
                    for (int j = 0; j < cutIndex; j++) {
                        fold[i][j] = features[i * cutIndex + j];
                        foldClassesLabels[i][j] = classesLabels[i * cutIndex + j];
                    }
                    trainTestSet.addFoldData(fold[i],foldClassesLabels[i]);
                }
                for (int i = 0; i < features.length; i++){
                    testSetFeatureVectorsNumber[i] = i;
                }
            }
            else {
                fold = new double[kFolds][cutIndex][features[0].length-1];
                foldClassesLabels = new String[kFolds][cutIndex];

                double [][] tempFeatures = new double[features.length][features[0].length-1];
                for (int i = 0; i < features.length; i++){
                    for (int j = 0; j < features[0].length - 1; j++) {
                        tempFeatures[i][j] = features[i][j+1];
                    }
                    testSetFeatureVectorsNumber[i] = (int) features[i][0];
                }
                for (int i = 0; i < kFolds; i++) {
                    for (int j = 0; j < cutIndex; j++) {
                        fold[i][j] = tempFeatures[i * cutIndex + j];
                        foldClassesLabels[i][j] = classesLabels[i * cutIndex + j];
                    }
                    trainTestSet.addFoldData(fold[i],foldClassesLabels[i]);
                }
                trainTestSet.setTestSetFeatureVectorNumber(testSetFeatureVectorsNumber);
            }
        }
        else {
            int singleFoldSamples = (int) (features.length / kFolds);
            int tempCounter;

            if (includesOrdinalNumber) {
                fold = new double[kFolds][cutIndex][features[0].length-1];
                foldClassesLabels = new String[kFolds][cutIndex];

                Random r = new Random(seed);

                List<double[]> listFeatures = new ArrayList<>();
                for (int i = 0; i < features.length; i++) {
                    listFeatures.add(features[i]);
                }
                double[] tempSample;
                for (int i = 0; i < kFolds; i++) {
                    tempCounter = singleFoldSamples;
                    for (int j = 0; j < tempCounter; j++) {
                        tempSample = listFeatures.remove(r.nextInt(features.length - i * singleFoldSamples - j));
                        testSetFeatureVectorsNumber[i * kFolds + j] = (int) tempSample[0];
                        foldClassesLabels[i][j] = classesLabels[(int) tempSample[0]];
                        for (int k = 0; k < tempSample.length-1; k++) {
                            fold[i][j][k] = tempSample[k+1];
                        }
                    }
                    trainTestSet.addFoldData(fold[i],foldClassesLabels[i]);
                }
                trainTestSet.setTestSetFeatureVectorNumber(testSetFeatureVectorsNumber);
            }
            else {
                fold = new double[kFolds][cutIndex][features[0].length];
                foldClassesLabels = new String[kFolds][cutIndex];

                List<Integer> randomList = new ArrayList<>();

                for (int i = 0; i < features.length; i++){
                    randomList.add(i);
                }
                Collections.shuffle(randomList, new Random(seed));

                for (int i = 0; i < kFolds; i++){
                    for (int j = 0; j < cutIndex; j++) {
                        fold[i][j] = features[randomList.get(i*kFolds + j)];
                        foldClassesLabels[i][j] = classesLabels[randomList.get(i*kFolds + j)];
                        testSetFeatureVectorsNumber[i*kFolds + j] = randomList.get(i*kFolds + j);
                    }
                    trainTestSet.addFoldData(fold[i],foldClassesLabels[i]);
                }
                trainTestSet.setTestSetFeatureVectorNumber(testSetFeatureVectorsNumber);
            }
        }
        //if (last) trainTestSet.setGoalIndex(features[0].length-1);
        //else trainTestSet.setGoalIndex(goalIndex);
        return trainTestSet;
    }
}
