package multisab.processing.machineLearning;

import org.encog.Encog;
import org.encog.mathutil.Equilateral;
import org.encog.ml.MLEncodable;
import org.encog.ml.MLMethod;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.StopTrainingStrategy;
import org.encog.neural.networks.training.strategy.RegularizationStrategy;
import org.encog.util.csv.ReadCSV;
import org.encog.util.simple.EncogUtility;

import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import static org.encog.util.csv.CSVFormat.DECIMAL_POINT;

/**
 * Created by Davor on 6.10.2016..
 */
public class IrisClassification_SVM {

    public void run(String[] args) {

        try {

            double in[][] = new double[150][4];
            int out[] = new int[150];

            File sourceFile = new File("E:/iris_integer.csv");
            ReadCSV csv = new ReadCSV(sourceFile, false, DECIMAL_POINT);

            String[] line = new String[4];
            int i = 0;

            while(csv.next()) {
                in[i][0] = Double.parseDouble(csv.get(0));
                in[i][1] = Double.parseDouble(csv.get(1));
                in[i][2] = Double.parseDouble(csv.get(2));
                in[i][3] = Double.parseDouble(csv.get(3));
                out[i] = Integer.parseInt(csv.get(4));
                i++;
            }

            // Analyze the data, determine the min/max/mean/sd of every column.
            double min[] = new double[4];
            double max[] = new double[4];
            double mean[] = new double[4];
            double sd[] = new double[4];

            for(i = 0; i < 4; i++) {
                min[i] = in[0][i];
                max[i] = in[0][i];
                for(int j = 1; j < 150; j++) {
                    if (in[j][i] < min[i]) {
                        min[i] = in[j][i];
                    }
                    if (in[j][i] > max[i]) {
                        max[i] = in[j][i];
                    }
                }
            }

            Map<Double, Integer> mapClasses = new HashMap<Double, Integer>();
            Map<Integer, Double> inverseMapClasses = new HashMap<Integer, Double>();

            int numberOfClasses = 0;
            mapClasses.put((double)numberOfClasses, out[0]);
            inverseMapClasses.put(out[0], (double)numberOfClasses);
            numberOfClasses++;

            for(int j = 1; j < 150; j++) {
                if (!mapClasses.containsValue(out[j])) {
                    mapClasses.put((double)numberOfClasses, out[j]);
                    inverseMapClasses.put(out[j], (double)numberOfClasses);
                    numberOfClasses++;
                }
            }

            // Map the prediction column to the output of the model, and all
            // other columns to the input.
            // data.defineSingleOutputOthersInput(outputColumn);

            // Now normalize the data.  Encog will automatically determine the correct normalization
            // type based on the model you chose in the last step.
            // data.normalize();

            double features[][] = new double[150][4];
            double rangeMin = -1;
            double rangeMax = 1;

            for(i = 0; i < 4; i++) {
                for(int j = 0; j < 150; j++) {
                    features[j][i] = (in[j][i]-min[i]) * (rangeMax - rangeMin) / (max[i] - min[i]) + rangeMin;
                }
            }

            double ideal[][] = new double[150][1];

            for(int j = 0; j < 150; j++)  {
                double doubleClass = inverseMapClasses.get(out[j]);
                ideal[j][0] = doubleClass;
            }

            // Create the data set
            MLDataSet dataSet = new BasicMLDataSet(features, ideal);

            // Create feedforward neural network as the model type. MLMethodFactory.TYPE_FEEDFORWARD.
            // You could also other model types, such as:
            // MLMethodFactory.TYPE_SVM:  Support Vector Machine (SVM)
            // MLMethodFactory.TYPE_RBFNETWORK: RBF Neural Network
            // MLMethodFactor.TYPE_NEAT: NEAT Neural Network
            // MLMethodFactor.TYPE_PNN: Probabilistic Neural Network

            MLMethodFactory methodFactory = new MLMethodFactory();
            int inputFeatures = 4;
            int outputNumber = 1;
            MLMethod method = methodFactory.create(MLMethodFactory.TYPE_SVM, "?->C->?", inputFeatures, outputNumber);

            MLTrainFactory trainFactory = new MLTrainFactory();
            String trainerName = MLTrainFactory.TYPE_SVM;
            String trainerArgs = "";
            MLTrain train = trainFactory.create(method,dataSet,trainerName,trainerArgs);
			/*
			// reset if improve is less than 1% over 5 cycles
			if( method instanceof MLResettable && !(train instanceof ManhattanPropagation) ) {
				train.addStrategy(new RequiredImprovementStrategy(500));
			}

			EncogUtility.trainToError(train, 0.01);
			method = train.getMethod();
			*/

            /*
            if( method instanceof MLEncodable) {
                train.addStrategy(new RegularizationStrategy(0.01));
                train.addStrategy(new StopTrainingStrategy(0.01, 100));
            }
            */

            EncogUtility.trainToError(train, 0.01);
            method = train.getMethod();


            EncogUtility.evaluate((MLRegression) method, dataSet);

            // Display our normalization parameters.
            // NormalizationHelper helper = data.getNormHelper();
            // System.out.println(helper.toString());

            // Display the final model.
            System.out.println("Final model: " + method);

            // Loop over the entire, original, dataset and feed it through the model.
            // This also shows how you would process new data, that was not part of your
            // training set.  You do not need to retrain, simply use the NormalizationHelper
            // class.  After you train, you can save the NormalizationHelper to later
            // normalize and denormalize your data.

            csv.close();
            csv = new ReadCSV(sourceFile, false, DECIMAL_POINT);

            double input[] = new double[4];

            while(csv.next()) {
                StringBuilder result = new StringBuilder();
                line[0] = csv.get(0);
                line[1] = csv.get(1);
                line[2] = csv.get(2);
                line[3] = csv.get(3);
                String correct = csv.get(4);

                input[0] = Double.parseDouble(line[0]);
                input[1] = Double.parseDouble(line[1]);
                input[2] = Double.parseDouble(line[2]);
                input[3] = Double.parseDouble(line[3]);

                //helper.normalizeInputVector(line,input.getData(),false);
                for(i = 0; i < 4; i++) {
                    input[i] = (input[i]-min[i]) * (rangeMax - rangeMin) / (max[i] - min[i]) + rangeMin;
                }

                MLData data = new BasicMLData(input);
                MLData output = ((MLRegression) method).compute(data);
                double[] clasifier_output = output.getData();
                int irisChosen = mapClasses.get(clasifier_output[0]);
                //String irisChosen = helper.denormalizeOutputVectorToString(output)[0];

                result.append(Arrays.toString(line));
                result.append(" -> predicted: ");
                result.append(irisChosen);
                result.append("(correct: ");
                result.append(correct);
                result.append(")");

                System.out.println(result.toString());
            }


            Encog.getInstance().shutdown();

        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    public static void main(String[] args) {
        IrisClassification_SVM prg = new IrisClassification_SVM();
        prg.run(args);
    }
}