/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#include <thread>
#include <chrono>
#include <iterator>
#include <cmath>

#include <Rcpp.h>

#include "Tree.h"
#include "utility.h"

namespace unityForest
{

  Tree::Tree() : dependent_varID(0), mtry(0), nsplits(0), npairs(0), proptry(0.0), prop_var_root(0), num_samples(0), num_samples_oob(0), min_node_size(0), min_node_size_root(0), deterministic_varIDs(0), split_select_varIDs(
                                                                                                                                                           0),
                 split_select_weights(0), case_weights(0), manual_inbag(0), oob_sampleIDs(0), promispairs(0), eim_mode(0), divfortype(0), metricind(0), holdout(false), keep_inbag(
                                                                                                                  false),
                 data(0), variable_importance(0), importance_mode(DEFAULT_IMPORTANCE_MODE), sample_with_replacement(
                                                                                                true),
                 sample_fraction(0), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(
                                                                                                                             DEFAULT_MINPROP),
                 num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(DEFAULT_MAXDEPTH), max_depth_root(DEFAULT_MAXDEPTHROOT), num_cand_trees(DEFAULT_NUMCANDTREES), depth(0), last_left_nodeID(0), last_left_nodeID_loop(0)
  { 
  }

  Tree::Tree(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
             std::vector<double> &split_values, std::vector<size_t> &split_types, std::vector<std::vector<size_t>> &split_multvarIDs,
             std::vector<std::vector<std::vector<bool>>> &split_directs,
             std::vector<std::vector<std::vector<double>>> &split_multvalues) : dependent_varID(0), mtry(0), nsplits(0), npairs(0), proptry(0.0), prop_var_root(0), num_samples(0), num_samples_oob(0), min_node_size(0), min_node_size_root(0), deterministic_varIDs(0), split_select_varIDs(0), split_select_weights(0), case_weights(0), manual_inbag(0), split_varIDs(split_varIDs), split_values(split_values), split_types(split_types), split_multvarIDs(split_multvarIDs), split_directs(split_directs), split_multvalues(split_multvalues), child_nodeIDs(child_nodeIDs), oob_sampleIDs(0), promispairs(0), eim_mode(0), divfortype(0), metricind(0), holdout(false), keep_inbag(false), data(0), variable_importance(0), importance_mode(DEFAULT_IMPORTANCE_MODE), sample_with_replacement(true), sample_fraction(0), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(DEFAULT_MAXDEPTH), max_depth_root(DEFAULT_MAXDEPTHROOT), num_cand_trees(DEFAULT_NUMCANDTREES), depth(0), last_left_nodeID(0), last_left_nodeID_loop(0)
  {
  }
  
  // Unity Forests: Constructor for repr_tree_mode:
  Tree::Tree(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
             std::vector<double> &split_values, std::vector<size_t> &split_types, std::vector<std::vector<size_t>> &split_multvarIDs,
             std::vector<std::vector<std::vector<bool>>> &split_directs,
             std::vector<std::vector<std::vector<double>>> &split_multvalues, const Data* data_ptr) : dependent_varID(0), mtry(0), nsplits(0), npairs(0), proptry(0.0), prop_var_root(0), num_samples(0), num_samples_oob(0), min_node_size(0), min_node_size_root(0), deterministic_varIDs(0), split_select_varIDs(0), split_select_weights(0), case_weights(0), manual_inbag(0), split_varIDs(split_varIDs), split_values(split_values), split_types(split_types), split_multvarIDs(split_multvarIDs), split_directs(split_directs), split_multvalues(split_multvalues), child_nodeIDs(child_nodeIDs), oob_sampleIDs(0), promispairs(0), eim_mode(0), divfortype(0), metricind(0), holdout(false), keep_inbag(false), data(data_ptr), variable_importance(0), importance_mode(DEFAULT_IMPORTANCE_MODE), sample_with_replacement(true), sample_fraction(0), memory_saving_splitting(false), splitrule(DEFAULT_SPLITRULE), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(DEFAULT_MAXDEPTH), max_depth_root(DEFAULT_MAXDEPTHROOT), num_cand_trees(DEFAULT_NUMCANDTREES), depth(0), last_left_nodeID(0), last_left_nodeID_loop(0)
  {
  }

  void Tree::init(const Data *data, uint mtry, uint nsplits, uint npairs, double proptry, double prop_var_root, size_t dependent_varID, size_t num_samples, uint seed,
                  std::vector<size_t> *deterministic_varIDs, std::vector<size_t> *split_select_varIDs,
                  std::vector<double> *split_select_weights, ImportanceMode importance_mode, uint min_node_size, uint min_node_size_root,
                  bool sample_with_replacement, bool memory_saving_splitting, SplitRule splitrule, std::vector<double> *case_weights,
                  std::vector<size_t> *manual_inbag, bool keep_inbag, std::vector<double> *sample_fraction, double alpha,
                  double minprop, bool holdout, uint num_random_splits, uint max_depth, uint max_depth_root, uint num_cand_trees, std::vector<std::vector<size_t>> *promispairs, uint eim_mode, uint divfortype, std::vector<size_t> *metricind, std::vector<size_t> repr_vars)
  {

    this->data = data;
    this->mtry = mtry;
    this->dependent_varID = dependent_varID;
    this->num_samples = num_samples;
    this->memory_saving_splitting = memory_saving_splitting;
    this->nsplits = nsplits;
	this->npairs = npairs;
    this->proptry = proptry;
	this->prop_var_root = prop_var_root;
    this->divfortype = divfortype;
	this->metricind = metricind;
    this->repr_vars = repr_vars;

    // Create root node, assign bootstrap sample and oob samples
    child_nodeIDs.push_back(std::vector<size_t>());
    child_nodeIDs.push_back(std::vector<size_t>());
    if (divfortype == 1)
    {
      createEmptyNode();
    }
    if (divfortype == 2)
    {
      createEmptyNodeMultivariate();
    }
    if (divfortype == 3)
    {
      createEmptyNodeInternal();
    }
	if (divfortype == 4)
    {
	  createEmptyNodeFullTree();
	}
	
    // Initialize random number generator and set seed
    random_number_generator.seed(seed);

    this->deterministic_varIDs = deterministic_varIDs;
    this->split_select_varIDs = split_select_varIDs;
    this->split_select_weights = split_select_weights;
    this->importance_mode = importance_mode;
    this->min_node_size = min_node_size;
    this->min_node_size_root = min_node_size_root;
    this->sample_with_replacement = sample_with_replacement;
    this->splitrule = splitrule;
    this->case_weights = case_weights;
    this->manual_inbag = manual_inbag;
    this->keep_inbag = keep_inbag;
    this->sample_fraction = sample_fraction;
    this->holdout = holdout;
    this->alpha = alpha;
    this->minprop = minprop;
    this->num_random_splits = num_random_splits;
    this->eim_mode = eim_mode;
    this->max_depth = max_depth;
	this->max_depth_root = max_depth_root;
	this->num_cand_trees = num_cand_trees;
    this->promispairs = promispairs;
  }

  void Tree::grow(std::vector<double> *variable_importance)
  {

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Allocate memory for tree growing
    allocateMemory();

    this->variable_importance = variable_importance;

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Bootstrap, dependent if weighted or not and with or without replacement
    if (!case_weights->empty())
    {
      if (sample_with_replacement)
      {
        bootstrapWeighted();
      }
      else
      {
        bootstrapWithoutReplacementWeighted();
      }
    }
    else if (sample_fraction->size() > 1)
    {
      if (sample_with_replacement)
      {
        bootstrapClassWise();
      }
      else
      {
        bootstrapWithoutReplacementClassWise();
      }
    }
    else if (!manual_inbag->empty())
    {
      setManualInbag();
    }
    else
    {
      if (sample_with_replacement)
      {
        bootstrap();
      }
      else
      {
        bootstrapWithoutReplacement();
      }
    }

    /*
    // For each variable in the data set, print its values (use ", " as separators between the values from a variable), but only for the observations in the bootstrap sample:
    Rcpp::Rcout << "Bootstrap sample: " << std::endl;
    for (size_t i = 0; i < data->getNumCols(); ++i)
     {
      Rcpp::Rcout << "c(";
      for (size_t j = 0; j < sampleIDs.size(); ++j)
      {
      Rcpp::Rcout << data->get(sampleIDs[j], i);
      if (j < sampleIDs.size() - 1)
      {
        Rcpp::Rcout << ", ";
      }
      }
      Rcpp::Rcout << ")" << std::endl;
      if (i < data->getNumCols() - 1)
      {
        Rcpp::Rcout << ", ";
      }
     }
     Rcpp::Rcout << std::endl;

     // Print the values of sampleIDs:
      Rcpp::Rcout << "sampleIDs <- c(";
      for (size_t i = 0; i < sampleIDs.size() - 1; ++i)
      {
		Rcpp::Rcout << sampleIDs[i] << ", ";
      }
	  Rcpp::Rcout << sampleIDs[sampleIDs.size() - 1] << ")";
      Rcpp::Rcout << std::endl;

     // For each variable in the data set, print its values (use ", " as separators between the values from a variable), but only for the observations in the OOB sample:
    Rcpp::Rcout << "OOB sample: " << std::endl;
    for (size_t i = 0; i < data->getNumCols(); ++i)
     {
      Rcpp::Rcout << "c(";
      for (size_t j = 0; j < oob_sampleIDs.size(); ++j)
      {
      Rcpp::Rcout << data->get(oob_sampleIDs[j], i);
      if (j < oob_sampleIDs.size() - 1)
      {
        Rcpp::Rcout << ", ";
      }
      }
      Rcpp::Rcout << ")" << std::endl;
      if (i < data->getNumCols() - 1)
      {
        Rcpp::Rcout << ", ";
      }
     }
     Rcpp::Rcout << std::endl;

      // Print the values of oob_sampleIDs:
        Rcpp::Rcout << "oob_sampleIDs <- c(";
        for (size_t i = 0; i < oob_sampleIDs.size()-1; ++i)
        {
          Rcpp::Rcout << oob_sampleIDs[i] << ", ";
        }
		Rcpp::Rcout << oob_sampleIDs[oob_sampleIDs.size() - 1] << ")";
        Rcpp::Rcout << std::endl;
     */
    
    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Randomly draw a proportion prop_var_root of the variables (default 0.7), which will be used for the tree root:
    std::vector<size_t> varIDs_root;
    // Determine the number of variables to be used for the tree root as prop_var_root times the number of all available variables:
    size_t num_vars_root = round((data->getNumCols()-1) * prop_var_root);
    if (num_vars_root == 0)
    {
      num_vars_root = 1;
    }
    
    // Draw num_vars_root variables without replacement from all available variables:
    // Get the vector with all available variables, while excluding the variables in data->getNoSplitVariables():
    const std::vector<size_t>& all_vars = *allowedVarIDs_;
    // Draw num_vars_root variables without replacement from all_vars:
    drawWithoutReplacementFromVector(varIDs_root, all_vars, random_number_generator, num_vars_root);
    
    /*
    // Print the values of varIDs_root:
    Rcpp::Rcout << "varIDs_root: ";
    for (size_t i = 0; i < varIDs_root.size(); ++i)
    {
      Rcpp::Rcout << varIDs_root[i] << " ";
    }
    Rcpp::Rcout << std::endl;
    */

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Generate 'num_cand_trees' trees (default: 1000) with random splits, where each tree is grown to a maximum depth of three:

    child_nodeIDs_loop.push_back(std::vector<size_t>());
    child_nodeIDs_loop.push_back(std::vector<size_t>());

    double best_decrease = -1;

    // Distribution to generate double between 0.0 and 1.0
    std::uniform_real_distribution<double> distr(0.0, 1.0);

    const size_t MAX_NODES = static_cast<size_t>(pow(2, max_depth_root + 1)) - 1;
    split_varIDs_loop.reserve(MAX_NODES);
    split_values_loop.reserve(MAX_NODES);
    child_nodeIDs_loop[0].reserve(MAX_NODES);
    child_nodeIDs_loop[1].reserve(MAX_NODES);
    start_pos_loop.reserve(MAX_NODES);
    end_pos_loop.reserve(MAX_NODES);
 
    for (size_t j = 0; j < num_cand_trees; ++j)
    {

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

      // Clear vectors
      clearRandomTree();

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

      // Make empty node:
      createEmptyNodeRandomTree();

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

      // Init start and end positions
      start_pos_loop[0] = 0;
      end_pos_loop[0] = sampleIDs.size();
  
      // Print the index of the current tree:
      // Rcpp::Rcout << "Tree index: " << j << std::endl;

      // While not all nodes terminal, split next node
      size_t num_open_nodes = 1;
      last_left_nodeID_loop = 0;
      size_t i = 0;
      depth = 0;
      std::vector<size_t> terminal_nodes; // Vector to store indices of terminal nodes
      while (num_open_nodes > 0)
      {

        // Split node at random
        bool is_terminal_node = splitNodeRandom(i, varIDs_root);

        if (is_terminal_node)
        {
          terminal_nodes.push_back(i); // Add index of terminal node to vector
          --num_open_nodes;
        }
        else
        {
          ++num_open_nodes;
          if (i >= last_left_nodeID_loop)
          {
            // If new level, increase depth
            // (left_node saves left-most node in current level, new level reached if that node is splitted)
            last_left_nodeID_loop = split_varIDs_loop.size() - 2;
            ++depth;
          }
        }
        ++i;
      }

      // Evaluate the tree:
      double decrease = evaluateRandomTree(terminal_nodes);

      // Save the current tree if it is better than the best tree so far:
      if (decrease >= best_decrease) 
      {
        // If decrease is equal to best_decrease, the current tree is saved with a probability of 0.5:
        if (decrease == best_decrease)
        {
          if (distr(random_number_generator) < 0.5)
          {
            split_varIDs_best = split_varIDs_loop;
            split_values_best = split_values_loop;
            child_nodeIDs_best = child_nodeIDs_loop;
            best_decrease = decrease;
          }
        }
        else
        {
          split_varIDs_best = split_varIDs_loop;
          split_values_best = split_values_loop;
          child_nodeIDs_best = child_nodeIDs_loop;
          best_decrease = decrease;
        }
      }

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    }

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Extend the best tree to the maximum depth using conventional splitting:

    // Init start and end positions
    start_pos[0] = 0;
    end_pos[0] = sampleIDs.size();

    // The first root in the (full) tree is always from the tree root
    nodeID_in_root[0] = 0;

    // While not all nodes terminal, split next node
    size_t num_open_nodes = 1;
    size_t i = 0;
    depth = 0;
    while (num_open_nodes > 0)
    {
      // Split node
      bool is_terminal_node = splitNodeFullTree(i);
      if (is_terminal_node)
      {
        --num_open_nodes;
      }
      else
      {
        ++num_open_nodes;
        if (i >= last_left_nodeID)
        {
          // If new level, increase depth
          // (left_node saves left-most node in current level, new level reached if that node is splitted)
          if (divfortype == 1)
          {
            last_left_nodeID = split_varIDs.size() - 2;
          }
          if (divfortype == 2)
          {
            last_left_nodeID = split_multvarIDs.size() - 2;
          }
          if (divfortype == 4)
          {
            last_left_nodeID = split_varIDs.size() - 2;
          }
          ++depth;
        }
      }
      ++i;
    }

    /*
    // Print the values of split_varIDs:
    Rcpp::Rcout << "split_varIDs <- c(";
    for (size_t i = 0; i < split_varIDs.size() - 1; ++i)
    {
      Rcpp::Rcout << split_varIDs[i] << ", ";
    }
	Rcpp::Rcout << split_varIDs[split_varIDs.size() - 1] << ")";
    Rcpp::Rcout << std::endl;

    // Print the values of split_values:
    Rcpp::Rcout << "split_values <- c(";
    for (size_t i = 0; i < split_values.size() - 1; ++i)
    {
      Rcpp::Rcout << split_values[i] << ", ";
    }
    Rcpp::Rcout << split_values[split_values.size() - 1] << ")";
    Rcpp::Rcout << std::endl;

    // Print the values of child_nodeIDs[0]:
    Rcpp::Rcout << "child_nodeIDs_0 <- c(";
    for (size_t i = 0; i < child_nodeIDs[0].size() - 1; ++i)
    {
      Rcpp::Rcout << child_nodeIDs[0][i] << ", ";
    }
	Rcpp::Rcout << child_nodeIDs[0][child_nodeIDs[0].size()-1] << ")";
    Rcpp::Rcout << std::endl;

    // Print the values of child_nodeIDs[1]:
    Rcpp::Rcout << "child_nodeIDs_1 <- c(";
    for (size_t i = 0; i < child_nodeIDs[1].size() - 1; ++i)
    {
      Rcpp::Rcout << child_nodeIDs[1][i] << ", ";
    }
    Rcpp::Rcout << child_nodeIDs[1][child_nodeIDs[1].size()-1] << ")";
    Rcpp::Rcout << std::endl;

    // Print an empty line:
    Rcpp::Rcout << std::endl;

   */

   /*
    // Print the values of the response variable for each terminal node in the tree:
    for (size_t i = 0; i < child_nodeIDs[0].size(); ++i)
    {
      if (child_nodeIDs[0][i] == 0 && child_nodeIDs[1][i] == 0)
      {
        Rcpp::Rcout << "c(";
        for (size_t j = start_pos[i]; j < end_pos[i]; ++j)
        {
          Rcpp::Rcout << data->get(sampleIDs[j], dependent_varID);
          if (j < end_pos[i] - 1)
          {
            Rcpp::Rcout << ", ";
          }
        }
        Rcpp::Rcout << ")," << std::endl;
      }
    }
    

        // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	
	*/

    // Delete sampleID vector to save memory
    ///sampleIDs.clear();
    ///sampleIDs.shrink_to_fit();
    cleanUpInternal();

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

  }

  void Tree::predict(const Data *prediction_data, bool oob_prediction)
  {

    size_t num_samples_predict;
    if (oob_prediction)
    {
      num_samples_predict = num_samples_oob;
    }
    else
    {
      num_samples_predict = prediction_data->getNumRows();
    }

    prediction_terminal_nodeIDs.resize(num_samples_predict, 0);

    // For each sample start in root, drop down the tree and return final value
    for (size_t i = 0; i < num_samples_predict; ++i)
    {
      size_t sample_idx;
      if (oob_prediction)
      {
        sample_idx = oob_sampleIDs[i];
      }
      else
      {
        sample_idx = i;
      }
      size_t nodeID = 0;
      while (1)
      {

        // Break if terminal node
        if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
        {
          break;
        }

        // Move to child
        size_t split_varID = split_varIDs[nodeID];

        double value = prediction_data->get(sample_idx, split_varID);
        if (prediction_data->isOrderedVariable(split_varID))
        {
          if (value <= split_values[nodeID])
          {
            // Move to left child
            nodeID = child_nodeIDs[0][nodeID];
          }
          else
          {
            // Move to right child
            nodeID = child_nodeIDs[1][nodeID];
          }
        }
        else
        {
          size_t factorID = floor(value) - 1;
          size_t splitID = floor(split_values[nodeID]);

          // Left if 0 found at position factorID
          if (!(splitID & (1 << factorID)))
          {
            // Move to left child
            nodeID = child_nodeIDs[0][nodeID];
          }
          else
          {
            // Move to right child
            nodeID = child_nodeIDs[1][nodeID];
          }
        }
      }

      prediction_terminal_nodeIDs[i] = nodeID;
    }
  }
  
  void Tree::predictMuw(const Data *prediction_data, bool oob_prediction)
  {
	 // Empty on purpose (virtual function only implemented in classification and probability)
  }

  void Tree::computePermutationImportance(std::vector<double> &forest_importance, std::vector<double> &forest_variance)
  {

    size_t num_independent_variables = data->getNumCols() - data->getNoSplitVariables().size();

	// Compute normal prediction accuracy for each tree. Predictions already computed..
    double accuracy_normal = computePredictionAccuracyInternal();

    prediction_terminal_nodeIDs.clear();
    prediction_terminal_nodeIDs.resize(num_samples_oob, 0);

    // Reserve space for permutations, initialize with oob_sampleIDs
    std::vector<size_t> permutations(oob_sampleIDs);

    // Randomly permute for all independent variables
    for (size_t i = 0; i < num_independent_variables; ++i)
    {

      // Skip no split variables
      size_t varID = i;
      for (auto &skip : data->getNoSplitVariables())
      {
        if (varID >= skip)
        {
          ++varID;
        }
      }

      // If variable is not used for splitting, skip it
      double accuracy_difference = 0;
      bool iscontained = false;
      for (size_t j = 0; j < split_varIDs.size(); ++j)
      {
        if (split_varIDs[j] == varID)
        {
          iscontained = true;
          break;
        }
      }
      if (!iscontained)
      {
        forest_importance[i] += 0;
      }
      else
      {
      // Permute and compute prediction accuracy again for this permutation and save difference
	  randomizedDropDownOobSamples(varID);
      double accuracy_randomized = computePredictionAccuracyInternal();
      accuracy_difference = accuracy_normal - accuracy_randomized;
      forest_importance[i] += accuracy_difference;
      }
		
      // Compute variance
      if (importance_mode == IMP_PERM_BREIMAN)
      {
        forest_variance[i] += accuracy_difference * accuracy_difference;
      }
      else if (importance_mode == IMP_PERM_LIAW)
      {
        forest_variance[i] += accuracy_difference * accuracy_difference * num_samples_oob;
      }
    }
  }
 

  void Tree::computeUFImportance(std::vector<double> &forest_importance)
  {

    // If at least one element of is_in_best is 1, compute the importance of the variables:
    if (std::find(is_in_best.begin(), is_in_best.end(), 1) != is_in_best.end())
    {

      // Determine the node IDs in the tree for which is_in_best is 1:
      std::vector<size_t> best_nodeIDs;
      for (size_t i = 0; i < is_in_best.size(); ++i)
      {
        if (is_in_best[i] == 1)
        {
          best_nodeIDs.push_back(i);
        }
      }

      // Drop the OOB observations down the tree and for each node in best_nodeIDs, determine the
      // OOB observations that pass through the node:
      std::vector<std::vector<size_t>> oob_sampleIDs_nodeID(best_nodeIDs.size());
      for (size_t sampleID : oob_sampleIDs)
      {
        size_t nodeID = 0;
        while (true)
        {
          if (std::find(best_nodeIDs.begin(), best_nodeIDs.end(), nodeID) != best_nodeIDs.end())
          {
            oob_sampleIDs_nodeID[std::find(best_nodeIDs.begin(), best_nodeIDs.end(), nodeID) - best_nodeIDs.begin()].push_back(sampleID);
          }
          if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
          {
            break;
          }
          size_t split_varID = split_varIDs[nodeID];
          double value = data->get(sampleID, split_varID);
          if (value <= split_values[nodeID])
          {
            nodeID = child_nodeIDs[0][nodeID];
          }
          else
          {
            nodeID = child_nodeIDs[1][nodeID];
          }
        }
      }

      // Remove empty elements from oob_sampleIDs_nodeID and the corresponding elements from best_nodeIDs:
      std::vector<size_t> best_nodeIDs_temp;
      std::vector<std::vector<size_t>> oob_sampleIDs_nodeID_temp;
      for (size_t i = 0; i < best_nodeIDs.size(); ++i)
      {
        if (!oob_sampleIDs_nodeID[i].empty())
        {
          best_nodeIDs_temp.push_back(best_nodeIDs[i]);
          oob_sampleIDs_nodeID_temp.push_back(oob_sampleIDs_nodeID[i]);
        }
      }
      best_nodeIDs = best_nodeIDs_temp;
      oob_sampleIDs_nodeID = oob_sampleIDs_nodeID_temp;

      /*
      // Print the values of best_nodeIDs:
      Rcpp::Rcout << "best_nodeIDs: ";
      for (size_t i = 0; i < best_nodeIDs.size(); ++i)
      {
        Rcpp::Rcout << best_nodeIDs[i] << " ";
      }
      Rcpp::Rcout << std::endl;

      // Print the values of oob_sampleIDs_nodeID:
      for (size_t i = 0; i < oob_sampleIDs_nodeID.size(); ++i)
      {
         Rcpp::Rcout << "oob_sampleIDs_nodeID[" << i << "]: ";
         for (size_t j = 0; j < oob_sampleIDs_nodeID[i].size(); ++j)
         {
           Rcpp::Rcout << oob_sampleIDs_nodeID[i][j] << " ";
         }
         Rcpp::Rcout << std::endl;
      }
	  */

      // Loop through best_nodeIDs and compute the importance of the variables:
      for (size_t i = 0; i < best_nodeIDs.size(); ++i)
      {
        // Print the size of oob_sampleIDs_nodeID[i]: 
        //Rcpp::Rcout << "oob_sampleIDs_nodeID[" << i << "].size(): " << oob_sampleIDs_nodeID[i].size() << std::endl;

        // Calculate the importance of the variable for the node using OOB observations:
        forest_importance[split_varIDs[best_nodeIDs[i]]] += computeUFNodeImportance(best_nodeIDs[i], oob_sampleIDs_nodeID[i]);
      }
	  
	  // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
      
    }

  }

  // Unity Forests:
  double Tree::computeUFNodeImportance(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID)
  {

//Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Compute the OOB split criterion value for the node:
    double split_crit_node = computeOOBSplitCriterionValue(nodeID, oob_sampleIDs_nodeID);

    // Print the value of split_crit_node:
    //Rcpp::Rcout << "split_crit_node: " << split_crit_node << std::endl;

    // Permute the OOB observations in oob_sampleIDs_nodeID:
    std::vector<size_t> permutations(oob_sampleIDs_nodeID);
    std::shuffle(permutations.begin(), permutations.end(), random_number_generator);

//Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    /*
    // Print the values of permutations:
    Rcpp::Rcout << "permutations: ";
    for (size_t i = 0; i < permutations.size(); ++i)
    {
      Rcpp::Rcout << permutations[i] << " ";
    }
    Rcpp::Rcout << std::endl;
	*/

    // Compute the OOB split criterion value for the node after permuting the OOB observations:
    double split_crit_node_permuted = computeOOBSplitCriterionValuePermuted(nodeID, oob_sampleIDs_nodeID, permutations);

//Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Print the value of split_crit_node_permuted:
    // Rcpp::Rcout << "split_crit_node_permuted: " << split_crit_node_permuted << std::endl;
    
    // Compute the difference between the OOB split criterion value for the node and the OOB split 
    // criterion value for the node after permuting the OOB observations and weight it by the number of
    // in-bag observations that pass through the node:
    double importance_node = (split_crit_node - split_crit_node_permuted) * (end_pos[nodeID] - start_pos[nodeID]);
    
//Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;	
	
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	
    // Return the importance of the variable for the node:
    return importance_node;

  }
 
// Unity Forests:
void Tree::computeSplitCriterionValues()
{

// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

    // Initialize the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    split_criterion.resize(split_varIDs.size(), -1);

// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

    // Drop all in-bag observations down the tree and store the node IDs
    std::vector<std::vector<size_t>> inbag_sampleIDs_nodeID(split_varIDs.size());
    for (size_t sampleID : sampleIDs)
    {
        size_t nodeID = 0;
        while (true)
        {
            inbag_sampleIDs_nodeID[nodeID].push_back(sampleID);
            if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
            {
                break;
            }
            size_t split_varID = split_varIDs[nodeID];
            double value = data->get(sampleID, split_varID);
            if (value <= split_values[nodeID])
            {
                nodeID = child_nodeIDs[0][nodeID];
            }
            else
            {
                nodeID = child_nodeIDs[1][nodeID];
            }
        }
    }
	
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

// Print the values of inbag_sampleIDs_nodeID:
//for (size_t i = 0; i < inbag_sampleIDs_nodeID.size(); ++i)
//{
//    Rcpp::Rcout << "inbag_sampleIDs_nodeID[" << i << "]: ";
//    for (size_t j = 0; j < inbag_sampleIDs_nodeID[i].size(); ++j)
//    {
//        Rcpp::Rcout << inbag_sampleIDs_nodeID[i][j] << " ";
//    }
//    Rcpp::Rcout << std::endl;
//}
////std::this_thread::sleep_for(std::chrono::milliseconds(10));

    // Calculate the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    std::vector<size_t> inbag_sampleIDs_left_child;
    std::vector<size_t> inbag_sampleIDs_right_child;
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
        // If the node is the first node or is in the tree root and has children that are in the tree root, calculate the split criterion:
        if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))// && inbag_sampleIDs_nodeID[i].size() >= 20)
        {

          // Print the values of i:
          //Rcpp::Rcout << "i: " << i << std::endl;
            //std::this_thread::sleep_for(std::chrono::milliseconds(10));
			
            // Determine the in-bag sample IDs of the two child nodes of the current node (hint: the in-bag sample IDs of the left child node are stored in inbag_sampleIDs_nodeID[child_nodeIDs[0][i]]):
            inbag_sampleIDs_left_child = inbag_sampleIDs_nodeID[child_nodeIDs[0][i]];
            inbag_sampleIDs_right_child = inbag_sampleIDs_nodeID[child_nodeIDs[1][i]];

            // Print the value of child_nodeIDs[0][i]:
            //Rcpp::Rcout << "child_nodeIDs[0][i]: " << child_nodeIDs[0][i] << std::endl;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

            // Print the value of child_nodeIDs[1][i]:
            //Rcpp::Rcout << "child_nodeIDs[1][i]: " << child_nodeIDs[1][i] << std::endl;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

            // Print the values of inbag_sampleIDs_left_child:
            //Rcpp::Rcout << "inbag_sampleIDs_left_child: ";
            //for (size_t j = 0; j < inbag_sampleIDs_left_child.size(); ++j)
            //{
            //    Rcpp::Rcout << inbag_sampleIDs_left_child[j] << " ";
            //}
            //Rcpp::Rcout << std::endl;

            // Print the values of inbag_sampleIDs_right_child:
            //Rcpp::Rcout << "inbag_sampleIDs_right_child: ";
            //for (size_t j = 0; j < inbag_sampleIDs_right_child.size(); ++j)
            //{
            //    Rcpp::Rcout << inbag_sampleIDs_right_child[j] << " ";
            //}
            //Rcpp::Rcout << std::endl;
			//std::this_thread::sleep_for(std::chrono::milliseconds(10));

            // Calculate the split criterion for the node using in-bag observations and multiply it by the number of in-bag observations that pass through the node:
            split_criterion[i] = computeSplitCriterion(inbag_sampleIDs_left_child, inbag_sampleIDs_right_child) * static_cast<double>(inbag_sampleIDs_nodeID[i].size());
			
				// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));
			
        }
    }
	
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	//std::this_thread::sleep_for(std::chrono::milliseconds(10));
	
}

// Unity Forests:
void Tree::computeOOBSplitCriterionValues()
{

// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

    // Initialize the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    split_criterion.resize(split_varIDs.size(), -1);

// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

    // Make a new vector to store the out of bag sample IDs.
    std::vector<size_t> oob_sampleIDs;
    // The out of bag sample IDs are the indices of inbag_counts that are equal to 0.
    for (size_t i = 0; i < inbag_counts.size(); ++i)
    {
        if (inbag_counts[i] == 0)
        {
            oob_sampleIDs.push_back(i);
        }
    }

    // Drop all out-of-bag observations down the tree and store the node IDs
    std::vector<std::vector<size_t>> oob_sampleIDs_nodeID(split_varIDs.size());
    for (size_t sampleID : oob_sampleIDs)
    {
        size_t nodeID = 0;
        while (true)
        {
            oob_sampleIDs_nodeID[nodeID].push_back(sampleID);
            if (child_nodeIDs[0][nodeID] == 0 && child_nodeIDs[1][nodeID] == 0)
            {
                break;
            }
            size_t split_varID = split_varIDs[nodeID];
            double value = data->get(sampleID, split_varID);
            if (value <= split_values[nodeID])
            {
                nodeID = child_nodeIDs[0][nodeID];
            }
            else
            {
                nodeID = child_nodeIDs[1][nodeID];
            }
        }
    }

    // If 

    // Calculate the split criterion for each node in the tree that is in the tree root and is not a terminal node:
    std::vector<size_t> oob_sampleIDs_left_child;
    std::vector<size_t> oob_sampleIDs_right_child;
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
        // If the node is the first node or is in the tree root and has children that are in the tree root and the split variable is in repr_vars, calculate the split criterion:
        if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))// && inbag_sampleIDs_nodeID[i].size() >= 20)
        {

          if (std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) != repr_vars.end()) {
            // Print the values of split_varIDs[i]:
            //Rcpp::Rcout << "split_varIDs[i]: " << split_varIDs[i] << std::endl;
            //std::this_thread::sleep_for(std::chrono::milliseconds(10));

          // Print the values of i:
          //Rcpp::Rcout << "i: " << i << std::endl;
            //std::this_thread::sleep_for(std::chrono::milliseconds(10));

            // Determine the out-of-bag sample IDs of the two child nodes of the current node (hint: the in-bag sample IDs of the left child node are stored in oob_sampleIDs_nodeID[child_nodeIDs[0][i]]):
            oob_sampleIDs_left_child = oob_sampleIDs_nodeID[child_nodeIDs[0][i]];
            oob_sampleIDs_right_child = oob_sampleIDs_nodeID[child_nodeIDs[1][i]];

            // Print the value of child_nodeIDs[0][i]:
            //Rcpp::Rcout << "child_nodeIDs[0][i]: " << child_nodeIDs[0][i] << std::endl;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

            // Print the value of child_nodeIDs[1][i]:
            //Rcpp::Rcout << "child_nodeIDs[1][i]: " << child_nodeIDs[1][i] << std::endl;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));

            // Print the values of inbag_sampleIDs_left_child:
            //Rcpp::Rcout << "inbag_sampleIDs_left_child: ";
            //for (size_t j = 0; j < inbag_sampleIDs_left_child.size(); ++j)
            //{
            //    Rcpp::Rcout << inbag_sampleIDs_left_child[j] << " ";
            //}
            //Rcpp::Rcout << std::endl;

            // Print the values of inbag_sampleIDs_right_child:
            //Rcpp::Rcout << "inbag_sampleIDs_right_child: ";
            //for (size_t j = 0; j < inbag_sampleIDs_right_child.size(); ++j)
            //{
            //    Rcpp::Rcout << inbag_sampleIDs_right_child[j] << " ";
            //}
            //Rcpp::Rcout << std::endl;
			//std::this_thread::sleep_for(std::chrono::milliseconds(10));

            // If neither child node is empty, calculate the split criterion:
            if (!oob_sampleIDs_left_child.empty() && !oob_sampleIDs_right_child.empty())
            {
              // Calculate the split criterion for the node using in-bag observations and multiply it by the number of in-bag observations that pass through the
            // Calculate the split criterion for the node using in-bag observations and multiply it by the number of in-bag observations that pass through the node:
            split_criterion[i] = computeSplitCriterion(oob_sampleIDs_left_child, oob_sampleIDs_right_child) * static_cast<double>(oob_sampleIDs_nodeID[i].size());
            }

				// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
//std::this_thread::sleep_for(std::chrono::milliseconds(10));
          }			

        }
    }
	
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	// Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;
	//std::this_thread::sleep_for(std::chrono::milliseconds(10));
	
}

double Tree::computeSplitCriterion(std::vector<size_t> sampleIDs_left_child, std::vector<size_t> sampleIDs_right_child)
{
 // Default implementation: Throw an exception (or do nothing)
        throw std::runtime_error("computeSplitCriterion not implemented for this subclass.");
}

double Tree::computeOOBSplitCriterionValue(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID)
{
 // Default implementation: Throw an exception (or do nothing)
        throw std::runtime_error("computeOOBSplitCriterionValue not implemented for this subclass.");
}

  double Tree::computeOOBSplitCriterionValuePermuted(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID, std::vector<size_t> permutations)
  {

    // Default implementation: Throw an exception (or do nothing)
    throw std::runtime_error("computeOOBSplitCriterionValuePermuted not implemented for this subclass.");

  }

// Unity Forests:
void Tree::collectSplits(size_t tree_idx, std::vector<std::vector<SplitData>> &all_splits_per_variable)
{

      // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

  // Print the length of split_varIDs:
  // Rcpp::Rcout << "split_varIDs.size(): " << split_varIDs.size() << std::endl;
  
// Print the length of split_criterion:
  // Rcpp::Rcout << "split_criterion.size(): " << split_criterion.size() << std::endl;

  // Loop over all nodes in the tree and collect the split data:
  for (size_t i = 0; i < split_varIDs.size(); ++i)
  {

    // Print the value of i:
    // Rcpp::Rcout << "i: " << i << std::endl;

    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Determine tree index:
    size_t tree_index = tree_idx;
    // Determine node ID:
    size_t nodeID = i;
    // Determine the variable ID of the split:
    size_t varID = split_varIDs[i];
    // Determine the split value:
    double split_value = split_criterion[i];

    // Print the values of tree_index, nodeID, varID, and split_value:
    // Rcpp::Rcout << "tree_index: " << tree_index << std::endl;
    // Rcpp::Rcout << "nodeID: " << nodeID << std::endl;
    // Rcpp::Rcout << "varID: " << varID << std::endl;
    // Rcpp::Rcout << "split_value: " << split_value << std::endl;

    // Print the length of all_splits_per_variable:
    // Rcpp::Rcout << "all_splits_per_variable.size(): " << all_splits_per_variable.size() << std::endl;

    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Create a SplitData object and add it to all_splits_per_variable:
    SplitData split_data(tree_index, nodeID, split_value);

    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Add the SplitData object to all_splits_per_variable:
    all_splits_per_variable[varID].push_back(split_data);

    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

  }
 
}

// Unity Forests:
void Tree::collectOOBSplits(size_t tree_idx, std::vector<std::vector<SplitData>> &all_splits_per_variable)
{

      // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

  // Print the length of split_varIDs:
  // Rcpp::Rcout << "split_varIDs.size(): " << split_varIDs.size() << std::endl;
  
// Print the length of split_criterion:
  // Rcpp::Rcout << "split_criterion.size(): " << split_criterion.size() << std::endl;

  // Loop over all nodes in the tree and collect the split data:
  for (size_t i = 0; i < split_varIDs.size(); ++i)
  {

    // If split_varIDs[i] is in repr_vars, collect the split data:
    if (std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) != repr_vars.end())
    {

    // Print the value of i:
    // Rcpp::Rcout << "i: " << i << std::endl;

    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Determine tree index:
    size_t tree_index = tree_idx;
    // Determine node ID:
    size_t nodeID = i;
    // Determine the index of split_varIDs[i] in repr_vars:
    size_t varID_index = std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) - repr_vars.begin();
    // Determine the split value:
    double split_value = split_criterion[i];

    // Print the values of tree_index, nodeID, varID, and split_value:
    // Rcpp::Rcout << "tree_index: " << tree_index << std::endl;
    // Rcpp::Rcout << "nodeID: " << nodeID << std::endl;
    // Rcpp::Rcout << "varID: " << varID << std::endl;
    // Rcpp::Rcout << "split_value: " << split_value << std::endl;

    // Print the length of all_splits_per_variable:
    // Rcpp::Rcout << "all_splits_per_variable.size(): " << all_splits_per_variable.size() << std::endl;

    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Create a SplitData object and add it to all_splits_per_variable:
    SplitData split_data(tree_index, nodeID, split_value);

    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    // Add the SplitData object to all_splits_per_variable:
    all_splits_per_variable[varID_index].push_back(split_data);

    }
    // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

  }
 
}

// Unity Forests:
void Tree::determineBestTreesPerVariable(size_t tree_idx, std::vector<std::vector<size_t>>& bestTreesPerVariable)
{
 
 // If is_in_best is 1 for at least one element, determine the best trees for each variable in repr_vars:

  if (std::find(is_in_best.begin(), is_in_best.end(), 1) != is_in_best.end())
  {

    size_t varID;
    size_t varID_index;
    std::vector<size_t> varID_indices;

    // Loop over all nodes in the tree and determine the best trees for each variable:
    for (size_t i = 0; i < split_varIDs.size(); ++i)
    {
      // If is_in_best is 1 for the node and the variable is in repr_vars, add the tree index to the best trees for the variable:
      if (is_in_best[i] == 1 && std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) != repr_vars.end())
      {
        // Add the tree index to the best trees for the variable
        // (Note: use the index of bestTreesPerVariable for which repr_vars==split_varIDs[i]):
        varID = split_varIDs[i];
        varID_index = std::find(repr_vars.begin(), repr_vars.end(), varID) - repr_vars.begin();
        varID_indices.push_back(varID_index);
        
      }
    }

    // Remove duplicates from varID_indices:
    std::sort(varID_indices.begin(), varID_indices.end());
    varID_indices.erase(std::unique(varID_indices.begin(), varID_indices.end()), varID_indices.end());

    // Add the tree index to the best trees for each variable:
    for (size_t i = 0; i < varID_indices.size(); ++i)
    {
      // Add the tree index to the best trees for the variable:
      bestTreesPerVariable[varID_indices[i]].push_back(tree_idx);
    }

  }
 
}


// Unity Forests:
void Tree::countVariables(std::vector<size_t>& var_counts)
{
	
  // Loop over all nodes in the tree and count the variables:
  for (size_t i = 0; i < split_varIDs.size(); ++i)
  {
	  
    // If node is in the tree root and has children that are in the tree root, count the variable:
    if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))
    {	
      // Increment the count for the variable:
      var_counts[split_varIDs[i]]++;
    }
  }

}

// Unity Forests:
void Tree::computeUv(size_t tree_ind, std::vector<std::vector<double>>& Uv) {

  size_t depth_temp = 1; 
  std::vector<size_t> curr_child_nodeIDs;

  for (size_t i = 0; i < split_varIDs.size(); ++i)
  {
    if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))
    {
       // If nodeID_in_root[i] is in curr_child_nodeIDs, increment depth_temp:
       if (std::find(curr_child_nodeIDs.begin(), curr_child_nodeIDs.end(), nodeID_in_root[i]) != curr_child_nodeIDs.end())
       {
         depth_temp++;
         // Clear the curr_child_nodeIDs vector:
         curr_child_nodeIDs.clear();
       }       

       // Add 1/(2^(depth_temp-1)) to the Uv for the variable:	
       Uv[tree_ind][split_varIDs[i]] += 1.0 / std::pow(2.0, depth_temp - 1);

        // Add the child node IDs to the curr_child_nodeIDs vector:
        curr_child_nodeIDs.push_back(nodeID_in_root[child_nodeIDs[0][i]]);
        curr_child_nodeIDs.push_back(nodeID_in_root[child_nodeIDs[1][i]]);
    }
  }

  // Divide all elements of Uv[tree_ind] by depth_temp:
  for (size_t i = 0; i < Uv[tree_ind].size(); ++i)
  {
    Uv[tree_ind][i] /= depth_temp;
  }

}

// Unity Forests:
// Set the score vector for the tree:
// This function sets the score values for the variables in the tree based on the scores_tree vector.
void Tree::setScoreVector(std::vector<double> scores_tree)
{

  // Initialize the vector score_values with the value -99:
  score_values.resize(split_varIDs.size(), -99.0);
  
  // Loop over all nodes in the tree and set the score values:
  for (size_t i = 0; i < split_varIDs.size(); ++i)
  {
    // If nodeID_in_root[i] is in curr_child_nodeIDs, set the score value:
    if ((i == 0 || nodeID_in_root[i] != 0) && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))
    {
      // Set the score value (Note: the i-th value of scores_tree is the score value for the i-th variable):
      score_values[i] = scores_tree[split_varIDs[i]];
    }
  }

}

// Unity Forests:
void Tree::markBestSplits(size_t tree_idex, const std::vector<std::set<std::pair<size_t, size_t>>>& bestSplits)
{

  // Reserve space for the vector is_in_best (note: the vector is_in_best is part of the Tree object):
  is_in_best.resize(split_varIDs.size());

  // Loop over all nodes in the tree and mark the best splits:
  for (size_t i = 0; i < split_varIDs.size(); ++i)
  {
    if (bestSplits[split_varIDs[i]].count(std::make_pair(tree_idex, i)) > 0)
    {
      // Mark the split as one of the best splits:
      is_in_best[i] = 1;
    }
    else
    {
      // Mark the split as not one of the best splits:
      is_in_best[i] = 0;
    }
  }

}

// Unity Forests:
void Tree::markBestOOBSplits(size_t tree_idex, const std::vector<std::set<std::pair<size_t, size_t>>>& bestSplits)
{

  // The vector is_in_best is part of the Tree object.
  // Make is_in_best of length split_varIDs.size() and initialize it to 0:
  is_in_best.resize(split_varIDs.size(), 0);

  // Loop over all nodes in the tree and mark the best splits:
  for (size_t i = 0; i < split_varIDs.size(); ++i)
  {  
    // If split_varIDs[i] is in repr_vars, mark the best splits:
    if (std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) != repr_vars.end())
    {
      size_t varID_index = std::find(repr_vars.begin(), repr_vars.end(), split_varIDs[i]) - repr_vars.begin();
      if (bestSplits[varID_index].count(std::make_pair(tree_idex, i)) > 0)
      {
        // Mark the split as one of the best splits:
        is_in_best[i] = 1;
      }
    }
  }

}

  // Unity Forests: Compute the contribution to the variable importance for one tree
  void Tree::computeImportanceUF(std::vector<double> &forest_importance, std::vector<double> &forest_variance)
  {

    size_t num_independent_variables = data->getNumCols() - data->getNoSplitVariables().size();

    // Determine the node IDs in the tree root that have children:
    std::vector<size_t> root_nodeIDs;
    // Ensure the first node is always included
    root_nodeIDs.push_back(0);

    for (size_t i = 1; i < split_varIDs.size(); ++i)
    {
      if (nodeID_in_root[i] != 0 && (child_nodeIDs[0][i] != 0 && nodeID_in_root[child_nodeIDs[0][i]] != 0))
      {
      root_nodeIDs.push_back(i);
      }
    }
 
    // Make a triple-nested vector that will store the out-of-bag sample IDs for each variable and each node ID, where
    // only the nodes that are in the tree root and have children in the tree root are considered:

    std::vector<std::vector<std::vector<size_t>>> oob_sampleIDs_var_nodeID;
    oob_sampleIDs_var_nodeID.resize(num_independent_variables);

    for (size_t i = 0; i < num_independent_variables; ++i)
    {
      oob_sampleIDs_var_nodeID[i].resize(root_nodeIDs.size());
    }

    // For each variable, determine the out-of-bag sample IDs for each node ID in the tree root that has children in the tree root:

      // Drop the out-of-bag observations down the tree:

      for (size_t j = 0; j < num_samples_oob; ++j)
      {
        size_t sample_idx = oob_sampleIDs[j];
        size_t current_nodeID = 0;
        while (1)
        {
          // Break if current_nodeID is not an element of root_nodeIDs:
          if (std::find(root_nodeIDs.begin(), root_nodeIDs.end(), current_nodeID) == root_nodeIDs.end())
          {
            break;
          }

          // Determine the index of the current node ID in root_nodeIDs:
          size_t nodeID_idx = std::distance(root_nodeIDs.begin(), std::find(root_nodeIDs.begin(), root_nodeIDs.end(), current_nodeID));

          // Assign sample_idx to oob_sampleIDs_var_nodeID[varID][nodeID_idx]:
          size_t varID = split_varIDs[current_nodeID];
          oob_sampleIDs_var_nodeID[varID][nodeID_idx].push_back(sample_idx);

          // Move to child
          double value = data->get(sample_idx, varID);

            if (value <= split_values[current_nodeID])
            {
              // Move to left child
              current_nodeID = child_nodeIDs[0][current_nodeID];
            }
            else
            {
              // Move to right child
              current_nodeID = child_nodeIDs[1][current_nodeID];
            }

        }

      }

    /*
    // Print the values of oob_sampleIDs_var_nodeID:
    Rcpp::Rcout << "oob_sampleIDs_var_nodeID: " << std::endl;
    for (size_t i = 0; i < num_independent_variables; ++i)
    {
      Rcpp::Rcout << "Variable " << i << ": " << std::endl;
      for (size_t j = 0; j < root_nodeIDs.size(); ++j)
      {
        Rcpp::Rcout << "Node ID " << root_nodeIDs[j] << ": ";
        for (size_t k = 0; k < oob_sampleIDs_var_nodeID[i][j].size(); ++k)
        {
          Rcpp::Rcout << oob_sampleIDs_var_nodeID[i][j][k] << ", ";
        }
        Rcpp::Rcout << std::endl;
      }
    }
    Rcpp::Rcout << std::endl;
    */

    // Compute the variable importance for each variable:
    for (size_t i = 0; i < num_independent_variables; ++i)
    {

      // Skip no split variables
      size_t varID = i;
      for (auto &skip : data->getNoSplitVariables())
      {
        if (varID >= skip)
        {
          ++varID;
        }
      }

      // If no out-of-bag observations are passing through nodes in the tree root that use the variable varID, skip it
      bool iscontained = false;
      for (size_t j = 0; j < root_nodeIDs.size(); ++j)
      {
        if (oob_sampleIDs_var_nodeID[varID][j].size() > 0)
        {
          iscontained = true;
          break;
        }
      }

      if (!iscontained)
      {
        // If the variable is not used for any of the out-of-bag observations, set the variable importance to 0:
        forest_importance[i] += 0;
      }
      else
      {
      // Permute and compute prediction accuracy again for this permutation and save difference

      for (size_t j = 0; j < root_nodeIDs.size(); ++j)
      {
        size_t nodeID = root_nodeIDs[j];
        if (oob_sampleIDs_var_nodeID[varID][j].size() > 0)
        {

          // Get all sample IDs that are in oob_sampleIDs_var_nodeID[varID][j]:
          std::vector<size_t> oob_sampleIDs_nodeID = oob_sampleIDs_var_nodeID[varID][j];

          double split_criterion_difference = calculateSplitCriterionDiffOOB(nodeID, oob_sampleIDs_nodeID);

          // // Print the value of split_criterion_difference:
          // Rcpp::Rcout << "split_criterion_difference: " << split_criterion_difference << std::endl;
          // Rcpp::Rcout << std::endl;

          // Add the difference in split criterion to the variable importance, weighted by the size of the node:
          forest_importance[i] += split_criterion_difference * (end_pos[nodeID] - start_pos[nodeID]);
          
        }
      }

      }

    }

  }

  void Tree::appendToFile(std::ofstream &file)
  {

    // Save general fields
    saveVector2D(child_nodeIDs, file);
    saveVector1D(split_varIDs, file);
    saveVector1D(split_values, file);

    // Call special functions for subclasses to save special fields.
    appendToFileInternal(file);
  }

  void Tree::createPossibleSplitVarSubset(std::vector<size_t> &result)
  {

    size_t num_vars = data->getNumCols();

    // For corrected Gini importance add dummy variables
    if (importance_mode == IMP_GINI_CORRECTED)
    {
      num_vars += data->getNumCols() - data->getNoSplitVariables().size();
    }

    // Randomly add non-deterministic variables (according to weights if needed)
    if (split_select_weights->empty())
    {
      if (deterministic_varIDs->empty())
      {
        drawWithoutReplacementSkip(result, random_number_generator, num_vars, data->getNoSplitVariables(), mtry);
      }
      else
      {
        std::vector<size_t> skip;
        std::copy(data->getNoSplitVariables().begin(), data->getNoSplitVariables().end(),
                  std::inserter(skip, skip.end()));
        std::copy(deterministic_varIDs->begin(), deterministic_varIDs->end(), std::inserter(skip, skip.end()));
        std::sort(skip.begin(), skip.end());
        drawWithoutReplacementSkip(result, random_number_generator, num_vars, skip, mtry);
      }
    }
    else
    {
      drawWithoutReplacementWeighted(result, random_number_generator, *split_select_varIDs, mtry, *split_select_weights);
    }

    // Always use deterministic variables
    std::copy(deterministic_varIDs->begin(), deterministic_varIDs->end(), std::inserter(result, result.end()));
  }

  // Unity Forests: Split node in (full) tree
  bool Tree::splitNodeFullTree(size_t nodeID)
  {

    bool split_in_root;

    if ((nodeID_in_root[nodeID] != 0 || nodeID == 0) && !(child_nodeIDs_best[0][nodeID_in_root[nodeID]] == 0 && child_nodeIDs_best[1][nodeID_in_root[nodeID]] == 0))
    {
      // If node is in tree root and not terminal, use split from tree root
      split_varIDs[nodeID] = split_varIDs_best[nodeID_in_root[nodeID]];
      split_values[nodeID] = split_values_best[nodeID_in_root[nodeID]];

      // The split is in the tree root
      split_in_root = true;
    }
    else
    {

      // Select random subset of variables to possibly split at
      std::vector<size_t> possible_split_varIDs;
      createPossibleSplitVarSubset(possible_split_varIDs);

      // Call subclass method, sets split_varIDs and split_values
      bool stop = splitNodeInternal(nodeID, possible_split_varIDs);
      if (stop)
      {
        // Terminal node
        return true;
      }

      // The split is not in the tree root
      split_in_root = false;
    }

    size_t split_varID = split_varIDs[nodeID];
    double split_value = split_values[nodeID];

    // Save non-permuted variable for prediction
    split_varIDs[nodeID] = data->getUnpermutedVarID(split_varID);

    // Create child nodes
    size_t left_child_nodeID = split_varIDs.size();
    child_nodeIDs[0][nodeID] = left_child_nodeID;
    createEmptyNodeFullTree();
    if (split_in_root)
    {
      nodeID_in_root[left_child_nodeID] = child_nodeIDs_best[0][nodeID_in_root[nodeID]];
    }
    start_pos[left_child_nodeID] = start_pos[nodeID];

    size_t right_child_nodeID = split_varIDs.size();
    child_nodeIDs[1][nodeID] = right_child_nodeID;
    createEmptyNodeFullTree();
    if (split_in_root)
    {
      nodeID_in_root[right_child_nodeID] = child_nodeIDs_best[1][nodeID_in_root[nodeID]];
    }
    start_pos[right_child_nodeID] = end_pos[nodeID];

    // For each sample in node, assign to left or right child

    // Ordered: left is <= splitval and right is > splitval
    size_t pos = start_pos[nodeID];
    while (pos < start_pos[right_child_nodeID])
    {
      size_t sampleID = sampleIDs[pos];
      if (data->get(sampleID, split_varID) <= split_value)
      {
        // If going to left, do nothing
        ++pos;
      }
      else
      {
        // If going to right, move to right end
        --start_pos[right_child_nodeID];
        std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]);
      }
    }

    // End position of left child is start position of right child
    end_pos[left_child_nodeID] = start_pos[right_child_nodeID];
    end_pos[right_child_nodeID] = end_pos[nodeID];

    // No terminal node
    return false;
  }

  // New function.
  // This function samples the pairs of variable IDs and splits in these
  // variables.
  void Tree::drawSplitsUnivariate(size_t nodeID, size_t n_triedsplits, std::vector<std::pair<size_t, double>> &sampled_varIDs_values)
  {

    // Get the total number of variables
    size_t num_vars = data->getNumCols();

    // For corrected Gini importance add dummy variables
    if (importance_mode == IMP_GINI_CORRECTED)
    {
      num_vars += data->getNumCols() - data->getNoSplitVariables().size();
    }

    // Determine the indices of the covariates:
    ////////////////

    // REMARK: The covariates are not necessarily all variables
    // different from the target variable, but there may be more
    // variables which should not be used for splitting.
    // For example in the survival case, we have two variables
    // associated with the target variable, the time variable and
    // the censoring indicator. Apart from this, it is also possible
    // for the user to specify variables that should not be used
    // for splitting ("no split variables").
    // Therefore, when determening the indices of the covariates to
    // use, we have to cycle through all variables and skip those
    // variable that should not be used for splitting.

    // Initialize an empty vector of consecutive numbers 0, 1, 2, ...:
    // REMARK: This vector will be modified to exclude the "no split variables":
    std::vector<int> all_varIDsPre(num_vars - data->getNoSplitVariables().size());
    std::iota(all_varIDsPre.begin(), all_varIDsPre.end(), 0);

    // Initialize empty vector, which will contain the indices of
    // the covariates:
    std::vector<int> all_varIDs(num_vars - data->getNoSplitVariables().size());

    // Cycle through "all_varIDsPre" and skip the "no split variables":
    size_t countertemp = 0;
    size_t varIDtemp = 0;

    for (auto &varID : all_varIDsPre)
    {
      varIDtemp = varID;
      // Go through the "no split variables"; if the current variable
      // "varID" is equal to the respective "no split variable",
      // increase index of the current variable:
      for (auto &skip_value : data->getNoSplitVariables())
      {
        if (varIDtemp >= skip_value)
        {
          ++varIDtemp;
        }
      }
      all_varIDs[countertemp] = varIDtemp;
      ++countertemp;
    }

    // Cycle through all variables, count their numbers of split
    // points and add these up:
    ///////////////////////

    size_t n_splitstotal = 0;
    size_t n_triedsplitscandidate;
    for (auto &varID : all_varIDs)
    {

      // Create possible split values for variable 'varID'
      std::vector<double> possible_split_values;
      data->getAllValues(possible_split_values, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]);

      // Add the number of split values up to the total number of
      // split values:
      n_splitstotal += possible_split_values.size() - 1;

      // Break the loop, if the number n_triedsplitscandidate = proptry * n_splitstotal
      // already exceeds the maximum number of splits to sample:
      n_triedsplitscandidate = (size_t)((double)n_splitstotal * proptry + 0.5);

      if (n_triedsplitscandidate > n_triedsplits)
      {
        break;
      }
    }

    // If the calculated number of splits to sample
    // is larger than the maximum number of splits to sample 'nsplits',
    // use 'nsplits':
    n_triedsplits = std::min(n_triedsplits, n_triedsplitscandidate);

    // Sample the pairs of variable IDs and splits:
    //////////////////////

    // If "n_triedsplits" is zero no splits should be sampled,
    // which will result in findBestSplitUnivariate() returning
    // zero, leading the node splitting to stop:
    if (n_triedsplits > 0)
    {

      // Initialize:
      sampled_varIDs_values.reserve(n_triedsplits);

      // Random number generator for the covariates:
      std::uniform_int_distribution<size_t> unif_distvarID(0, num_vars - 1 - data->getNoSplitVariables().size());

      // Draw the covariate/split pairs by a loop:
      size_t drawnvarID;
      double drawnvalue;

      for (size_t i = 0; i < n_triedsplits; ++i)
      {

        std::pair<size_t, double> drawnpair;
        bool pairnotfound = false;

        // Loop that stops as soon "pairnotfound" becomes FALSE.
        do
        {

          // Draw a covariate, while skipping the "no split variables":
          drawnvarID = unif_distvarID(random_number_generator);
          for (auto &skip_value : data->getNoSplitVariables())
          {
            if (drawnvarID >= skip_value)
            {
              ++drawnvarID;
            }
          }

          // Create possible split values for variable 'varID':
          std::vector<double> possible_split_values;
          data->getAllValues(possible_split_values, sampleIDs, drawnvarID, start_pos[nodeID], end_pos[nodeID]);

          // The pair is declared not found if there is only one
          // or less possible split values in the drawn covariate
          // (and a new variable will be drawn as a consequence)
          // REMINDER: This might be computationally (very) ineffective
          // for higher dimensional data with many dichotome covariates,
          // because here it can happen that there will be no possible
          // splits in a large quantity of covariates after a few splits,
          // which might have the effect that the process of drawing the
          // covariate has to repeated many times before a suitable
          // covariate has been drawn. For this reason it might be
          // better to store the indices of the covariates for which
          // there are no splits left, so that these are not drawn
          // again and again.
          pairnotfound = possible_split_values.size() < 2;

          if (!pairnotfound)
          {

            // Determine the splits in the drawn covariates, which are the mid points
            // between the neighboring covariate values:
            std::vector<double> all_mid_points(possible_split_values.size() - 1);
            for (size_t i = 0; i < possible_split_values.size() - 1; ++i)
            {
              all_mid_points[i] = (possible_split_values[i] + possible_split_values[i + 1]) / 2;
            }

            // Random number generator for the splits:
            std::uniform_int_distribution<size_t> unif_distvalue(0, all_mid_points.size() - 1);

            // Draw a split:
            drawnvalue = all_mid_points[unif_distvalue(random_number_generator)];

            /*
        // Draw values:
        std::vector<double> drawnvalues;
        drawDoublesWithoutReplacement(drawnvalues, random_number_generator, possible_split_values, 2);

        // ...and take their average:
        drawnvalue = (drawnvalues[0] + drawnvalues[1]) / 2;

*/

            // Make the drawn covariate/split pair:
            drawnpair = std::make_pair(drawnvarID, drawnvalue);

            // Check whether this pair is already existent in the drawn pairs.
            // If this is the case, "pairnotfound" will be set to false
            // and the search for a suitable pair continues.
            pairnotfound = std::find(sampled_varIDs_values.begin(), sampled_varIDs_values.end(), drawnpair) != sampled_varIDs_values.end();
          }

        } while (pairnotfound);

        // Add the drawn pair to "sampled_varIDs_values":
        sampled_varIDs_values.push_back(drawnpair);
      }

      // Some console outputs I had used, while developing the function:
      //std::vector<size_t> gezogenevars;
      //std::vector<double> gezogenepunkte;
      //for (size_t i = 0; i < sampled_varIDs_values.size(); ++i) {
      //	gezogenevars.push_back(std::get<0>(sampled_varIDs_values[i]));
      //	gezogenepunkte.push_back(std::get<1>(sampled_varIDs_values[i]));
      //	}

    }
  }

  void Tree::drawSplitsMultivariate(size_t nodeID, size_t n_triedsplits, std::vector<size_t> &sampled_split_types, std::vector<std::vector<size_t>> &sampled_split_multvarIDs, std::vector<std::vector<std::vector<bool>>> &sampled_split_directs, std::vector<std::vector<std::vector<double>>> &sampled_split_multvalues)
  {

    // Number of features:
    size_t num_independent_variables = data->getNumCols() - data->getNoSplitVariables().size();

    // Random number generator for selecting a random number out of {1,2}:
    std::uniform_int_distribution<size_t> getoneortwo(1, 2);

    // Number of variables:
    size_t num_vars = data->getNumCols();

    // Number of promising feature pairs:
    size_t npromispairs = (*promispairs).size();

    // Random number generator for the promising feature pairs:
    std::uniform_int_distribution<size_t> unif_promispairs(0, npromispairs - 1);

    // For corrected Gini importance add dummy variables
    if (importance_mode == IMP_GINI_CORRECTED)
    {
      num_vars += data->getNumCols() - data->getNoSplitVariables().size();
    }
    // Random number generator for the covariates:
    std::uniform_int_distribution<size_t> unif_distvarID(0, num_independent_variables - 1);

    // Reserve space for the split information:

    ///////////////////size_t npairs = std::min((size_t)ceil(sqrt((double)num_independent_variables) / 2), (size_t)10);
    uint numbermaxtotal = ceil(npairs * 1.5);

    sampled_split_types.resize(7 * numbermaxtotal);
    sampled_split_multvarIDs.resize(7 * numbermaxtotal);
    sampled_split_directs.resize(7 * numbermaxtotal);
    sampled_split_multvalues.resize(7 * numbermaxtotal);

    // Initialize arrays that will contain the information on the
    // current sampled split:
    std::vector<size_t> drawn_types;
    std::vector<std::vector<size_t>> drawn_multvarIDs;
    std::vector<std::vector<std::vector<bool>>> drawn_directs;
    std::vector<std::vector<std::vector<double>>> drawn_multvalues;

    drawn_types.resize(7);
    drawn_multvarIDs.resize(7);
    drawn_directs.resize(7);
    drawn_multvalues.resize(7);

    drawn_multvarIDs[0].resize(1);
    drawn_multvarIDs[1].resize(1);
    for (size_t i = 2; i < 7; ++i)
    {
      drawn_multvarIDs[i].resize(2);
    }

    drawn_directs[0].resize(1);
    drawn_directs[0][0].resize(1);
    drawn_directs[1].resize(1);
    drawn_directs[1][0].resize(1);
    for (size_t i = 2; i < 7; ++i)
    {
      drawn_directs[i].resize(1);
      drawn_directs[i][0].resize(2);
    }

    drawn_multvalues[0].resize(1);
    drawn_multvalues[0][0].resize(1);
    drawn_multvalues[1].resize(1);
    drawn_multvalues[1][0].resize(1);
    for (size_t i = 2; i < 7; ++i)
    {
      drawn_multvalues[i].resize(1);
      drawn_multvalues[i][0].resize(2);
    }

    drawn_types = {1, 1, 2, 3, 4, 5, 6};

    drawn_directs[0][0][0] = true;
    drawn_directs[1][0][0] = true;
    drawn_directs[2][0] = {true, true};
    drawn_directs[3][0] = {true, false};
    drawn_directs[4][0] = {false, true};
    drawn_directs[5][0] = {false, false};
    drawn_directs[6][0] = {true, true};

    ///std::vector<std::vector<size_t>> varcombs = { { 0, 1 }, { 0, 2 }, { 1, 2 } };

    // Sample the n_triedsplits splits:

    size_t countit = 0;
    size_t countitall = 0;

    size_t countsplit = 0;

    size_t counttoeight = 0;

    while (countit < npairs && countitall < numbermaxtotal)
    {

      counttoeight = 0;

      //  Randomly select promising pair:
      size_t pairindex = unif_promispairs(random_number_generator);
      std::vector<size_t> drawnvarIDs = (*promispairs)[pairindex];

      // Randomly permute the order in the selected pair:
      size_t permutepair = getoneortwo(random_number_generator);
      if (permutepair == 1)
      {
        size_t firstelement = drawnvarIDs[0];
        drawnvarIDs[0] = drawnvarIDs[1];
        drawnvarIDs[1] = firstelement;
      }

      // Skip variables not to use for splitting:
      for (size_t i = 0; i < 2; ++i)
      {

        for (auto &skip_value : data->getNoSplitVariables())
        {
          if (drawnvarIDs[i] >= skip_value)
          {
            ++drawnvarIDs[i];
          }
        }
      }

      std::vector<double> values_variable1;
      data->getRawValues(values_variable1, sampleIDs, drawnvarIDs[0], start_pos[nodeID], end_pos[nodeID]);

      std::vector<double> values_variable2;
      data->getRawValues(values_variable2, sampleIDs, drawnvarIDs[1], start_pos[nodeID], end_pos[nodeID]);

      std::vector<double> possible_split_values_1 = values_variable1;
      // Sort the values:
      std::sort(possible_split_values_1.begin(), possible_split_values_1.end());
      possible_split_values_1.erase(std::unique(possible_split_values_1.begin(), possible_split_values_1.end()), possible_split_values_1.end());

      std::vector<double> possible_split_values_2 = values_variable2;
      // Sort the values:
      std::sort(possible_split_values_2.begin(), possible_split_values_2.end());
      possible_split_values_2.erase(std::unique(possible_split_values_2.begin(), possible_split_values_2.end()), possible_split_values_2.end());

      // Discard split, if there is no possible split:
      if (possible_split_values_1.size() < 2 || possible_split_values_2.size() < 2)
      {
        countitall++;
        continue;
      }




  std::uniform_int_distribution<size_t> unif_dist(0, possible_split_values_1.size() - 2);
size_t draw = unif_dist(random_number_generator);

        // ...and take their average:
        double xvalueuniv = (possible_split_values_1[draw] + possible_split_values_1[draw+1])/2;


  std::uniform_int_distribution<size_t> unif_dist2(0, possible_split_values_2.size() - 2);
draw = unif_dist2(random_number_generator);

        // ...and take their average:
        double yvalueuniv = (possible_split_values_2[draw] + possible_split_values_2[draw+1])/2;




		        sampled_split_types[countsplit] = drawn_types[counttoeight];
      sampled_split_directs[countsplit].resize(1);
      sampled_split_directs[countsplit][0].resize(1);
      sampled_split_directs[countsplit] = drawn_directs[counttoeight];
      sampled_split_multvarIDs[countsplit].resize(1);
      sampled_split_multvarIDs[countsplit][0] = drawnvarIDs[0];
      sampled_split_multvalues[countsplit].resize(1);
      sampled_split_multvalues[countsplit][0].resize(1);
      sampled_split_multvalues[countsplit][0][0] = xvalueuniv;
      countsplit++;
      counttoeight++;

      sampled_split_types[countsplit] = drawn_types[counttoeight];
      sampled_split_directs[countsplit].resize(1);
      sampled_split_directs[countsplit][0].resize(1);
      sampled_split_directs[countsplit] = drawn_directs[counttoeight];
      sampled_split_multvarIDs[countsplit].resize(1);
      sampled_split_multvarIDs[countsplit][0] = drawnvarIDs[1];
      sampled_split_multvalues[countsplit].resize(1);
      sampled_split_multvalues[countsplit][0].resize(1);
      sampled_split_multvalues[countsplit][0][0] = yvalueuniv;
      countsplit++;
      counttoeight++;






      size_t numbertried = 0;
      bool foundsplit = false;

      double lowerbound;
      double upperbound;
      double drawnxvalue;

      while (numbertried < 20 && !foundsplit)
      {

        // Draw values:
        std::vector<double> drawnvalues;
        drawDoublesWithoutReplacement(drawnvalues, random_number_generator, possible_split_values_1, 2);

        // ...and take their average:
        drawnxvalue = (drawnvalues[0] + drawnvalues[1]) / 2;


        // Get all values from feature 2, for which feature 1 is smaller
        // than the split point in feature 1:

        std::vector<double> values_variable2_smallx = values_variable2;

        size_t last = 0;
        for (size_t j = 0; j < values_variable2_smallx.size(); j++)
        {
          if (values_variable1[j] < drawnxvalue)
          {
            values_variable2_smallx[last] = values_variable2_smallx[j];
            last++;
          }
        }
        values_variable2_smallx.erase(values_variable2_smallx.begin() + last, values_variable2_smallx.end());

        // Get all values from feature 2, for which feature 1 is larger
        // than the split point in feature 1:

        std::vector<double> values_variable2_largex = values_variable2;

        last = 0;
        for (size_t j = 0; j < values_variable2_largex.size(); j++)
        {
          if (values_variable1[j] > drawnxvalue)
          {
            values_variable2_largex[last] = values_variable2_largex[j];
            last++;
          }
        }
        values_variable2_largex.erase(values_variable2_largex.begin() + last, values_variable2_largex.end());

        // Minimum and maximum values of the y values with x values smaller than p^{1, j_1} / larger than p^{2, j_1}:

        double maxyxsmall = *std::max_element(values_variable2_smallx.begin(), values_variable2_smallx.end());
        double maxyxlarge = *std::max_element(values_variable2_largex.begin(), values_variable2_largex.end());

        double minyxsmall = *std::min_element(values_variable2_smallx.begin(), values_variable2_smallx.end());
        double minyxlarge = *std::min_element(values_variable2_largex.begin(), values_variable2_largex.end());

        lowerbound = std::max(minyxsmall, minyxlarge);
        upperbound = std::min(maxyxsmall, maxyxlarge);

        if (lowerbound >= upperbound)
        {
          numbertried++;
          continue;
        }

        foundsplit = true;
      }

      if (!foundsplit)
      {
        countitall++;
        continue;
      }

      // Delete the y values with x values smaller than p^{1, j_1}
      // that are too large to deliver good splits:

      std::vector<double> possible_split_values_2_interval = possible_split_values_2;

      possible_split_values_2_interval.erase(std::remove_if(possible_split_values_2_interval.begin(), possible_split_values_2_interval.end(),
                                                            [lowerbound, upperbound](double n) { return n < lowerbound || n > upperbound; }),
                                             possible_split_values_2_interval.end());

      if (possible_split_values_2_interval.size() < 2)
      {
        countitall++;
        continue;
      }

      // Draw values:
      std::vector<double> drawnvalues;
      drawDoublesWithoutReplacement(drawnvalues, random_number_generator, possible_split_values_2_interval, 2);

      // ...and take their average:
      double drawnyvalue = (drawnvalues[0] + drawnvalues[1]) / 2;

      for (size_t i = 0; i < 5; ++i)
      {
        sampled_split_types[countsplit] = drawn_types[counttoeight];
        sampled_split_directs[countsplit].resize(1);
        sampled_split_directs[countsplit][0].resize(2);
        sampled_split_directs[countsplit] = drawn_directs[counttoeight];
        sampled_split_multvarIDs[countsplit].resize(2);
        sampled_split_multvarIDs[countsplit] = drawnvarIDs;
        sampled_split_multvalues[countsplit].resize(1);
        sampled_split_multvalues[countsplit][0].resize(2);
        sampled_split_multvalues[countsplit][0] = {drawnxvalue, drawnyvalue};
        countsplit++;
        counttoeight++;
      }

      countitall++;
      countit++;
    }

    // Delete empty elements from the vectors:

    ///Rcpp::Rcout << "countitall  " << countitall << std::endl;
    ///Rcpp::Rcout << "countit " << countit << std::endl;
    ///Rcpp::Rcout << "countsplit " << countsplit << std::endl;

    size_t numberkeep = std::min(countsplit, (size_t)7 * npairs);

    sampled_split_multvarIDs.resize(numberkeep);
    sampled_split_types.resize(numberkeep);
    sampled_split_directs.resize(numberkeep);
    sampled_split_multvalues.resize(numberkeep);

    if (sampled_split_multvarIDs.size() > 0)
    {

      std::vector<size_t> randindices(sampled_split_multvarIDs.size());
      std::iota(randindices.begin(), randindices.end(), 0);
      std::shuffle(randindices.begin(), randindices.end(), random_number_generator);

      //for (size_t i = 0; i < randindices.size(); ++i) {
      //  Rcpp::Rcout << "randindices: " << i << "    " << randindices[i] << std::endl;
      //}

      //for (size_t i = 0; i < sampled_split_multvalues.size(); ++i) {
      //  Rcpp::Rcout << "sampled_split_multvalues davor: " << i << "    " << sampled_split_multvalues[i][0][0] << std::endl;
      //}

      sampled_split_multvarIDs = reorder(sampled_split_multvarIDs, randindices);
      sampled_split_types = reorder(sampled_split_types, randindices);
      sampled_split_directs = reorder(sampled_split_directs, randindices);
      sampled_split_multvalues = reorder(sampled_split_multvalues, randindices);

      //for (size_t i = 0; i < sampled_split_multvalues.size(); ++i) {
      //  Rcpp::Rcout << "sampled_split_multvalues danach: " << i << "    " << sampled_split_multvalues[i][0][0] << std::endl;
      //}
    }

    //Rcpp::Rcout << "Check 27" << std::endl;
  }

  bool Tree::IsInRectangle(const Data *data, size_t sampleID, size_t split_type, std::vector<size_t> &split_multvarID, std::vector<std::vector<bool>> &split_direct, std::vector<std::vector<double>> &split_multvalue)
  {

    // For univariate splits, the value of the variable must be smaller than
    // the split point to be considered within the rectangle:
    if (split_type == 1)
    {

      if (data->get(sampleID, split_multvarID[0]) < split_multvalue[0][0])
      {
        return true;
      }
    }
    else if (split_type == 2)
    {

      // Get the x- and y-axis values of sampledID:
      double value1 = data->get(sampleID, split_multvarID[0]);
      double value2 = data->get(sampleID, split_multvarID[1]);

      // Split corresponding to quantiative interaction:

      // If sampleID is contained within the rectangle both with respect to
      // the x- and the y-dimension, it is actually contained
      // in the rectangle:
      if (value1 < split_multvalue[0][0] && value2 < split_multvalue[0][1])
      {
        return true;
      }
    }
    else if (split_type == 3)
    {

      // Get the x- and y-axis values of sampledID:
      double value1 = data->get(sampleID, split_multvarID[0]);
      double value2 = data->get(sampleID, split_multvarID[1]);

      // Split corresponding to quantiative interaction:

      // If sampleID is contained within the rectangle both with respect to
      // the x- and the y-dimension, it is actually contained
      // in the rectangle:
      if (value1 < split_multvalue[0][0] && value2 > split_multvalue[0][1])
      {
        return true;
      }
    }
    else if (split_type == 4)
    {

      // Get the x- and y-axis values of sampledID:
      double value1 = data->get(sampleID, split_multvarID[0]);
      double value2 = data->get(sampleID, split_multvarID[1]);

      // Split corresponding to quantiative interaction:

      // If sampleID is contained within the rectangle both with respect to
      // the x- and the y-dimension, it is actually contained
      // in the rectangle:
      if (value1 > split_multvalue[0][0] && value2 < split_multvalue[0][1])
      {
        return true;
      }
    }
    else if (split_type == 5)
    {

      // Get the x- and y-axis values of sampledID:
      double value1 = data->get(sampleID, split_multvarID[0]);
      double value2 = data->get(sampleID, split_multvarID[1]);

      // Split corresponding to quantiative interaction:

      // If sampleID is contained within the rectangle both with respect to
      // the x- and the y-dimension, it is actually contained
      // in the rectangle:
      if (value1 > split_multvalue[0][0] && value2 > split_multvalue[0][1])
      {
        return true;
      }
    }
    else
    {

      // Get the x- and y-axis values of sampledID:
      double value1 = data->get(sampleID, split_multvarID[0]);
      double value2 = data->get(sampleID, split_multvarID[1]);

      // Return true if sampleID is either contained in the first or in
      // the second rectangle:
      if ((value1 < split_multvalue[0][0] && value2 < split_multvalue[0][1]) || (value1 > split_multvalue[0][0] && value2 > split_multvalue[0][1]))
      {
        return true;
      }
    }

    // If sampleID is not contained in the rectangle(s) return false:
    return false;
  }

  bool Tree::splitNode(size_t nodeID)
  {

    bool stop;

    /// Rcpp::Rcout << "nodeID: " << nodeID << std::endl;

    if (divfortype == 1)
    {

      // Rcpp::Rcout << "Laenge sampled_split_types" << sampled_split_types.size() << std::endl;

      // Draw the variables and the candidate splits - after performing this step,
      // sampled_varIDs_values will contain the variables and candidate splits:
      size_t n_triedsplits = (size_t)nsplits;
      std::vector<std::pair<size_t, double>> sampled_varIDs_values;
      drawSplitsUnivariate(nodeID, n_triedsplits, sampled_varIDs_values);

      // Perform the splitting using the subclass method:
      stop = splitNodeUnivariateInternal(nodeID, sampled_varIDs_values);

      if (stop)
      {
        // Terminal node
        return true;
      }

      size_t split_varID = split_varIDs[nodeID];
      double split_value = split_values[nodeID];

      // Save non-permuted variable for prediction
      split_varIDs[nodeID] = data->getUnpermutedVarID(split_varID);

      // Create child nodes
      size_t left_child_nodeID = split_varIDs.size();
      child_nodeIDs[0][nodeID] = left_child_nodeID;
      createEmptyNode();
      start_pos[left_child_nodeID] = start_pos[nodeID];

      size_t right_child_nodeID = split_varIDs.size();
      child_nodeIDs[1][nodeID] = right_child_nodeID;
      createEmptyNode();
      start_pos[right_child_nodeID] = end_pos[nodeID];

      // For each sample in node, assign to left or right child
      if (data->isOrderedVariable(split_varID))
      {
        // Ordered: left is <= splitval and right is > splitval
        size_t pos = start_pos[nodeID];
        while (pos < start_pos[right_child_nodeID])
        {
          size_t sampleID = sampleIDs[pos];
          if (data->get(sampleID, split_varID) <= split_value)
          {
            // If going to left, do nothing
            ++pos;
          }
          else
          {
            // If going to right, move to right end
            --start_pos[right_child_nodeID];
            std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]);
          }
        }
      }
      else
      {
        // Unordered: If bit at position is 1 -> right, 0 -> left
        size_t pos = start_pos[nodeID];
        while (pos < start_pos[right_child_nodeID])
        {
          size_t sampleID = sampleIDs[pos];
          double level = data->get(sampleID, split_varID);
          size_t factorID = floor(level) - 1;
          size_t splitID = floor(split_value);

          // Left if 0 found at position factorID
          if (!(splitID & (1 << factorID)))
          {
            // If going to left, do nothing
            ++pos;
          }
          else
          {
            // If going to right, move to right end
            --start_pos[right_child_nodeID];
            std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]);
          }
        }
      }

      // End position of left child is start position of right child
      end_pos[left_child_nodeID] = start_pos[right_child_nodeID];
      end_pos[right_child_nodeID] = end_pos[nodeID];

      // Rcpp::Rcout << "left_child_nodeID: " << left_child_nodeID << std::endl;

      // No terminal node
      return false;
    }

    if (divfortype == 2)
    {

      size_t n_triedsplits = (size_t)nsplits;
      std::vector<size_t> sampled_split_types;
      std::vector<std::vector<size_t>> sampled_split_multvarIDs;
      std::vector<std::vector<std::vector<bool>>> sampled_split_directs;
      std::vector<std::vector<std::vector<double>>> sampled_split_multvalues;
      drawSplitsMultivariate(nodeID, n_triedsplits, sampled_split_types, sampled_split_multvarIDs, sampled_split_directs, sampled_split_multvalues);

      // Perform the splitting using the subclass method:
      stop = splitNodeMultivariateInternal(nodeID, sampled_split_types, sampled_split_multvarIDs, sampled_split_directs, sampled_split_multvalues); // asdf

      if (stop)
      {
        // Terminal node
        return true;
      }

      size_t split_type = split_types[nodeID];
      std::vector<size_t> split_multvarID = split_multvarIDs[nodeID];
      std::vector<std::vector<bool>> split_direct = split_directs[nodeID];
      std::vector<std::vector<double>> split_multvalue = split_multvalues[nodeID];

      // Save non-permuted variable for prediction
      for (size_t i = 0; i < split_multvarIDs[nodeID].size(); ++i)
      {
        split_multvarIDs[nodeID][i] = data->getUnpermutedVarID(split_multvarIDs[nodeID][i]);
      }

      // Create child nodes
      size_t left_child_nodeID = split_multvarIDs.size();
      child_nodeIDs[0][nodeID] = left_child_nodeID;
      createEmptyNodeMultivariate();
      start_pos[left_child_nodeID] = start_pos[nodeID];

      size_t right_child_nodeID = split_multvarIDs.size();
      child_nodeIDs[1][nodeID] = right_child_nodeID;
      createEmptyNodeMultivariate();
      start_pos[right_child_nodeID] = end_pos[nodeID];

      // For each sample in node, assign to left or right child
      /////if (data->isOrderedVariable(split_varID))  Remark: Currently only ordered variables
      /////{
      // Ordered: left is <= splitval and right is > splitval
      bool inrectangle;
      size_t pos = start_pos[nodeID];
      while (pos < start_pos[right_child_nodeID])
      {
        size_t sampleID = sampleIDs[pos];

        //// MAKE NEW FUNCTION IsInRectangle TO CHECK WHETHER IN RECTANGLE

        /// 1. Schritt: IsInRectangle in Tree.cpp definieren

        inrectangle = IsInRectangle(data, sampleID, split_type, split_multvarID, split_direct, split_multvalue);
        if (inrectangle)
        {
          // If going to left, do nothing
          ++pos;
        }
        else
        {
          // If going to right, move to right end
          --start_pos[right_child_nodeID];
          std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]);
        }
      }

      // End position of left child is start position of right child
      end_pos[left_child_nodeID] = start_pos[right_child_nodeID];
      end_pos[right_child_nodeID] = end_pos[nodeID];

      // Rcpp::Rcout << "left_child_nodeID: " << left_child_nodeID << std::endl;

      // No terminal node
      return false;
    }

    // To satisfy the compiler:
    return false;
  }

  // Unity Forests: Split node in random tree
  bool Tree::splitNodeRandom(size_t nodeID, const std::vector<size_t>& varIDs_root)
  {

    // Print line number:
    // // Rcpp::Rcout << "Line number: " << __LINE__ << std::endl << std::flush;

    bool stop = checkWhetherFinalRandom(nodeID);

    if (stop)
    {
      // Terminal node
      return true;
    }

    // Randomly draw one variable out of varIDs_root that has at least two unique values in the current node and is not a no split variable

/*

    // Determine the number of variables to try
    size_t max_variables_to_try = std::min<size_t>(500, varIDs_root.size());

    // Randomly sample up to 500 variables without replacement
    std::vector<size_t> sampled_varIDs;
    sampled_varIDs.reserve(max_variables_to_try);

    // Create an index vector for varIDs_root
    std::vector<size_t> indices(varIDs_root.size());
    std::iota(indices.begin(), indices.end(), 0);

    // Shuffle the indices
    std::shuffle(indices.begin(), indices.end(), random_number_generator);

    // Select the first 'max_variables_to_try' indices
    for (size_t i = 0; i < max_variables_to_try; ++i)
    {
      sampled_varIDs.push_back(varIDs_root[indices[i]]);
    }
	
	*/
	
	std::uniform_int_distribution<size_t> uni(0, varIDs_root.size() - 1);

    size_t varID;	

    bool found_variable = false;
    size_t drawnvarID;

    /*

    //for (size_t varID : sampled_varIDs)
    for (size_t t = 0; t < 500; ++t)
    {
	
	  varID = varIDs_root[ uni(random_number_generator) ];
	  
      // Early stopping after finding two unique values
      size_t unique_value_count = 0;
      double first_value;
      bool first_value_set = false;

      for (size_t pos = start_pos_loop[nodeID]; pos < end_pos_loop[nodeID]; ++pos)
      {
        double value = data->get(sampleIDs[pos], varID);
        if (!first_value_set)
        {
          first_value = value;
          first_value_set = true;
        }
        else if (value != first_value)
        {
          unique_value_count = 2;
          break;
        }
      }

      if (unique_value_count == 2)
      {
        drawnvarID = varID;
        found_variable = true;
        break;
      }
    }

    if (!found_variable)
    {
      return true; // No suitable variable found; make node terminal
    }

    */

    //for (size_t varID : sampled_varIDs)
    for (size_t t = 0; t < 500; ++t)
    {
	
	  varID = varIDs_root[ uni(random_number_generator) ];
	  
	  if (twoDifferentValues(nodeID, varID)) 
	  { 
         drawnvarID = varID;
		 found_variable = true;
		 break;
	  }
	  
    }

    if (!found_variable)
    {
      return true; // No suitable variable found; make node terminal
    }


    // Collect values into a vector
    values_buffer.clear();
    values_buffer.reserve(end_pos_loop[nodeID] - start_pos_loop[nodeID]);

    for (size_t pos = start_pos_loop[nodeID]; pos < end_pos_loop[nodeID]; ++pos)
    {
      values_buffer.push_back(data->get(sampleIDs[pos], drawnvarID));
    }

    // Sort and remove duplicates
    std::sort(values_buffer.begin(), values_buffer.end());
    auto last = std::unique(values_buffer.begin(), values_buffer.end());
    values_buffer.erase(last, values_buffer.end());

    // Select a split value between two neighboring unique values
    std::uniform_int_distribution<size_t> unif_dist(0, values_buffer.size() - 2);
    size_t index = unif_dist(random_number_generator);
    double split_value = (values_buffer[index] + values_buffer[index + 1]) / 2;


    // Set the split variable and split value
    split_varIDs_loop[nodeID] = data->getUnpermutedVarID(drawnvarID);
    split_values_loop[nodeID] = split_value;


    // Create child nodes
    size_t left_child_nodeID = split_varIDs_loop.size();
    child_nodeIDs_loop[0][nodeID] = left_child_nodeID;
    createEmptyNodeRandomTree();
    start_pos_loop[left_child_nodeID] = start_pos_loop[nodeID];

    size_t right_child_nodeID = split_varIDs_loop.size();
    child_nodeIDs_loop[1][nodeID] = right_child_nodeID;
    createEmptyNodeRandomTree();
    start_pos_loop[right_child_nodeID] = end_pos_loop[nodeID];

    // For each sample in node, assign to left or right child

    // Ordered: left is <= splitval and right is > splitval
    size_t pos = start_pos_loop[nodeID];
    while (pos < start_pos_loop[right_child_nodeID])
    {
      size_t sampleID = sampleIDs[pos];
      if (data->get(sampleID, drawnvarID) <= split_value)
      {
        // If going to left, do nothing
        ++pos;
      }
      else
      {
        // If going to right, move to right end
        --start_pos_loop[right_child_nodeID];
        std::swap(sampleIDs[pos], sampleIDs[start_pos_loop[right_child_nodeID]]);
      }
    }

    // End position of left child is start position of right child
    end_pos_loop[left_child_nodeID] = start_pos_loop[right_child_nodeID];
    end_pos_loop[right_child_nodeID] = end_pos_loop[nodeID];

    // No terminal node
    return false;
  }

// Unity Forests:
bool Tree::twoDifferentValues(size_t nodeID, size_t varID)
{
    const double first =
        data->get(sampleIDs[start_pos_loop[nodeID]], varID);

    for (size_t pos = start_pos_loop[nodeID] + 1; pos < end_pos_loop[nodeID]; ++pos) {
        if (data->get(sampleIDs[pos], varID) != first) 
			return true;
	}

    return false;      // all equal
}

  void Tree::createEmptyNode()
  {
    split_varIDs.push_back(0);
    split_values.push_back(0);
    child_nodeIDs[0].push_back(0);
    child_nodeIDs[1].push_back(0);
    start_pos.push_back(0);
    end_pos.push_back(0);

    createEmptyNodeInternal();
  }
  
  // Unity Forests: Create an empty node in a random tree
  void Tree::createEmptyNodeRandomTree()
  {
    split_varIDs_loop.push_back(0);
    split_values_loop.push_back(0);
    child_nodeIDs_loop[0].push_back(0);
    child_nodeIDs_loop[1].push_back(0);
    start_pos_loop.push_back(0);
    end_pos_loop.push_back(0);

    createEmptyNodeRandomTreeInternal();
  }

  // Unity Forests: Create an empty node in a (full) tree
  void Tree::createEmptyNodeFullTree()
  {
    split_varIDs.push_back(0);
    split_values.push_back(0);
    child_nodeIDs[0].push_back(0);
    child_nodeIDs[1].push_back(0);
    start_pos.push_back(0);
    end_pos.push_back(0);
	nodeID_in_root.push_back(0);

    createEmptyNodeFullTreeInternal();
  }

  void Tree::createEmptyNodeMultivariate()
  {

    split_types.push_back(0);
    split_multvarIDs.push_back(std::vector<size_t>());
    split_directs.push_back(std::vector<std::vector<bool>>());
    split_multvalues.push_back(std::vector<std::vector<double>>());
    child_nodeIDs[0].push_back(0);
    child_nodeIDs[1].push_back(0);
    start_pos.push_back(0);
    end_pos.push_back(0);

    createEmptyNodeInternal();
  }
  
  // Unity Forests: Function used to clear some objects from the random trees
  void Tree::clearRandomTree()
  {
    split_varIDs_loop.clear();
    split_values_loop.clear();
    start_pos_loop.clear();
    end_pos_loop.clear();
    child_nodeIDs_loop[0].clear();
    child_nodeIDs_loop[1].clear();

    clearRandomTreeInternal();
  }

// Unity Forests: Calculate the difference between the Gini split criterion values in a node using only the OOB observations
// before and after permuting the values of the split variable:
double Tree::calculateSplitCriterionDiffOOB(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID)
{
  return 0.0;
}

  size_t Tree::dropDownSamplePermuted(size_t permuted_varID, size_t sampleID, size_t permuted_sampleID)
  {

    // Start in root and drop down
    size_t nodeID = 0;
    while (child_nodeIDs[0][nodeID] != 0 || child_nodeIDs[1][nodeID] != 0)
    {

      // Permute if variable is permutation variable
      size_t split_varID = split_varIDs[nodeID];
      size_t sampleID_final = sampleID;
      if (split_varID == permuted_varID)
      {
        sampleID_final = permuted_sampleID;
      }

      // Move to child
      double value = data->get(sampleID_final, split_varID);
      if (data->isOrderedVariable(split_varID))
      {
        if (value <= split_values[nodeID])
        {
          // Move to left child
          nodeID = child_nodeIDs[0][nodeID];
        }
        else
        {
          // Move to right child
          nodeID = child_nodeIDs[1][nodeID];
        }
      }
      else
      {
        size_t factorID = floor(value) - 1;
        size_t splitID = floor(split_values[nodeID]);

        // Left if 0 found at position factorID
        if (!(splitID & (1 << factorID)))
        {
          // Move to left child
          nodeID = child_nodeIDs[0][nodeID];
        }
        else
        {
          // Move to right child
          nodeID = child_nodeIDs[1][nodeID];
        }
      }
    }
    return nodeID;
  }

  // Unity forest: Drop OOB sample sampleID down the tree and whenever a split uses
  // varID as split variable make random assignment to child node
  size_t Tree::randomizedDropDownSample(size_t varID, size_t sampleID)
  {

    // Start in root and drop down
    size_t nodeID = 0;
	bool randomized_before = false;
	
    while (child_nodeIDs[0][nodeID] != 0 || child_nodeIDs[1][nodeID] != 0)
    {

      // If the split uses varID for splitting or has used it before, make random assignment
      // to child node:
      if (split_varIDs[nodeID] == varID || randomized_before)
      {

        // Randomly decide to which child node the observation
        // is assigned, where the probabilities for the two child
        // nodes are proportional to their sizes:
        bool toleft = randomAssignLeftChildNode(nodeID);
        if (toleft)
        {
          // Move to left child
          nodeID = child_nodeIDs[0][nodeID];
        }
        else
        {
          // Move to right child
          nodeID = child_nodeIDs[1][nodeID];
        }
		
		    randomized_before = true;
		
      }
      else
      {

        // If the split does not use varID for splitting,
        // just assign observation to child node according to the
        // regular split point:
 
        double value = data->get(sampleID, split_varIDs[nodeID]);
        if (value <= split_values[nodeID])
        {
          // Move to left child
          nodeID = child_nodeIDs[0][nodeID];
        }
        else
        {
          // Move to right child
          nodeID = child_nodeIDs[1][nodeID];
        }
      }
    }
    return nodeID;
  }

  // Unity Forest: Used for randomly assigning OOB observations to child nodes:
  // Calculates the sizes of the two child nodes and randomly assigns observation
  // to one of them, where the probabilities for assigning to either of these
  // are proportional to the child node sizes:
  bool Tree::randomAssignLeftChildNode(size_t nodeID)
  {

    // Sizes of the child nodes:
    size_t num_samples_leftnode = end_pos[child_nodeIDs[0][nodeID]] - start_pos[child_nodeIDs[0][nodeID]];
    size_t num_samples_rightnode = end_pos[child_nodeIDs[1][nodeID]] - start_pos[child_nodeIDs[1][nodeID]];

    // Left child node size divided by node size:
    double prob_left = (double)num_samples_leftnode / ((double)(num_samples_leftnode + num_samples_rightnode));

    // Return true with probability prob_left and false with probability 1 - prob_left:

    // Draw [0,1] uniformly distributed number:
    std::uniform_real_distribution<double> distribution(0.0, 1.0);
    double rand_num = distribution(random_number_generator);
    if (rand_num <= prob_left)
      return true;

    return false;
  }

  void Tree::permuteAndPredictOobSamples(size_t permuted_varID, std::vector<size_t> &permutations)
  {

    // Permute OOB sample
    //std::vector<size_t> permutations(oob_sampleIDs);
    std::shuffle(permutations.begin(), permutations.end(), random_number_generator);

    // For each sample, drop down the tree and add prediction
    for (size_t i = 0; i < num_samples_oob; ++i)
    {
      size_t nodeID = dropDownSamplePermuted(permuted_varID, oob_sampleIDs[i], permutations[i]);
      prediction_terminal_nodeIDs[i] = nodeID;
    }
  }

  // Unity forest: Drop all OOB samples down the tree and whenever a split uses
  // varID as split variable make random assignment to child node
  void Tree::randomizedDropDownOobSamples(size_t varID)
  {

    // For each sample, drop down the tree and add prediction
    for (size_t i = 0; i < num_samples_oob; ++i)
    {
      size_t nodeID = randomizedDropDownSample(varID, oob_sampleIDs[i]);
      prediction_terminal_nodeIDs[i] = nodeID;
    }
  }

  void Tree::bootstrap()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];

    // Reserve space, reserve a little more to be save)
    sampleIDs.reserve(num_samples_inbag);
    oob_sampleIDs.reserve(num_samples * (exp(-(*sample_fraction)[0]) + 0.1));

    std::uniform_int_distribution<size_t> unif_dist(0, num_samples - 1);

    // Start with all samples OOB
    inbag_counts.resize(num_samples, 0);

    // Draw num_samples samples with replacement (num_samples_inbag out of n) as inbag and mark as not OOB
    for (size_t s = 0; s < num_samples_inbag; ++s)
    {
      size_t draw = unif_dist(random_number_generator);
      sampleIDs.push_back(draw);
      ++inbag_counts[draw];
    }

    // Save OOB samples
    for (size_t s = 0; s < inbag_counts.size(); ++s)
    {
      if (inbag_counts[s] == 0)
      {
        oob_sampleIDs.push_back(s);
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

  void Tree::bootstrapWeighted()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];

    // Reserve space, reserve a little more to be save)
    sampleIDs.reserve(num_samples_inbag);
    oob_sampleIDs.reserve(num_samples * (exp(-(*sample_fraction)[0]) + 0.1));

    std::discrete_distribution<> weighted_dist(case_weights->begin(), case_weights->end());

    // Start with all samples OOB
    inbag_counts.resize(num_samples, 0);

    // Draw num_samples samples with replacement (n out of n) as inbag and mark as not OOB
    for (size_t s = 0; s < num_samples_inbag; ++s)
    {
      size_t draw = weighted_dist(random_number_generator);
      sampleIDs.push_back(draw);
      ++inbag_counts[draw];
    }

    // Save OOB samples. In holdout mode these are the cases with 0 weight.
    if (holdout)
    {
      for (size_t s = 0; s < (*case_weights).size(); ++s)
      {
        if ((*case_weights)[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    else
    {
      for (size_t s = 0; s < inbag_counts.size(); ++s)
      {
        if (inbag_counts[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

  void Tree::bootstrapWithoutReplacement()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];
    shuffleAndSplit(sampleIDs, oob_sampleIDs, num_samples, num_samples_inbag, random_number_generator);
    num_samples_oob = oob_sampleIDs.size();

    if (keep_inbag)
    {
      // All observation are 0 or 1 times inbag
      inbag_counts.resize(num_samples, 1);
      for (size_t i = 0; i < oob_sampleIDs.size(); i++)
      {
        inbag_counts[oob_sampleIDs[i]] = 0;
      }
    }
  }

  void Tree::bootstrapWithoutReplacementWeighted()
  {

    // Use fraction (default 63.21%) of the samples
    size_t num_samples_inbag = (size_t)num_samples * (*sample_fraction)[0];
    drawWithoutReplacementWeighted(sampleIDs, random_number_generator, num_samples - 1, num_samples_inbag, *case_weights);

    // All observation are 0 or 1 times inbag
    inbag_counts.resize(num_samples, 0);
    for (auto &sampleID : sampleIDs)
    {
      inbag_counts[sampleID] = 1;
    }

    // Save OOB samples. In holdout mode these are the cases with 0 weight.
    if (holdout)
    {
      for (size_t s = 0; s < (*case_weights).size(); ++s)
      {
        if ((*case_weights)[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    else
    {
      for (size_t s = 0; s < inbag_counts.size(); ++s)
      {
        if (inbag_counts[s] == 0)
        {
          oob_sampleIDs.push_back(s);
        }
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

  void Tree::bootstrapClassWise()
  {
    // Empty on purpose (virtual function only implemented in classification and probability)
  }

  void Tree::bootstrapWithoutReplacementClassWise()
  {
    // Empty on purpose (virtual function only implemented in classification and probability)
  }

  void Tree::setManualInbag()
  {
    // Select observation as specified in manual_inbag vector
    sampleIDs.reserve(manual_inbag->size());
    inbag_counts.resize(num_samples, 0);
    for (size_t i = 0; i < manual_inbag->size(); ++i)
    {
      size_t inbag_count = (*manual_inbag)[i];
      if ((*manual_inbag)[i] > 0)
      {
        for (size_t j = 0; j < inbag_count; ++j)
        {
          sampleIDs.push_back(i);
        }
        inbag_counts[i] = inbag_count;
      }
      else
      {
        oob_sampleIDs.push_back(i);
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    // Shuffle samples
    std::shuffle(sampleIDs.begin(), sampleIDs.end(), random_number_generator);

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

} // namespace unityForest
