package multisab.processing.machineLearning;

import org.encog.Encog;
import org.encog.util.csv.ReadCSV;

import static org.encog.util.csv.CSVFormat.DECIMAL_POINT;

import java.io.File;
import java.util.Arrays;

/**
 * Created by Davor on 11.10.2016..
 */
public class IrisClassificationMain {

    public void run(String[] args) {

        try {

            double in[][] = new double[150][4];
            String out[] = new String[150];

            File sourceFile = new File("E:/iris.csv");
            ReadCSV csv = new ReadCSV(sourceFile, false, DECIMAL_POINT);

            String[] line = new String[4];
            int i = 0;

            while(csv.next()) {
                in[i][0] = Double.parseDouble(csv.get(0));
                in[i][1] = Double.parseDouble(csv.get(1));
                in[i][2] = Double.parseDouble(csv.get(2));
                in[i][3] = Double.parseDouble(csv.get(3));
                out[i] = csv.get(4);
                i++;
            }

            //Classifier classifier = new Classifier("SVM", 3);
            //Classifier classifier = new Classifier("PNN", 3);
            Classifier classifier = new Classifier("FEEDFORWARD", 3);
            //classifier.setCodeType("OneOfN");

            classifier.train(in, out);

            // Loop over the entire, original, dataset and feed it through the model.
            // This also shows how you would process new data, that was not part of your
            // training set.  You do not need to retrain, simply use the NormalizationHelper
            // class.  After you train, you can save the NormalizationHelper to later
            // normalize and denormalize your data.

            csv.close();
            csv = new ReadCSV(sourceFile, false, DECIMAL_POINT);

            double input[][] = new double[1][4];

            while(csv.next()) {
                StringBuilder result = new StringBuilder();
                line[0] = csv.get(0);
                line[1] = csv.get(1);
                line[2] = csv.get(2);
                line[3] = csv.get(3);
                String correct = csv.get(4);

                input[0][0] = Double.parseDouble(line[0]);
                input[0][1] = Double.parseDouble(line[1]);
                input[0][2] = Double.parseDouble(line[2]);
                input[0][3] = Double.parseDouble(line[3]);

                String[] irisChosen = classifier.classifyToString(input);

                result.append(Arrays.toString(line));
                result.append(" -> predicted: ");
                result.append(irisChosen[0]);
                result.append("(correct: ");
                result.append(correct);
                result.append(")");

                System.out.println(result.toString());
            }


            Encog.getInstance().shutdown();

        } catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    public static void main(String[] args) {
        IrisClassificationMain prg = new IrisClassificationMain();
        prg.run(args);
    }
}
