package multisab.processing.machineLearning;

import multisab.processing.machineLearning.normalization.*;

import org.encog.ml.MLMethod;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.factory.MLMethodFactory;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.*;
import org.encog.ml.train.strategy.end.*;
import org.encog.neural.networks.training.strategy.*;
import org.encog.util.simple.EncogUtility;
import org.encog.persist.EncogDirectoryPersistence;

import java.io.File;
import java.util.ArrayList;
import java.util.List;

/**
 * Created by Davor on 6.10.2016..
 */
public class Classifier {

    private MLMethod method;
    private Normalization featureNormalization;
    private Coder outputCoder;
    private String methodType;
    private String methodArgs;
    private String trainingType;
    private String trainingArgs;
    private List<Strategy> strategy;
    private double normalizationMin;
    private double normalizationMax;
    private double outputNormalizationMin;
    private double outputNormalizationMax;
    private String codeType;
    private int numberOfClasses;

    public static final String FEEDFORWARD = "FEEDFORWARD";
    public static final String SVM = "SVM";
    public static final String RBFNETWORK = "RBFNETWORK";
    public static final String NEAT = "NEAT";
    public static final String PNN = "PNN";

    public Classifier(String MLType, int numberOfClasses) {

        strategy = new ArrayList<Strategy>();

        if(numberOfClasses > 2) {
            codeType = "Equilateral";
        }
        else {
            codeType = "OneOfN";
        }

        this.numberOfClasses = numberOfClasses;

        switch(MLType) {
            case "FEEDFORWARD": {
                methodType = "feedforward";
                methodArgs = "?:B->TANH->?:B->?";
                trainingType = "rprop";
                trainingArgs = "";
                normalizationMin = -1;
                normalizationMax = 1;
                outputNormalizationMin = -1;
                outputNormalizationMax = 1;
                addStrategy(new RegularizationStrategy(0.01));
                addStrategy(new StopTrainingStrategy(0.01, 100));
                break;
            }
            case "SVM": {
                methodType = "svm";
                methodArgs = "?->C->?";
                trainingType = "svm-train";
                trainingArgs = "";
                normalizationMin = 0;
                normalizationMax = 1;
                codeType = "Int";
                break;
            }
            case "RBFNETWORK": {
                methodType = "rbfnetwork";
                methodArgs = "?->gaussian(c=?)->?";
                trainingType = "rprop";
                trainingArgs = "";
                normalizationMin = 0;
                normalizationMax = 1;
                outputNormalizationMin = -1;
                outputNormalizationMax = 1;
                addStrategy(new StopTrainingStrategy(0.001, 100));
                break;
            }
            case "NEAT": {
                methodType = "neat";
                methodArgs = "cycles=4";
                trainingType = "neat-ga";
                trainingArgs = "";
                normalizationMin = 0;
                normalizationMax = 1;
                outputNormalizationMin = -1;
                outputNormalizationMax = 1;
                addStrategy(new StopTrainingStrategy(0.001, 100));
                break;
            }
            case "PNN": {
                methodType = "pnn";
                methodArgs = "?->C(kernel=gaussian)->?";
                trainingType = "pnn";
                trainingArgs = "";
                normalizationMin = 0;
                normalizationMax = 1;
                codeType = "Int->OneOfN";
                break;
            }
        }
    }

    public void train(double[][] features, int[] classes) {

        featureNormalization = new Normalization(normalizationMin,normalizationMax);

        switch(codeType) {
            case "Equilateral": {
                outputCoder = new EquilateralCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
            case "OneOfN": {
                outputCoder = new OneOfNCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
            case "Int": {
                outputCoder = new IntCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
            case "Int->OneOfN": {
                outputCoder = new IntOneOfNCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
        }
        double ideal[][] =  outputCoder.encode(classes);
        trainML(features, ideal);
    }


    public void train(double[][] features, String[] classes) {

        featureNormalization = new Normalization(normalizationMin,normalizationMax);

        switch(codeType) {
            case "Equilateral": {
                outputCoder = new EquilateralStringCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
            case "OneOfN": {
                outputCoder = new OneOfNStringCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
            case "Int": {
                outputCoder = new IntStringCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
            case "Int->OneOfN": {
                outputCoder = new IntOneOfNStringCoder(outputNormalizationMin, outputNormalizationMax, numberOfClasses);
                break;
            }
        }
        String [] trainClasses = new String[features.length];
        for (int i = 0; i < features.length; i++){
            trainClasses[i] = classes[i];
        }
        double ideal[][] =  outputCoder.encode(trainClasses);
        trainML(features, ideal);
    }

    public void trainML(double[][] features, double[][] ideal) {

        featureNormalization.analyze(features);

        double normFeatures[][] = featureNormalization.normalization(features);

        // Create the data set
        MLDataSet dataSet = new BasicMLDataSet(normFeatures, ideal);

        MLMethodFactory methodFactory = new MLMethodFactory();

        int numberOfFeatures = features[0].length;
        int numberOfOutputs = ideal[0].length;

        if((methodType.equals("feedforward")) && (methodArgs.equals("?:B->TANH->?:B->?"))) {
            int numberOfHiddenNeurons = (int) Math.ceil((numberOfFeatures + numberOfOutputs)/2.);
            methodArgs = methodArgs.substring(0,11) + numberOfHiddenNeurons + methodArgs.substring(12);
            method = methodFactory.create(methodType, methodArgs, numberOfFeatures, numberOfOutputs);
        }
        else if((methodType.equals("rbfnetwork")) && (methodArgs.equals("?->gaussian(c=?)->?"))) {
            int numberOfHiddenNeurons = numberOfFeatures + numberOfOutputs;
            methodArgs = methodArgs.substring(0,14) + numberOfHiddenNeurons + methodArgs.substring(15);
            method = methodFactory.create(methodType, methodArgs, numberOfFeatures, numberOfClasses);
        }
        else if(methodType.equals("pnn")) {
            method = methodFactory.create(methodType, methodArgs, numberOfFeatures, numberOfClasses);
        }
        else {
            method = methodFactory.create(methodType, methodArgs, numberOfFeatures, numberOfOutputs);
        }

        MLTrainFactory trainFactory = new MLTrainFactory();
        MLTrain train = trainFactory.create(method,dataSet,trainingType,trainingArgs);

        int numberOfStrategies = strategy.size();

        for(int i = 0; i < numberOfStrategies; i++) {
            train.addStrategy(strategy.get(i));
        }

        EncogUtility.trainToError(train, 0.01);
        method = train.getMethod();

        //EncogUtility.evaluate((MLRegression) method, dataSet);
    }

    public int[] classify(double[][] features) {

        int numberOfSamples = features.length;

        int[] intOutput = new int[numberOfSamples];

        double normFeatures[][] = featureNormalization.normalization(features);

        for (int i = 0; i < numberOfSamples; i++) {
            MLData data = new BasicMLData(normFeatures[i]);
            MLData output = ((MLRegression) method).compute(data);
            double[] clasifierOutput = output.getData();
            intOutput[i] = outputCoder.decode(clasifierOutput);
        }

        return intOutput;
    }

    public String[] classifyToString(double[][] features) {

        int numberOfSamples = features.length;

        String[] stringOutput = new String[numberOfSamples];

        double normFeatures[][] = featureNormalization.normalization(features);

        for (int i = 0; i < numberOfSamples; i++) {
            MLData data = new BasicMLData(normFeatures[i]);
            MLData output = ((MLRegression) method).compute(data);
            double[] clasifierOutput = output.getData();
            stringOutput[i] = outputCoder.decodeToString(clasifierOutput);
        }

        return stringOutput;
    }

    public void saveTrainedModelToFile(String filename){
        if (filename != null) {
            String fileNameExt = filename.substring(filename.lastIndexOf("."));
            String fileNamePath = filename.substring(0,filename.lastIndexOf("."));
            EncogDirectoryPersistence.saveObject(new File(fileNamePath+"_M"+fileNameExt), this.method);
            EncogDirectoryPersistence.saveObject(new File(fileNamePath+"_N"+fileNameExt), this.featureNormalization);
            EncogDirectoryPersistence.saveObject(new File(fileNamePath+"_O"+fileNameExt), this.outputCoder);
        }
        else {
            EncogDirectoryPersistence.saveObject(new File("ModelFile_M.mod"), this.method);
            EncogDirectoryPersistence.saveObject(new File("ModelFile_N.mod"), this.featureNormalization);
            EncogDirectoryPersistence.saveObject(new File("ModelFile_O.mod"), this.outputCoder);
        }
    }

    public void loadTrainedModelFromFile(String filename){
        if (filename != null) {
            String fileNameExt = filename.substring(filename.lastIndexOf("."));
            String fileNamePath = filename.substring(0,filename.lastIndexOf("."));
            method = (MLMethod) EncogDirectoryPersistence.loadObject(new File(fileNamePath+"_M"+fileNameExt));
            featureNormalization = (Normalization) EncogDirectoryPersistence.loadObject(new File(fileNamePath+"_N"+fileNameExt));
            outputCoder = (Coder) EncogDirectoryPersistence.loadObject(new File(fileNamePath+"_O"+fileNameExt));
        }
        else {
            method = (MLMethod) EncogDirectoryPersistence.loadObject(new File("ModelFile_M.mod"));
            featureNormalization = (Normalization) EncogDirectoryPersistence.loadObject(new File("ModelFile_N.mod"));
            outputCoder = (Coder) EncogDirectoryPersistence.loadObject(new File("ModelFile_O.mod"));
        }
    }

    public void setMethodType(String methodType) {
        this.methodType = methodType;
    }

    public String getMethodType() {
        return methodType;
    }

    public void setMethodArgs(String methodArgs) {
        this.methodArgs = methodArgs;
    }

    public String getMethodArgs() {
        return methodArgs;
    }

    public void setTrainingType(String trainingType) {
        this.trainingType = trainingType;
    }

    public String getTrainingType() {
        return trainingType;
    }

    public void setTrainingArgs(String trainingArgs) {
        this.trainingArgs = trainingArgs;
    }

    public String getTrainingArgs() {
        return trainingArgs;
    }

    public void addStrategy(Strategy strategy) {
        this.strategy.add(strategy);
    }

    public void resetStrategy() {
        strategy.clear();
    }

    public List<Strategy> getStrategy() {
        return strategy;
    }

    public void setNormalizationRange(double normalizationMin, double normalizationMax) {
        this.normalizationMin = normalizationMin;
        this.normalizationMax = normalizationMax;
    }

    public void setOutputNormalizationRange(double outputNormalizationMin, double outputNormalizationMax) {
        this.outputNormalizationMin = outputNormalizationMin;
        this.outputNormalizationMax = outputNormalizationMax;
    }

    public void setCodeType(String codeType) {
        this.codeType = codeType;
    }

    public String getCodeType() {
        return codeType;
    }
}
