package multisab.processing.ecgAnalysis.ecgFiducialPointsDetection.qrsDetection;

import multisab.processing.preprocessing.filtering.*;
import multisab.processing.preprocessing.iirj.*;
import multisab.processing.preprocessing.otherFunctions.*;
import java.util.Arrays;

/**
 *
 * This method implements Pan-Tompinks QRS detection algorithm.
 * Implementation is based on their paper A real-time QRS Detection algoritham (1985)
 * and MATLAB implementation by Hooman Sedghamizs, Linkoping University.
 *
 * @author Krešimir Friganović
 * Date: 5.12.2016.
 *
 * Modified Davor
 *
 */
public class PanTompkins {

    public static double [] DetectRSpike(double [] ecg, double fs) {

        // Initialization
        // Linear transformation arrays
        double[] ecgFiltered;
        double[] ecgDerivated;
        double[] ecgSquared;
        double[] ecgIntegrated;

        //Fiducial marks
        double [] pks;

        //Variables
        int n = ecg.length;



        //*******************LINEAR AND NONLINEAR TRANSFORMATION OF ECG SIGNAL *************************************/
        //**********************************************************************************************************/

        //Filter ecg signal (8 - 20 Hz)
        ecgFiltered = PanTompkinsFilterECG(ecg, fs);
        ecgFiltered = ArrayFunctions.Normalization(ecgFiltered);

        //Derivate filtered ecg signal with five-point derivative
        ecgDerivated = PanTompkinsDerivateECG(ecgFiltered, fs);
        ecgDerivated = ArrayFunctions.Normalization(ecgDerivated);

        //Squaring ecg signal
        ecgSquared = PanTompkinsSquaringECG(ecgDerivated);
        ecgSquared = ArrayFunctions.Normalization(ecgSquared);

        //Moving window integration - windows size 0.15s (150ms)
        ecgIntegrated = PanTompkinsIntegrateECG(ecgSquared, fs, 0.15);
        ecgIntegrated = ArrayFunctions.Normalization(ecgIntegrated);



        //***********FINDING R SPIKE FIDUCIAL MARKS OF TRANSFORMED ECG SIGNAL*******************************************/
        //**************************************************************************************************************/

        //Find fiducial marks that are local maximum's with 200ms peak distance
        pks = FindLocalExtremes.findMaximumWithMinPeakDistance(ecgIntegrated, (int) (0.2 * fs));


        //****************INITIALIZING THE THRESHOLDS*******************************************************************/
        //**************************************************************************************************************/

        //Integrated ECG thresholds
        double[] maxI = FindLocalExtremes.findMaximum(ecgIntegrated,0,(int)(2*fs));
        double meanI = ArrayFunctions.Mean(ecgIntegrated,0,(int)(2*fs));
        double THR_SIG = maxI[0] * 0.3;
        double THR_NOISE = meanI/(2*fs) * 0.5;
        double NOISE_LEV = meanI/(2*fs) * 0.5;
        double SIG_LEV = THR_SIG;

        //Filtered ECG thresholds
        double [] maxF =FindLocalExtremes.findMaximum(ecgFiltered,0,(int)(2*fs));
        double meanF = ArrayFunctions.Mean(ecgFiltered,0,(int)(2*fs));
        double THR_SIG1 = maxF[0] * 0.3;
        double THR_NOISE1 = meanI/(2*fs) * 0.5;
        double NOISE_LEV1 = meanF/(2*fs) * 0.5;
        double SIG_LEV1 = THR_SIG1;


        /**VARIOUS VARIABLES*******************************************************************************************/
        /**************************************************************************************************************/
        double yi; //Corresponding peak in filtered ECG
        int xi; // Location of corresponding peak in filtered ECG
        double [] yi_t;
        int xi_t;
        int ser_back;
        int qrs_c_length = 0;
        double [] qrs_i = new double[n];
        double [] qrs_c = new double[n];
        double [] diffRR;
        double mean_RR = 0;
        double m_selected_RR = 0;
        double comp;
        double test_m;
        double [] pks_temp;
        int locs_temp;
        int not_nois = 0;
        double Slope1;
        double Slope2;
        double [] temp1 = new double[(int)(0.075*fs)];
        int skip = 0;
        int k = 0;


        /*****************Threshold adaptation and online decision*****************************************************/
        /**************************************************************************************************************/
        for (int i=0; i < n; i++){ //TODO: optimization - loop only through peaks, not every sample
            if (pks[i] > 0) {

                //Locate corresponding peak in filtered ECG (has less delay, so search is backwards)
                yi=0;
                xi = 0;
                if ((i - (int)(0.15*fs) >= 0 && (i <=n))){

                    for (int j = i - (int)(0.15*fs); j < i; j++) {
                        if (ecgFiltered[j] > yi) {
                            yi = ecgFiltered[i];
                            xi = j;
                        }
                    }
                }
                else if (i <= (int)(0.15*fs)){
                    for (int j = 0; j < i; j++) {
                        if (ecgFiltered[j] > yi) {
                            yi = ecgFiltered[j];
                            xi = j;
                        }
                    }
                }


                //Update the average RR interval
                if (qrs_c_length >= 8){
                    diffRR = RRIntervals.findRRIntervals(Arrays.copyOfRange(qrs_i,qrs_c_length - 8,qrs_c_length));
                    mean_RR = ArrayFunctions.Mean(diffRR,0,diffRR.length);
                    comp = qrs_i[qrs_c_length - 1] - qrs_i[qrs_c_length - 2];

                    if (comp <= 0.92*mean_RR || comp >= 1.16*mean_RR){
                         //lower down thresholds to detect better in integrated ECG
                            THR_SIG = 0.5*(THR_SIG);
                        //lower down thresholds to detect better in filtered ECG
                            THR_SIG1 = 0.5*(THR_SIG1);
                    }
                    else {
                        //If last RR interval is "okay", use last 8 RR intervals for RRaverage
                        m_selected_RR = mean_RR;
                    }
                }


                //calculate the mean of the last 8 R waves to make sure that QRS is
                // missing(If no R detected , trigger a search back) 1.66*mean

                if (m_selected_RR > 0) {
                    test_m = m_selected_RR; //if the regular RR availabe use it
                }
                else if (mean_RR > 0 && m_selected_RR == 0) {
                    test_m = mean_RR;
                }
                else {
                    test_m = 0;
                }


                if (test_m > 0){
                    if ((i - qrs_i[qrs_c_length]) >= (int)(1.66*test_m)){ //if qrs is missed
                        pks_temp = FindLocalExtremes.findMaximum(ecgIntegrated, (int)(qrs_i[qrs_c_length] + (int)(0.2*fs)), i - (int)(0.2*fs) );
                        locs_temp = (int)pks_temp[1];
                        locs_temp = (int)(qrs_i[qrs_c_length]+ (int)(0.2*fs) + locs_temp - 1);

                        if (pks_temp[0] > THR_NOISE) {

                            qrs_c[i] = 1;
                            qrs_i[qrs_c_length] = locs_temp;
                            qrs_c_length++;
                            // find the location in filtered signal
                            if (locs_temp <= ecgFiltered.length){
                                yi_t = FindLocalExtremes.findMaximum(ecgFiltered,(int)(locs_temp - 0.15*fs),locs_temp);
                            }
                            else {
                                yi_t = FindLocalExtremes.findMaximum(ecgFiltered,(int)(locs_temp - 0.15*fs),ecgFiltered.length);
                            }

                            //take care of bandpass signal threshold
                            if (yi_t[0] > THR_NOISE1){
                                SIG_LEV1 = 0.25*yi_t[0] + 0.75*SIG_LEV1; //when found with the second thresholds
                            }

                            not_nois = 1;
                            SIG_LEV = 0.25*pks_temp[0] + 0.75*SIG_LEV;
                        }
                    else {
                            not_nois = 0;
                        }
                    }
                }



                //Find noise and QRS peaks
                if (pks[i] >= THR_SIG) {

                    if (qrs_c_length > 2) {
                        if ((i - qrs_i[qrs_c_length] < (int) (0.36 * fs))) {
                            k = 0;
                            for (int j = i - (int) (0.075 * fs); j < i; j++) {

                                temp1[k] = ecgIntegrated[j - 1] - ecgIntegrated[j];
                                k ++;
                            }
                            Slope1 = ArrayFunctions.Mean(temp1, 0, temp1.length);
                            k = 0;
                            for (int j = (int) (qrs_i[qrs_c_length] - (int) (0.075 * fs)); j < (int) (qrs_i[qrs_c_length]); j++) {
                                temp1[k] = ecgIntegrated[j - 1] - ecgIntegrated[j];
                                k ++;
                            }
                            Slope2 = ArrayFunctions.Mean(temp1, 0, temp1.length);


                            if (Math.abs(Slope1) <= Math.abs(0.5 * Slope2)) {
                                skip = 1;
                                NOISE_LEV1 = 0.125 * yi + 0.875 * NOISE_LEV1;
                                NOISE_LEV = 0.125 * pks[i] + 0.875 * NOISE_LEV;
                            } else {
                                skip = 0;
                            }
                        }
                    }


                    //Skip is 1 if T wave is detected
                    if (skip == 0) {

                        qrs_c[i] = 1;;
                        qrs_i[qrs_c_length] = i;
                        qrs_c_length++;

                        if (yi > THR_SIG1) {
                            SIG_LEV1 = 0.125 * yi + 0.875 * SIG_LEV1;
                        }
                        SIG_LEV = 0.125 * pks[i] + 0.875 * SIG_LEV;
                    } else if (THR_NOISE <= pks[i] && pks[i] < THR_SIG) {
                        //adjust Noise level in filtered sig
                        NOISE_LEV1 = 0.125 * yi + 0.875 * NOISE_LEV1;
                        //adjust Noise level in MVI
                        NOISE_LEV = 0.125 * pks[i] + 0.875 * NOISE_LEV;
                    } else if (pks[i] < THR_NOISE) {
                        //noise level in filtered signal
                        NOISE_LEV1 = 0.125 * yi + 0.875 * NOISE_LEV1;
                        //%adjust Noise level in MVI
                        NOISE_LEV = 0.125 * pks[i] + 0.875 * NOISE_LEV;
                    }


                }

                //adjust the threshold with SNR
                if( (NOISE_LEV != 0) || (SIG_LEV != 0)) {
                    THR_SIG = NOISE_LEV + 0.25 * (Math.abs(SIG_LEV - NOISE_LEV));
                    THR_NOISE = 0.5 * (THR_SIG);
                }

                 //djust the threshold with SNR for bandpassed signal
                if (NOISE_LEV1 != 0 || SIG_LEV1 != 0) {
                    THR_SIG1 = NOISE_LEV1 + 0.25 * (Math.abs(SIG_LEV1 - NOISE_LEV1));
                    THR_NOISE1 = 0.5 * (THR_SIG1);
                }

            }
            skip = 0; //reset parameters
            m_selected_RR = 0;
            not_nois = 0; //reset parameters
            ser_back = 0;  //reset bandpass param
        }

        return qrs_c;
    }



    //******LINEAR AND NONLINEAR SIGNAL TRANSFORMATIONS USED IN PAN TOMPKINS ALGORITHM**********************************/
    //*****************************************************************************************************************/
    public static double [] PanTompkinsFilterECG(double[] x, double fs) {


        //Lowpass filtering - cutoff freq. about 11 Hz
        FilterTransferFunction lowpass = new FilterTransferFunction(11, fs, "lowBiquad");

        double [] lowPassA = lowpass.getA();
        double [] lowPassB = lowpass.getB();

        double[] lowPassSignal = Filter.filterSignal(x,lowPassB, lowPassA);

        //Highpass filtering - cutoff freq. about 5 Hz
        FilterTransferFunction highpass = new FilterTransferFunction(5, fs, "highBiquad");
        double [] highPassA = lowpass.getA();
        double [] highPassB = lowpass.getB();

        double [] y = Filter.filterSignal(lowPassSignal, highPassB, highPassA);

/*
        //Bandpass filtering (8 - 20 Hz)
        Butterworth butterworth = new Butterworth();
        butterworth.bandPass(3,fs,14,12);
*/
        return y;
    }
    public static double [] PanTompkinsDerivateECG (double [] x, double fs) {


        // Derivative H(z)=1/8T * (-z^-2 - 2z^-1 + z^1 + z^2)
        double d = 8/fs;
        double[] derivativeB = {2/d, 1/d, 0/d, -1/d, -2/d} ;
        double[] derivativeA = {1};

        double [] y = Filter.filterSignal(x, derivativeB, derivativeA);

        return y;
    }
    public static double [] PanTompkinsSquaringECG (double [] x){

        double []  y = new double[x.length];


        for (int i=0; i<x.length; i++){
            y [i] = x[i]*x[i];
        }
        return y;
    }
    public static double [] PanTompkinsIntegrateECG (double [] x, double fs, double window) {


        double []  y = new double[x.length];
        int N=(int)(window*fs); //0.15 * fs (150ms window)
        double [] movingAverageB = new double[N];

        //Filter z transform
        for (int n=0; n < N; n++){
            movingAverageB[n] = 1/(double)N;

        }

        return Filter.filterSignal(x, movingAverageB,new double[] {1});

    }

}

