package multisab.processing.machineLearning.RFcore;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

public class ExtraTrees extends RandomForest {

	public ExtraTrees(List<String> params) {
		super(params);
		// TODO Auto-generated constructor stub
	}

	
protected Dataset subsample(Dataset dataset, double ratio){
		
		
		
		return dataset;
		
		
}
	
	protected Node getSplit(Dataset dataset, int nFeatures){
		
		Set<Integer> classValues = dataset.dictionaries.get( dataset.getEntries().get(0).getFields().size()-1).keySet();
		Node node = new Node();
		int b_index = 999;
		Field b_value = null;
		double b_score = 999.9;
		List<List<Entry>> groups, b_groups = new ArrayList<>();
		List<Integer> features = new ArrayList<>();
		
		while(features.size()< nFeatures){
			int range = dataset.getEntries().get(0).getFields().size();
			double random = Algorithm.generator.nextDouble();
			int index = (int) (random * (range-1));
			if (!features.stream().anyMatch(x-> x.intValue()==index)){
				features.add(index);
			}
		}
		for (Integer i: features){
			
			double min = (double) dataset.getEntries().get(0).getFields().get(i).getValue();
			double max = (double) dataset.getEntries().get(0).getFields().get(i).getValue();
			
			
			// get min and max values
			for (Entry e: dataset.getEntries()){
					
				if ((double)e.getFields().get(i).getValue()<min)
					min = (double)e.getFields().get(i).getValue();
				
				if ((double)e.getFields().get(i).getValue()>max)
					max = (double)e.getFields().get(i).getValue();				
					
			}
			
			//pick a random value between min and max	
			double random = Algorithm.generator.nextDouble();
			double value = (double) (random * ((max-min))) + min;
			
			
			
					groups = extraTestSplit(i, value, dataset);
					double gini = giniIndex(groups, classValues);
					if(gini<b_score){
						b_index = i;
						RealValField rvf = new RealValField(i);
						rvf.setValue(value);
						b_value = rvf; 
						b_score = gini;
						b_groups.clear();
						b_groups.add(groups.get(0));
						b_groups.add(groups.get(1));
						
					}
				
			
		}
		
		node.setIndex(b_index);
		node.setValue(b_value);
		node.setLeftData(b_groups.get(0));
		node.setRightData(b_groups.get(1));

		return node;
		
	}
	
	private List<List<Entry>>  extraTestSplit(int index, double value, Dataset dataset){
		
		List<List<Entry>>  groups = new ArrayList<>();
		List<Entry> left = new ArrayList<>();
		List<Entry> right = new ArrayList<>();
		
		for(Entry e: dataset.getEntries()){
			int comparisonRes= e.getFields().get(index).compareValues(e.getFields().get(index).getValue(), value);
			if(comparisonRes == -1){ //v1<v2	
				left.add(e);
			}
			else{
				right.add(e);
			}				
			
		}
		
		groups.add(left);
		groups.add(right);
		
		return groups;
	}
	
}
