package multisab.processing.machineLearning.discretization;

import java.util.*;

import static java.lang.Math.log;
import static java.lang.Math.sqrt;

/**
 * Main class for CACC discretization algorithm
 * 'A Discretization Algorithm based on Class-Attribute Contingency Coefficient' (CACC), by Sheng-Jung Tsai, Shien-I Lee and Wei-Pang Yang, Information Sciences, Elsevier, 2008
 *
 * @author Davor Kukolja
 */
public class CACC {
    private int M;
    private int N;
    private int maxNumIntervals;
    private int[] uniqueClasses;

    public Discretization caccDiscretization (double[][] data, int[] classes) {
        M = data.length;
        N = data[0].length;

        findUniqueClasses(classes);

        // Assume the maximum number of intervals is M*0.75
        maxNumIntervals = (int) Math.floor(M*0.75);

        Discretization discretization = new Discretization(N);

        for(int i = 0; i < N; i++) {
            double[] oneFeatureData = new double[M];

            for(int j = 0; j < M; j++) {
                oneFeatureData[j] = data[j][i];
            }
            double[] discretizationScheme = oneFeatureCACC(oneFeatureData, classes);
            discretization.setDiscretizationScheme(i, discretizationScheme);
        }

        return discretization;
    }

    public Discretization caccDiscretization (double[] data, int[] classes) {
        M = data.length;
        N = 1;

        findUniqueClasses(classes);

        // Assume the maximum number of intervals is M*0.75
        maxNumIntervals = (int) Math.floor(M*0.75);

        Discretization discretization = new Discretization(N);

        double[] discretizationScheme = oneFeatureCACC(data, classes);
        discretization.setDiscretizationScheme(0, discretizationScheme);

        return discretization;
    }

    private double[] oneFeatureCACC (double[] data, int[] classes) {


        // find the maximum dn and the minimum d0 values of data
        double d0 = min(data,0,M-1);
        double dn = max(data,0,M-1);

        // form a set of all distinct values of data in ascending order
        Set<Double> distincValuesDataSet = new HashSet<Double>();
        List<Double> distincValuesDataList = new ArrayList<Double>();

        for (int i = 0; i < data.length; i++) {
            if (!distincValuesDataSet.contains(data[i])) {
                distincValuesDataSet.add(data[i]);
                distincValuesDataList.add(data[i]);
            }
        }

        int distincValuesDataNumber = distincValuesDataSet.size();
        double[] distincValuesData = new double[distincValuesDataNumber];

        for (int i = 0; i < distincValuesDataNumber; i++) {
            distincValuesData[i] = (distincValuesDataList.get(i));
        }

        Arrays.sort(distincValuesData);

        // calculate the midpoints of all the adjacent pairs in the set
        List<Double> boundary = new ArrayList<Double>();
        for (int i = 0; i < distincValuesDataNumber-1; i++) {
            boundary.add((distincValuesData[i] + distincValuesData[i+1]) / 2);
        }

        // set the initial discretization scheme as {[d0,dn]} and globalcacc = 0;
        //D = [d0 dn];
        List<Double> discretizationSchemeList = new ArrayList<Double>();
        double globalcacc = 0;

        // initialize k = 1 (well... = 0 in this code), this is for helping the
        // algorithm to stop once we have reached the maximum number of intervals
        int k = 0;

        List<Double> discretizationSchemePrime = new ArrayList<Double>();
        double toRemove = 0;

        // for each inner boundary which is not already in discretization scheme, Add it into discretization scheme
        //     calculate the corresponding cacc value
        //     pick up the scheme with the highest cacc value
        for(int i = 0; i < distincValuesDataNumber-1; i++) {
            List<Double> auxBoundary = new ArrayList<Double>();
            auxBoundary.addAll(boundary);

            double maxcacc = 0;
            while (auxBoundary.size() > 0) {
                if (auxBoundary.get(0) == d0) {
                    auxBoundary.remove(0);
                    continue;
                }
                // add the boundary which is not already in discretization scheme
                discretizationSchemeList.remove(auxBoundary.get(0));
                discretizationSchemeList.add(auxBoundary.get(0));
                Collections.sort(discretizationSchemeList);

                // calculate cacc value
                double caccVal = caccValue(data, discretizationSchemeList, classes);

                // pick up the discretization scheme with the highest cacc value
                if (caccVal > maxcacc) {
                    discretizationSchemePrime.clear();
                    discretizationSchemePrime.addAll(discretizationSchemeList);
                    maxcacc = caccVal;
                    toRemove = auxBoundary.get(0);
                }
                // remove the boundary (since we already tried with it)
                discretizationSchemeList.remove(auxBoundary.get(0));
                auxBoundary.remove(0);
            }

            // if cacc > globalcacc
            //    replace D with D', globalcacc = cacc:
            if (maxcacc > globalcacc) {
                boundary.remove(toRemove);
                discretizationSchemeList.clear();
                discretizationSchemeList.addAll(discretizationSchemePrime);
                globalcacc = maxcacc;
                k = k + 1;
                if (k > maxNumIntervals) { // if we have reached the maximun number of intervals
                    break;                 // we stop and continue with the next attribute
                }
            }
        }

        int discretizationSchemeNumber = discretizationSchemeList.size();
        double[] discretizationScheme = new double[discretizationSchemeNumber];

        for (int i = 0; i < discretizationSchemeNumber; i++) {
            discretizationScheme[i] = (discretizationSchemeList.get(i));
        }

        return discretizationScheme;
    }

    private double caccValue(double[] data, List<Double> discretizationScheme, int[] classes) {
        int M = data.length;
        int i;
        int j;
        int t;
        int numberDiscreteValues = discretizationScheme.size() + 1;

        // Discretize the continuous data and compute the quanta matrix:
        int[] discretizedData = new int[M];
        int[][] quantaMatrix = new int[uniqueClasses.length][numberDiscreteValues];

        for (i = 0; i < uniqueClasses.length; i++) {
            for (j = 0; j < discretizationScheme.size() + 1; j++) {
                quantaMatrix[i][j] = 0;
            }
        }
        for (i = 0; i < M; i++) {
            discretizedData[i] = numberDiscreteValues - 1;
            for (t = 0; t < discretizationScheme.size(); t++) {
                if (data[i] <= discretizationScheme.get(t)) {
                    discretizedData[i] = t;
                    break;
                }
            }
            // Compute quanta matrix
            quantaMatrix[classes[i]][discretizedData[i]] = quantaMatrix[classes[i]][discretizedData[i]] + 1;
        }

        //Compute y value by using the quanta matrix:
        double y = 0;

        int[] rowQuantaMatrix = new int[uniqueClasses.length];
        int[] columnQuantaMatrix = new int[numberDiscreteValues];

        for (i = 0; i < uniqueClasses.length; i++) {
            rowQuantaMatrix[i] = 0;
            for (j = 0; j < discretizationScheme.size() + 1; j++) {
                rowQuantaMatrix[i] += quantaMatrix[i][j];
            }
        }

        for (j = 0; j < discretizationScheme.size() + 1; j++) {
            columnQuantaMatrix[j] = 0;
            for (i = 0; i < uniqueClasses.length; i++) {
                columnQuantaMatrix[j] += quantaMatrix[i][j];
            }
        }

        for (int p = 0; p < uniqueClasses.length; p++) {
            for (int q = 0; q < numberDiscreteValues; q++) {
                if (rowQuantaMatrix[p] > 0 && columnQuantaMatrix[q] > 0) {
                    y = y + (quantaMatrix[p][q] * quantaMatrix[p][q]) / (double)(rowQuantaMatrix[p] * columnQuantaMatrix[q]);
                }
            }
        }

        // Compute y' value from y value:
        double yPrime = M * (y - 1) / log(numberDiscreteValues);

        // Compute CACC value from y' value:
        return sqrt(yPrime / (yPrime + M));
    }

    private void findUniqueClasses(int[] classes) {

        Set<Integer> uniqueClassesSet = new HashSet<Integer>();
        List<Integer> uniqueClassesList = new ArrayList<Integer>();

        for (int i = 0; i < classes.length; i++) {
            if (!uniqueClassesSet.contains(classes[i])) {
                uniqueClassesSet.add(classes[i]);
                uniqueClassesList.add(classes[i]);
            }
        }

        int uniqueClassesNumber = uniqueClassesSet.size();
        uniqueClasses = new int[uniqueClassesNumber];

        for (int i = 0; i < uniqueClassesNumber; i++) {
            uniqueClasses[i] = (uniqueClassesList.get(i));
        }

        // Possible class values
        Arrays.sort(uniqueClasses);
    }

    public static double max(double[] x, int start, int stop){
        int index = start;
        double m = x[start];

        for (int i = start + 1; i < stop + 1; i++){
            if (x[i] > m){
                m = x[i];
                index = i;
            }
        }

        return m;
    }

    public static double min(double[] x, int start, int stop){
        int index = start;
        double m = x[start];

        for (int i = start + 1; i < stop + 1; i++){
            if (x[i] < m){
                m = x[i];
                index = i;
            }
        }

        return m;
    }
}
