package multisab.processing.machineLearning.featureSelection;

import multisab.processing.machineLearning.discretization.Discretization;
import multisab.processing.machineLearning.discretization.CACC;

import java.util.List;
import java.util.ArrayList;
import java.util.Random;
import java.util.Set;
import java.util.HashSet;
import java.util.stream.IntStream;

/**
 * Main class for feature selection
 * Implemented methods:
 * - Symmetrical Uncertainty
 * - Chi Square
 * - ReliefF ("Marko Robnik-Sikonja, Igor Kononenko: An adaptation of Relief for attribute estimation in regression. In: Fourteenth International Conference on Machine Learning, 296-304, 1997.")
 *
 * @author Josip Renic, Davor Kukolja
 */

//TODO: Selekcija značajki ne podržava nepoznate vrijednosti
//TODO: Selekcija značajki podržava samo nominalne ili samo numeričke značajke

public class FeatureSelectionMethods {

	public static double[] symmetricalUncertainty(double[][] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		double[] SU = new double[data[0].length];
		int[] oneFeature = new int[data.length];

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		CACC cacc = new CACC();
		Discretization discretization = cacc.caccDiscretization(data, classes);

		int[][] discretizeData = discretization.discretizeData(data);

		for(int i = 0; i < SU.length; i++) {

			for (int j = 0; j < oneFeature.length; j++) {
				oneFeature[j] = discretizeData[j][i];
			}

			numberOfDiscreteValues = discretization.getDiscretizationScheme(i).length + 1;

			SU[i] = symmetricalUncertainty(oneFeature, numberOfDiscreteValues, classes, numberOfClasses);
		}

		return SU;
	}

	public static double[] symmetricalUncertainty(int[][] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		double[] SU = new double[data[0].length];
		int[] oneFeature = new int[data.length];
		Set<Integer> discreteValues;

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		for(int i = 0; i < SU.length; i++) {

			discreteValues = new HashSet<Integer>();

			for (int j = 0; j < oneFeature.length; j++) {
				oneFeature[j] = data[j][i];
				if (!discreteValues.contains(oneFeature[j])) {
					discreteValues.add(oneFeature[j]);
				}
			}

			numberOfDiscreteValues = discreteValues.size();

			SU[i] = symmetricalUncertainty(oneFeature, numberOfDiscreteValues, classes, numberOfClasses);
		}

		return SU;
	}

	public static double symmetricalUncertainty(double[] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		CACC cacc = new CACC();
		Discretization discretization = cacc.caccDiscretization(data, classes);

		int[] discretizeData = discretization.discretizeData(data);

		numberOfDiscreteValues = discretization.getDiscretizationScheme(0).length + 1;

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		return symmetricalUncertainty(discretizeData, numberOfDiscreteValues, classes, numberOfClasses);
	}

	public static double symmetricalUncertainty(int[] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		Set<Integer> discreteValues = new HashSet<Integer>();

		for (int i = 0; i < data.length; i++) {
			if (!discreteValues.contains(data[i])) {
				discreteValues.add(data[i]);
			}
		}

		numberOfDiscreteValues = discreteValues.size();

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		return symmetricalUncertainty(data, numberOfDiscreteValues, classes, numberOfClasses);
	}

	public static double symmetricalUncertainty(int[] data, int numberOfDiscreteValues, int[] classes, int numberOfClasses) {
		double symmetrical = 0;

		int[] contingencyTable = ContingencyTable.generateTable(data, numberOfDiscreteValues, classes, numberOfClasses);

		int[] boundaryColumns = ContingencyTable.boundaryTable(classes, numberOfClasses);

		double[] conditional = ContingencyTable.generateConditionalTableOverColumns(contingencyTable, boundaryColumns, numberOfDiscreteValues, numberOfClasses);

		double conditionalEntropy = ContingencyTable.conditionalEntropyColumns(conditional, boundaryColumns, data.length);

		symmetrical = 2 * ((ContingencyTable.entropy(data,numberOfDiscreteValues) - conditionalEntropy) /
				(ContingencyTable.entropy(data,numberOfDiscreteValues) + ContingencyTable.entropy(classes,numberOfClasses)));

		return symmetrical;
	}

	public static double[] chiSquareScore(double[][] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		double[] SU = new double[data[0].length];
		int[] oneFeature = new int[data.length];

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		CACC cacc = new CACC();
		Discretization discretization = cacc.caccDiscretization(data, classes);

		int[][] discretizeData = discretization.discretizeData(data);

		for(int i = 0; i < SU.length; i++) {

			for (int j = 0; j < oneFeature.length; j++) {
				oneFeature[j] = discretizeData[j][i];
			}

			numberOfDiscreteValues = discretization.getDiscretizationScheme(i).length + 1;

			SU[i] = chiSquareScore(oneFeature, numberOfDiscreteValues, classes, numberOfClasses);
		}

		return SU;
	}

	public static double[] chiSquareScore(int[][] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		double[] SU = new double[data[0].length];
		int[] oneFeature = new int[data.length];
		Set<Integer> discreteValues;

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		for(int i = 0; i < SU.length; i++) {

			discreteValues = new HashSet<Integer>();

			for (int j = 0; j < oneFeature.length; j++) {
				oneFeature[j] = data[j][i];
				if (!discreteValues.contains(oneFeature[j])) {
					discreteValues.add(oneFeature[j]);
				}
			}

			numberOfDiscreteValues = discreteValues.size();

			SU[i] = chiSquareScore(oneFeature, numberOfDiscreteValues, classes, numberOfClasses);
		}

		return SU;
	}

	public static double chiSquareScore(double[] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		CACC cacc = new CACC();
		Discretization discretization = cacc.caccDiscretization(data, classes);

		int[] discretizeData = discretization.discretizeData(data);

		numberOfDiscreteValues = discretization.getDiscretizationScheme(0).length + 1;

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		return chiSquareScore(discretizeData, numberOfDiscreteValues, classes, numberOfClasses);
	}

	public static double chiSquareScore(int[] data, int[] classes) {
		int numberOfDiscreteValues = 0;
		int numberOfClasses = 0;

		Set<Integer> discreteValues = new HashSet<Integer>();

		for (int i = 0; i < data.length; i++) {
			if (!discreteValues.contains(data[i])) {
				discreteValues.add(data[i]);
			}
		}

		numberOfDiscreteValues = discreteValues.size();

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		numberOfClasses = uniqueClasses.size();

		return chiSquareScore(data, numberOfDiscreteValues, classes, numberOfClasses);
	}

	public static double chiSquareScore(int[] data, int numberOfDiscreteValues, int[] classes, int numberOfClasses) {
		double chiScore = 0.0;

		int[] table = ContingencyTable.generateTable(data, numberOfDiscreteValues, classes, numberOfClasses);

		int[] boundaryA1 = ContingencyTable.boundaryTable(data, numberOfDiscreteValues);

		int[] boundaryA2 = ContingencyTable.boundaryTable(classes, numberOfClasses);

		double N = (double)IntStream.of(table).sum();
		int M,P;
		double exp;

		for(int i = 0, nRow = numberOfDiscreteValues, nCol = numberOfClasses; i < nRow; i++) {
			for(int j = 0; j < nCol; j++) {
				M = boundaryA1[i];
				P = boundaryA2[j];

				exp = M*P/N;

				chiScore += Math.pow((exp-table[i*nCol+j]), 2)/exp;
			}
		}

		return chiScore;
	}

	public static double[] ReliefF(double[][] data, int[] classes) {
		int kMax = 10;
		int m = data.length;
		int seed = 1;

		return ReliefF(data, classes, kMax, m, seed);
	}

	public static double[] ReliefF(double[][] data, int[] classes, int kMax) {
		int m = data.length;
		int seed = 1;

		return ReliefF(data, classes, kMax, m, seed);
	}

	public static double[] ReliefF(double[][] data, int[] classes, int kMax, int m) {
		int seed = 1;

		return ReliefF(data, classes, kMax, m, seed);
	}

	public static double[] ReliefF(double[][] data, int[] classes, int kMax, int m, int seed) {
		int numberOfFeatures = data[0].length;
		int instancesNum = data.length;

		double[] weights = new double[numberOfFeatures];

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		int numberOfClasses = uniqueClasses.size();

		double minFeatureValue[] = new double[numberOfFeatures];
		double maxFeatureValue[] = new double[numberOfFeatures];

		for(int i = 0; i < numberOfFeatures; i++){
			minFeatureValue[i] = data[0][i];
			maxFeatureValue[i] = data[0][i];
		}

		for(int i = 0; i < numberOfFeatures; i++){
			for(int j = 1; j < instancesNum; j++) {
				if (data[j][i] < minFeatureValue[i]) {
					minFeatureValue[i] = data[j][i];
				}
				if (data[j][i] > maxFeatureValue[i]) {
					maxFeatureValue[i] = data[j][i];
				}
			}
		}

		Random rand = new Random(seed);

		for(int i = 0; i < m; i++) {
			int randomIns = i;

			if(instancesNum < 10) {
				randomIns = rand.nextInt() % instancesNum;
				if(randomIns < 0) {
					randomIns *= -1;
				}
			}

			// randomly select an instance Ri
			double Ri[] = data[randomIns];

			// HITS
			List<indexValue> nearestHits = findKNN(data, classes, Ri, randomIns, classes[randomIns], minFeatureValue, maxFeatureValue);


			List<List<indexValue>> misses = new ArrayList<>();

			// MISSES
			for(int j = 0, n = numberOfClasses; j < n; j++) {
				int label = j;
				if(label == classes[randomIns]) continue;

				List<indexValue> nearestMisses = findKNN(data, classes, Ri, randomIns, label, minFeatureValue, maxFeatureValue);
				misses.add(nearestMisses);
			}

			for(int feature = 0, n = numberOfFeatures; feature < n; feature++) {
				int K = nearestHits.size();
				if(K >= kMax) K = kMax;

				double sumNearestHitsDiff = 0.0;
				for(int f = 0; f < K; f++) {
					sumNearestHitsDiff += diff(Ri, data[nearestHits.get(f).getIndex()], feature, minFeatureValue, maxFeatureValue);
				}

				if(K > 0.0) {
					sumNearestHitsDiff /= (K*instancesNum);
				}

				double sumPerClass = 0.0;

				for(int j = 0, c = 0; j < numberOfClasses; j++) {
					int label = j;
					if(label == classes[randomIns]) continue;

					double sumNearestMiss = 0.0;

					List<indexValue> missesForClass = misses.get(c++);

					K = missesForClass.size();
					if(K >= kMax) K = kMax;

					for(int f = 0; f < K; f++) {
						sumNearestMiss += diff(Ri, data[missesForClass.get(f).getIndex()], feature, minFeatureValue, maxFeatureValue);
					}

					double p = ((FeatureSelectionMethods.prob(label, classes))/(1-FeatureSelectionMethods.prob(classes[randomIns], classes)));
					sumNearestMiss *= p;

					if(K > 0) {
						sumNearestMiss /= (K*instancesNum);
					}

					sumPerClass += sumNearestMiss;
				}

				weights[feature] = weights[feature] - sumNearestHitsDiff + sumPerClass;
			}
		}
		return weights;
	}

	public static double[] ReliefF(int[][] data, int[] classes) {
		int kMax = 10;
		int m = data.length;
		int seed = 1;

		return ReliefF(data, classes, kMax, m, seed);
	}

	public static double[] ReliefF(int[][] data, int[] classes, int kMax) {
		int m = data.length;
		int seed = 1;

		return ReliefF(data, classes, kMax, m, seed);
	}

	public static double[] ReliefF(int[][] data, int[] classes, int kMax, int m) {
		int seed = 1;

		return ReliefF(data, classes, kMax, m, seed);
	}

	public static double[] ReliefF(int[][] data, int[] classes, int kMax, int m, int seed) {
		int numberOfFeatures = data[0].length;

		double[] weights = new double[numberOfFeatures];

		Set<Integer> uniqueClasses = new HashSet<Integer>();

		for (int i = 0; i < classes.length; i++) {
			if (!uniqueClasses.contains(classes[i])) {
				uniqueClasses.add(classes[i]);
			}
		}

		int numberOfClasses = uniqueClasses.size();

		Random rand = new Random(seed);
		int instancesNum = data.length;

		for(int i = 0; i < m; i++) {
			int randomIns = i;

			if(instancesNum < 10) {
				randomIns = rand.nextInt() % instancesNum;
				if(randomIns < 0) {
					randomIns *= -1;
				}
			}

			// randomly select an instance Ri
			int Ri[] = data[randomIns];

			// HITS
			List<indexValue> nearestHits = findKNN(data, classes, Ri, randomIns, classes[randomIns]);


			List<List<indexValue>> misses = new ArrayList<>();

			// MISSES
			for(int j = 0, n = numberOfClasses; j < n; j++) {
				int label = j;
				if(label == classes[randomIns]) continue;

				List<indexValue> nearestMisses = findKNN(data, classes, Ri, randomIns, label);
				misses.add(nearestMisses);
			}

			for(int feature = 0, n = numberOfFeatures; feature < n; feature++) {
				int K = nearestHits.size();
				if(K >= kMax) K = kMax;

				double sumNearestHitsDiff = 0.0;
				for(int f = 0; f < K; f++) {
					sumNearestHitsDiff += diff(Ri, data[nearestHits.get(f).getIndex()], feature);
				}

				if(K > 0.0) {
					sumNearestHitsDiff /= (K*instancesNum);
				}

				double sumPerClass = 0.0;

				for(int j = 0, c = 0; j < numberOfClasses; j++) {
					int label = j;
					if(label == classes[randomIns]) continue;

					double sumNearestMiss = 0.0;

					List<indexValue> missesForClass = misses.get(c++);

					K = missesForClass.size();
					if(K >= kMax) K = kMax;

					for(int f = 0; f < K; f++) {
						sumNearestMiss += diff(Ri, data[missesForClass.get(f).getIndex()], feature);
					}

					double p = ((FeatureSelectionMethods.prob(label, classes))/(1-FeatureSelectionMethods.prob(classes[randomIns], classes)));
					sumNearestMiss *= p;

					if(K > 0) {
						sumNearestMiss /= (K*instancesNum);
					}

					sumPerClass += sumNearestMiss;
				}

				weights[feature] = weights[feature] - sumNearestHitsDiff + sumPerClass;
			}
		}
		return weights;
	}

	private static double prob(int label, int[] classes) {
		int sum = 0;
		for(int i = 0, n = classes.length; i < n; i++) {
			if(label == classes[i])
				sum++;
		}
		return sum/(double)classes.length;
	}

	private static List<indexValue> findKNN(int[][] data, int[] classes, int[] Ri, int index, int classLabel) {

		List<indexValue> values = new ArrayList<>();

		for (int j = 0, n = data.length; j < n; j++) {
			int[] instance = data[j];

			/*
			int equal = 1;

			for(int k = 0, m = Ri.length; k < m; k++) {
				if (instance[k] != Ri[k]) {
					equal = 0;
					continue;
				}
			}
			*/

			if(index == j) continue; //if(equal == 1) continue;
			if(! (classes[j] == classLabel)) continue;

			values.add(new indexValue(distance(instance, Ri), j));
		}

		values.sort((a,b) -> a.getValue().compareTo(b.getValue()));

		return values;
	}

	private static List<indexValue> findKNN(double[][] data, int[] classes, double[] Ri, int index, int classLabel, double minFeatureValue[], double maxFeatureValue[]) {

		List<indexValue> values = new ArrayList<>();

		for (int j = 0, n = data.length; j < n; j++) {
			double[] instance = data[j];

			if(index == j) continue;
			if(! (classes[j] == classLabel)) continue;

			values.add(new indexValue(distance(instance, Ri, minFeatureValue, maxFeatureValue), j));
		}

		values.sort((a,b) -> a.getValue().compareTo(b.getValue()));

		return values;
	}

	private static double distance(int[] instance1, int[] instance2) {
		double dist = 0.0;

		for(int i = 0, n = instance1.length; i < n; i++) {
			dist += diff(instance1, instance2, i);
		}

		return dist;
	}

	private static double distance(double[] instance1, double[] instance2, double minFeatureValue[], double maxFeatureValue[]) {
		double dist = 0.0;

		for(int i = 0, n = instance1.length; i < n; i++) {
			dist += diff(instance1, instance2, i, minFeatureValue, maxFeatureValue);
		}

		return dist;
	}

	private static double diff(int[] instance1, int[] instance2, int i) {

		if( instance1[i] != instance2[i]) {
			return 1.0;
		} else {
			return 0.0;
		}
	}

	private static double diff(double[] instance1, double[] instance2, int i, double minFeatureValue[], double maxFeatureValue[]) {

		return (Math.abs(instance1[i] - instance2[i]) / (maxFeatureValue[i] - minFeatureValue[i]));
	}
}