package multisab.processing.machineLearning.RFcore;

import multisab.processing.machineLearning.RFcore.*;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class RandomForest extends Algorithm {
	
	protected int maxDepth;
	protected int minSize;
	protected double sampleSize;
	protected int nTrees;
	protected int nFeatures;

	public RandomForest(List<String> params) {
		
		this.maxDepth= Integer.parseInt(params.get(0));
		this.minSize = Integer.parseInt(params.get(1));
		this.sampleSize = Double.parseDouble(params.get(2));
		this.nTrees = Integer.parseInt(params.get(3));
		this.nFeatures = Integer.parseInt(params.get(4));
	}

	@Override
	public List<Object> execute(Dataset trainData, Dataset testData) {
		
		
		List<Object> results = new ArrayList<Object>();
		
		List <Node> trees = new ArrayList<>(); //list of all trees' root nodes
		
		ArrayList<ParallelTreeBuildThread> th = new ArrayList<>();
		Iterator<ParallelTreeBuildThread> it;
		ParallelTreeBuildThread parallelTreeBuildThread;
		
		for (int i=0; i<nTrees; i++){ 
			Dataset sample = subsample(trainData, this.sampleSize);					
			parallelTreeBuildThread = new ParallelTreeBuildThread(sample, this.maxDepth, this.minSize, this.nFeatures, trees, this);
			th.add(parallelTreeBuildThread);
			parallelTreeBuildThread.start();
					
            
               			
					
					
				//	Node tree = buildTree(sample, this.maxDepth, this.minSize, this.nFeatures);
				//	trees.add(tree);
		}
		
	    
      try {
                it = th.iterator();
                while (it.hasNext()){
                    it.next().join();
                }
         }
         catch (InterruptedException iexc){}
				
		for(Entry e: testData.getEntries()){
				int res = baggingPredict(trees, e);
				results.add(res);  
		}
				
				
						
		return results;
	}
	
	protected Dataset subsample(Dataset dataset, double ratio){
		
		Dataset subsample = new Dataset();
		
		
		int	n_sample = (int) Math.round(dataset.getEntries().size() * ratio);
		while (subsample.getEntries().size() < n_sample){		
				double random = Algorithm.generator.nextDouble();
				int index = (int) (random * (dataset.getEntries().size()));
				subsample.addNewEntry(dataset.getEntries().get(index));					
		}
		
		//must copy dataset dictionary
		subsample.dictionaries = dataset.dictionaries;
		
		return subsample;
		
		
	}
	
	public Node buildTree (Dataset sample, int maxDepth, int minSize, int nFeatures){
		Node root = null;
		
		root = getSplit(sample, nFeatures);
		split(root, maxDepth, minSize, nFeatures, 1, sample.dictionaries);

		
		return root;
	}
	
	
	protected Node getSplit(Dataset dataset, int nFeatures){


		Set<Integer> classValues = dataset.dictionaries.get( dataset.getEntries().get(0).getFields().size()-1).keySet();
		Node node = new Node();
		int b_index = 999;
		Field b_value = null;
		double b_score = 999.9;
		List<List<Entry>> groups, b_groups = new ArrayList<>();
		List<Integer> features = new ArrayList<>();
		
		while(features.size()< nFeatures){
			int range = dataset.getEntries().get(0).getFields().size();
			double random = Algorithm.generator.nextDouble();
			int index = (int) (random * (range-1));
			if (!features.stream().anyMatch(x-> x.intValue()==index)){
				features.add(index);
			}
		}
		for (Integer i: features){
				for (Entry e: dataset.getEntries()){
					groups = testSplit(i, e.getFields().get(i), dataset);
					double gini = giniIndex(groups, classValues);
					if(gini<b_score){
						b_index = i;
						b_value = e.getFields().get(i);
						b_score = gini;
						b_groups.clear();
						b_groups.add(groups.get(0));
						b_groups.add(groups.get(1));
						
					}
				}
			
		}
		
		node.setIndex(b_index);
		node.setValue(b_value);
		node.setLeftData(b_groups.get(0));
		node.setRightData(b_groups.get(1));

		return node;
		
	}
	
	
	protected List<List<Entry>>  testSplit(int index, Field field, Dataset dataset){
		
		List<List<Entry>>  groups = new ArrayList<>();
		List<Entry> left = new ArrayList<>();
		List<Entry> right = new ArrayList<>();
		
		for(Entry e: dataset.getEntries()){
			int comparisonRes= field.compareValues(e.getFields().get(index).getValue(), field.getValue());
			if(comparisonRes == -1){ //v1<v2	
				left.add(e);
			}
			else{
				right.add(e);
			}				
			
		}
		
		groups.add(left);
		groups.add(right);
		
		return groups;
	}
	
	
	protected void split(Node node, int maxDepth, int minSize, int nFeatures, int depth, Map<Integer, Map<Integer, String>> dictionaries){
		
		List<Entry> left = node.getLeftData();
		List<Entry> right = node.getRightData();
		
		if(left==null || left.isEmpty()){
			int val = toTerminal(right);
			Node leaf = new Node();
			leaf.setLeaf(true);
			leaf.setIndex(Collections.max(dictionaries.keySet()));
			CategoryField f = new CategoryField(Collections.max(dictionaries.keySet()));
			f.setValue(val);
			leaf.setValue(f);
			node.setLeftNode(leaf);
			node.setRightNode(leaf);
			return;
		}
		if(right==null || right.isEmpty()){
			int val = toTerminal(left);
			Node leaf = new Node();
			leaf.setLeaf(true);
			leaf.setIndex(Collections.max(dictionaries.keySet()));
			CategoryField f = new CategoryField(Collections.max(dictionaries.keySet()));
			f.setValue(val);
			leaf.setValue(f);
			node.setLeftNode(leaf);
			node.setRightNode(leaf);
			return;
		}
		if(depth>= maxDepth){
			int val = toTerminal(left);
			Node leaf = new Node();
			leaf.setLeaf(true);
			leaf.setIndex(Collections.max(dictionaries.keySet()));
			CategoryField f = new CategoryField(Collections.max(dictionaries.keySet()));
			f.setValue(val);
			leaf.setValue(f);
			node.setLeftNode(leaf); 
			
			val = toTerminal(right);
			leaf = new Node();
			leaf.setLeaf(true);
			leaf.setIndex(Collections.max(dictionaries.keySet()));
			f = new CategoryField(Collections.max(dictionaries.keySet()));
			f.setValue(val);
			leaf.setValue(f);
			node.setRightNode(leaf); 
			
			return;
		}
		//process left child
		if(left.size()<=minSize){
			int val = toTerminal(left);
			Node leaf = new Node();
			CategoryField f = new CategoryField(Collections.max(dictionaries.keySet()));
			f.setValue(val);
			leaf.setValue(f);
			leaf.setLeaf(true);
			leaf.setIndex(Collections.max(dictionaries.keySet()));
			node.setLeftNode(leaf); 
		}
		else{
			Dataset tempdata  = new Dataset();
			tempdata.addMultipleEntries((ArrayList<Entry>) left);
			tempdata.dictionaries = dictionaries;
			node.setLeftNode(getSplit(tempdata, nFeatures));
			split(node.getLeftNode(), maxDepth, minSize, nFeatures, depth+1, dictionaries);
		}
		
		//process right child
		if(right.size()<=minSize){
			int val = toTerminal(right);
			Node leaf = new Node();
			leaf.setLeaf(true);
			leaf.setIndex(Collections.max(dictionaries.keySet()));
			CategoryField f = new CategoryField(Collections.max(dictionaries.keySet()));
			f.setValue(val);
			leaf.setValue(f);
			node.setRightNode(leaf); 
		}
		else{
			Dataset tempdata  = new Dataset();
			tempdata.addMultipleEntries((ArrayList<Entry>) right);
			tempdata.dictionaries = dictionaries;
			node.setRightNode(getSplit(tempdata, nFeatures));
			split(node.getRightNode(), maxDepth, minSize, nFeatures, depth+1, dictionaries);
		}
		
		
		
	}
	
	protected double giniIndex(List<List<Entry>> groups, Set<Integer> classValues){
		
		double gini = 0.0;
		
		int nInstances = 0;
		for(List<Entry> group: groups){
			nInstances+= group.size();
		}
		
		for(List<Entry> group: groups){
			int size = group.size();
			if(size == 0){
				continue; // avoid dvision by zero
			}
			double score = 0.0;
			for(Integer classVal: classValues){
				int nrOccurences = 0;
				double p=0.0;
				for(Entry e: group) {
					Integer val = (Integer) e.getFields().get(e.getFields().size()-1).getValue();
					if(val.intValue() == classVal.intValue()){
						nrOccurences ++;
					}					
				}
				p=(double)nrOccurences/size;
				score +=p*p;
			}
			gini += (1.0-score)*((double)size/nInstances);
		}
		
		
		return gini;
	}
	
	
	protected Integer baggingPredict(List<Node> trees, Entry e){
		
		List<Integer> predictions = new ArrayList<>();
		for (Node t: trees){
			int pred = predict(t, e);
			predictions.add(pred);
		}
		
				
		 Set<Integer> unique = new HashSet<Integer>(predictions);
		 
		 int max = 0;
	     int curr = 0;
	     Integer currKey =  null;
	     
	     for (Integer key : unique) {
             curr = Collections.frequency(predictions, key);

            if(max < curr){
              max = curr;
              currKey = key;
             }
         }
		
		return currKey;
		
		
		
		
	}
	
	protected Integer predict(Node t, Entry e){
		
		Field f = e.getFields().get(t.getIndex());
		
		int comparisonRes= f.compareValues(e.getFields().get(t.getIndex()).getValue(), t.getValue().getValue());
		if(comparisonRes == -1){ //v1<v2	
			if (!t.isLeaf()){
				return predict(t.getLeftNode(), e);
			}
			else{
				
				Integer retVal = (Integer) t.getValue().getValue();
				return  retVal;

				
			}
		}
		else{
			if (!t.isLeaf()){
				return predict(t.getRightNode(), e);
			}
			else{
				//System.out.println("right to terminal");
				Integer retVal = (Integer) t.getValue().getValue();
				return  retVal;
			}
		}
				
			
	}
	
	protected int toTerminal(List<Entry> group){
		
		List<Integer> outcomes = new ArrayList<>();
		
		for (Entry e: group){			
				outcomes.add((Integer) e.getFields().get(e.getFields().size()-1).getValue());			
		}
		
		 Set<Integer> unique = new HashSet<Integer>(outcomes);
		 
		 int max = 0;
	     int curr = 0;
	     Integer currKey =  null;
	     
	     for (Integer key : unique) {
             curr = Collections.frequency(outcomes, key);

            if(max < curr){
              max = curr;
              currKey = key;
             }
         }
		
		return currKey;
	}

}
