package multisab.processing.machineLearning;

public class TrainTestSet {
    double [][] trainSet;
    double [][] testSet;
    String [] trainSetClassesLabels;
    String [] testSetClassesLabels;

    double [][][] folds;
    String [][] foldsClassesLabels;

    //int goalIndex;
    int [] testSetFeatureVectorNumber;

    private int foldsCount;
    private int currentFold;

    public TrainTestSet(){
        currentFold = -1;
    };

    public void setTrainSetClassesLabels(String [] trainSetClassesLabels){
        this.trainSetClassesLabels = trainSetClassesLabels;
    }
    public String [] getTrainSetClassesLabels(){
        return this.trainSetClassesLabels;
    }
    public void setTestSetClassesLabels(String [] testSetClassesLabels){
        this.testSetClassesLabels = testSetClassesLabels;
    }
    public String [] getTestSetClassesLabels(){
        return this.testSetClassesLabels;
    }
    public void setTrainSet(double[][] trainSet) {
        this.trainSet = trainSet;
    }
    public void setTestSet(double [][] testSet){
        this.testSet = testSet;
    }
    public double [][] getTrainSet(){
        return this.trainSet;
    }
    public double [][] getTestSet(){
        return this.testSet;
    }
    public void setNewFoldsCount(int foldsCount){
        this.foldsCount = foldsCount;
        folds = new double[foldsCount][][];
        foldsClassesLabels = new String[foldsCount][];
        currentFold = 0;
    }
    public boolean addFoldData(double [][] fold, String[] foldsClassesLabels){
        if (currentFold < foldsCount) {
            this.folds[currentFold] = fold;
            this.foldsClassesLabels[currentFold] = foldsClassesLabels;
            currentFold++;
            return true;
        }
        else return false;
    }
    public void setTestSetFeatureVectorNumber(int [] featureVectorNumber){
        this.testSetFeatureVectorNumber = featureVectorNumber;
    }
    public int[] getTestSetFeatureVectorNumber(){
        return this.testSetFeatureVectorNumber;
    }
    public int getFoldsCount(){
        return this.folds.length;
    }
    public double[][] getFold(int foldNumber){
        return this.folds[foldNumber];
    }

    public TrainTestSet formTrainTestSetForFoldNumber(int foldNo){
        TrainTestSet kfoldTrainTestSet = new TrainTestSet();

        double[][] trainingSet = new double[(folds.length-1)*folds[0].length][];
        double[][] testSet = new double[folds[0].length][];
        String [] trainingSetLabels = new String[(folds.length-1)*folds[0].length];
        String [] testSetLabels = new String[folds[0].length];

        boolean encountered = false;
        for (int i=0; i<folds.length; i++){
            if (i==foldNo) {
                encountered = true;
                continue;
            }
            else {
               if (encountered) {
                   for (int j = 0; j < folds[i].length; j++){
                        trainingSet[(i-1)*folds[0].length+j] = folds[i][j];
                        trainingSetLabels[(i-1)*folds[0].length+j] = foldsClassesLabels[i][j];
                   }
               }
               else {
                   for (int j = 0; j < folds[i].length; j++){
                       trainingSet[i*folds[0].length+j] = folds[i][j];
                       trainingSetLabels[i*folds[0].length+j] = foldsClassesLabels[i][j];
                   }
               }
            }
        }
        for (int i = 0; i < folds[foldNo].length; i++){
            testSet[i] = folds[foldNo][i];
            testSetLabels[i] = foldsClassesLabels[foldNo][i];
        }
        kfoldTrainTestSet.setTrainSet(trainingSet);
        kfoldTrainTestSet.setTestSet(testSet);
        kfoldTrainTestSet.setTrainSetClassesLabels(trainingSetLabels);
        kfoldTrainTestSet.setTestSetClassesLabels(testSetLabels);

        return kfoldTrainTestSet;
    }

}
