### FUNCTION FOR DISCONTINUOUS HAMILTONIAN MONTE CARLO
#' Discontinuous Hamiltonian Monte Carlo using both manual and automatic termination criteria.
#'
#' @description The function allows generating multiple Markov Chains for sampling from both continuous and discontinuous
#' posterior distributions using a variety of algorithms. Classic Hamiltonian Monte Carlo \insertCite{duane1987hybrid}{XDNUTS}, 
#' NUTS \insertCite{hoffman2014no}{XDNUTS}, and XHMC \insertCite{betancourt2016identifying}{XDNUTS} are embedded into the framework
#' described in \insertCite{nishimura2020discontinuous}{XDNUTS}, which allows dealing with such posteriors.
#' Furthermore, for each method, it is possible to recycle samples from the trajectories using
#' the method proposed by \insertCite{Nishimura_2020}{XDNUTS}.
#' This is used to improve the estimate of the Mass Matrix during the warm-up phase
#' without requiring a relevant additional computational cost.
#' 
#' @param theta0 a list containing the starting values for each chain. These starting values are vectors of length-\eqn{d}. 
#' The last \eqn{k \in [0,d]} elements refer to parameters which determine a discontinuity in the posterior distribution. 
#' @param nlp a function which evaluates the negative log posterior and its gradient with respect to 
#' parameters that do not induce any discontinuity in the posterior distribution (more generally, the first \eqn{d-k} parameters).
#' This function must take 3 arguments:
#' \describe{
#' \item{par}{a vector of length-\eqn{d} containing the parameter values.}
#' \item{args}{a list object that contains the necessary arguments, namely data and hyperparameters.}
#' \item{eval_nlp}{a boolean value, \code{TRUE} to evaluate only the negative log posterior of the models, 
#' \code{FALSE} to evaluate its gradient with respect to the continuous components of the posterior.}
#' }
#' 
#' @param args a list containing the inputs for the negative posterior function.
#' @param k an integer value that states the number of parameters that determines a discontinuity in the posterior distribution.
#' Actually, since the algorithm proposed in \insertCite{nishimura2020discontinuous}{XDNUTS} also works for the full continuous case,
#' \code{k} is the number of parameters specified by the user for which this algorithm is used.
#' @param tau the threshold for the virial termination criterion \insertCite{betancourt2016identifying}{XDNUTS}.
#' @param L the desired length of the trajectory of classic Hamiltonian Monte Carlo algorithm.
#' @param N the number of draws from the posterior distribution, after warm-up, for each chain. Default value is \code{1000}.
#' @param K the number of recycled samples per iteration used by default during the warm-up phase.
#' Default value is \code{3}. To recycle in the sampling phase too, specify \code{recycle_only_init = FALSE}
#' in the \code{control} argument above.
#' @param method a character value which defines the type of algorithm to exploit:\describe{
#' \item{\code{"NUTS"}}{applies the No U-Turn Sampler of \insertCite{hoffman2014no}{XDNUTS}.}
#' \item{\code{"XHMC"}}{applies the Exhaustion Hamiltonian Monte Carlo of \insertCite{betancourt2016identifying}{XDNUTS}.}
#' \item{\code{"HMC"}}{applies one of the classic version of Hamiltonian Monte Carlo algorithm,
#' in particular the one described in \insertCite{betancourt2017conceptual}{XDNUTS}, which samples from the trajectory instead of always returning the last value.}
#' }
#' @param thin the number of necessary and discarded samples to obtain a final iteration of one chain.
#' @param control an object of class \code{control_xdnuts}, output of the function \link{set_parameters}.
#' @param parallel a boolean value specifying whether the chains must be run in parallel. Default value is \code{FALSE}.
#' @param verbose a boolean value for printing all the information regarding the sampling process.
#' @param hide a boolean value that omits the printing to the console if set to \code{TRUE}.
#'
#' @return a list of class \code{XDNUTS} containing \describe{
#' \item{chains}{a list of the same length of \code{theta0}, each element containing the output from the function \link{main_function}.}
#' \item{d}{the dimension of the parameter space.}
#' \item{k}{the number of parameters that lead to a discontinuous posterior distribution. 
#' Or, more generally, for which the algorithm of \insertCite{nishimura2020discontinuous}{XDNUTS} is exploited.}
#' \item{K}{the number of recycled samples for each iteration during the sampling phase.}
#' \item{N}{the number of posterior draws for each chain.}
#' \item{method}{the MCMC method used. This could be either "NUTS", "XHMC", or "HMC".}
#' \item{tau}{the threshold for the virial termination criterion \insertCite{betancourt2016identifying}{XDNUTS}. 
#' Only if \code{method = "XHMC"} this value is different from zero.}
#' \item{L}{the desired length of the trajectory of classic Hamiltonian Monte Carlo algorithm specified by the user.
#' This argument is necessary if \code{method = "HMC"}.}
#' \item{thin}{the number of discarded samples for every final iteration, specified by the user.}
#' \item{control}{an object of class \code{control_xdnuts}, output of the function \link{set_parameters} with arguments specified by the user.}
#' \item{verbose}{the boolean value specified by the user regarding the printing of the sampling process information.}
#' \item{parallel}{the boolean value specified by the user regarding parallel processing.}
#' }
#' 
#' 
#' @references 
#'  \insertAllCited{}
#'
#' @export xdnuts
xdnuts <- function(theta0,
                  nlp,
                  args,
                  k,
                  N = 1e3,
                  K = 3,
                  method = "NUTS",
                  tau = NULL,
                  L = NULL,
                  thin = 1,
                  control = set_parameters(),
                  parallel = FALSE,
                  verbose = FALSE,
                  hide = FALSE){
  #require(purrr)
  
  #let's make sure that the first input is a list
  if(!is.list(theta0)){
    base::stop("'theta0' must be a list containing the initial value for each chain!")
  }
  
  #let's make sure that the initial value of the chains are admissible
  for(i in seq_along(theta0)){
    if(!is.finite(nlp(theta0[[i]],args,1)) || any(!is.finite(nlp(theta0[[i]],args,2)))){
      base::stop("Not an admissible starting value for chain ", i)
    }
  }
  
  #let's make sure that nlp is a function
  if(!is.function(nlp)){
    base::stop("'nlp' must be a function object!")
  }
  
  #let's make sure that args is a list
  if(!is.list(args)){
    base::stop("'args' must be a list object!")
  }
  
  #let's make sure that k doesn't exceed both 0 and d
  if(all(k < 0 | k > length(theta0[[1]])) || length(k) > 1){
    base::stop("'k' must be a scalar, bounded  between 0 and the total number of parameters!")
  }
  
  #let's make sure that the sample size is adequate
  if(all(N <= 0) || length(N) > 1){
    base::stop("'N' must be an integer scalar greater than zero!")
  }
  
  #let's make sure that the number of recycled samples is adequate
  if(all(K <= 0) || length(K) > 1){
    base::stop("'K' must be an integer scalar greater than zero!")
  }
  
  #let's make sure that the method specified is included among those available
  if(! (method %in% c("NUTS","XHMC","HMC"))){
    base::stop("'method' must be either 'NUTS', 'XHMC' or 'HMC'!")
  }
  
  #once the method is specified set to zero the unnecessary arguments
  #and make sure that the necessary one are appropriate
  if(method == "NUTS"){
    tau <- 0
    L <- 0
  }else if(method == "XHMC"){
    L <- 0
    if(is.null(tau)){
      base::stop("'tau' must be specified!")
    }
    if(all(tau <= 0) || length(tau) > 1){
      base::stop("'tau' must be a scalar greater than zero!")
    }
  }else if(method == "HMC"){
    tau <- 0
    if(is.null(L)){
      base::stop("'L' must be specified!")
    }
    if(all(L <= 0) || length(L) > 1){
      base::stop("'L' must be a scalar integer greater than zero!")
    }
  }
  
  #let's make sure that the number of sample to discard is appropriate
  if(all(thin <= 0) || length(thin) > 1){
    base::stop("'thin' must be a scalar integer greater than zero!")
  }
  
  #let's make sure that the control arguments is of the type control_xdnuts
  #and not a simple list
  if(as.character(class(control)[1]) != "control_xdnuts"){
    base::stop("'control' must be an object of class control_xdnuts!")
  }
  
  #let's make sure that the parallel argument is logical
  if(!is.logical(parallel) || length(parallel) > 1){
    base::stop("'parallel' must be a logical scalar!")
  }
  
  #let's make sure that the verbose argument is logical
  if(!is.logical(verbose) || length(verbose) > 1){
    base::stop("'verbose' must be a logical scalar!")
  }
  
  #let's make sure that the hide argument is logical
  if(!is.logical(hide) || length(hide) > 1){
    base::stop("'hide' must be a logical scalar!")
  }
  
  #get the number of chains from the length of the list
  n_chains <- length(theta0)
  
  #get the name of the coordinate from the first element of the list
  nomi <- names(theta0[[1]])
  if(is.null(nomi)){
    #if no name is specified, set it as 'theta' by default
    nomi <- base::paste0("theta",seq_along(theta0[[1]]))
  }
  
  
  #MCMC
  if(!parallel){
    
    #initialize the output list
    mcmc_out <- list(chains = list(),
                     d = length(theta0[[1]]),
                     k = k,
                     K = K,
                     N = N,
                     method = method,
                     tau = tau,
                     L = L,
                     thin = thin,
                     control = control,
                     verbose = verbose,
                     parallel = parallel)
    
    #create the seed for each cluster, this is obtain the same results
    #with the parallel case
    #seeds <- stats::rexp(n_chains,1e-3)
    
    #let's cycle for every chain
    for(i in seq_len(n_chains)){
      #set seed
      #set.seed(seeds[i])
      
      if(hide){
        #if the user doesn't want it printed on console
        #use the function quietly to silent it
        mcmc_out$chains[[i]] <- purrr::quietly(main_function)(theta0 = theta0[[i]],
                                              nlp = nlp,
                                              args = args,
                                              k = k,
                                              N = N,
                                              K = K,
                                              tau = tau,
                                              L = L,
                                              thin = thin,
                                              chain_id = i,
                                              verbose = verbose,
                                              control = control)$result
      }else{
        #otherwise use the plain main_function
        mcmc_out$chains[[i]] <- main_function(theta0 = theta0[[i]],
                                              nlp = nlp,
                                              args = args,
                                              k = k,
                                              N = N,
                                              K = K,
                                              tau = tau,
                                              L = L,
                                              thin = thin,
                                              chain_id = i,
                                              verbose = verbose,
                                              control = control)
      }
      
      #let's give the name of the parameters to each chain
      base::colnames(mcmc_out$chains[[i]]$values) <- nomi
      if(control$keep_warm_up == TRUE){
        #do the same on the warm up matrices if present
        base::colnames(mcmc_out$chains[[i]]$warm_up) <- nomi
      }
      if(!hide){
        cat("\n")
      }
    }
    
    #let's make the output an S3 object by assigning it a class
    class(mcmc_out) <- "XDNUTS"
    
  }else{
    #parallel chains
    #require(parallel)
    
    #create the seed for each cluster
    #seeds <- stats::rexp(n_chains,1e-3)
    
    #creation of the function to run in parallel
    #f <- function(i,theta0,nlp,args,k,N,K,tau,L,thin,verbose,control,seeds){
    f <- function(i,theta0,nlp,args,k,N,K,tau,L,thin,verbose,control){
      #set the seed of this cluster
      #set.seed(seeds[i])
      
      #call the C++ function
      main_function(theta0 = theta0[[i]],
                    nlp = nlp,
                    args = args,
                    k = k,
                    N = N,
                    K = K,
                    thin = thin,
                    tau = tau,
                    L = L,
                    chain_id = i,
                    verbose = verbose,
                    control = control)
    }
    
    #cluster initialization
    if(hide){
      #no output are printed to console
      cl <- parallel::makeCluster(n_chains)
    }else{
      #the output are printed to console
      cl <- parallel::makeCluster(n_chains, outfile = "")
    }
    
    #parallel chains
    res <- base::tryCatch(parallel::parLapply(cl,
                          seq_len(n_chains),
                          f,
                          theta0 = theta0,
                          nlp = nlp,
                          args = args,
                          k = k,
                          N = N,
                          K = K,
                          tau = tau,
                          L = L,
                          thin = thin,
                          verbose = verbose,
                          control = control),#,
                          #seeds = seeds),
                    error = function(x) NULL)
    
    #stop the cluster
    parallel::stopCluster(cl)
    
    #verify that everything is ok, otherwise reports an error
    if(is.null(res)) base::stop("Something went wrong in processing the sampling in parallel. Try parallel = FALSE.")
    
    #let's assign each coordinate it's name, for every chain
    for(i in seq_along(res)){
      base::colnames(res[[i]]$values) <- nomi
      if(control$keep_warm_up == TRUE && control$N_adapt > 0){
        #do the same to the warm up matrices if present
        base::colnames(res[[i]]$warm_up) <- nomi
      }
    }
    
    #let's create the output list
    mcmc_out <- list(chains = res,
                     d = length(theta0[[1]]),
                     k = k,
                     K = K,
                     N = N,
                     method = method,
                     tau = tau,
                     L = L,
                     thin = thin,
                     control = control,
                     verbose = verbose,
                     parallel = parallel)
    
    #let's make the output an S3 object by assigning it a class
    class(mcmc_out) <- "XDNUTS"
    
  }
  
  #count the number of divergent transitions encountered during Hamilton equation
  #approximation
  n_div <- sum(base::sapply(mcmc_out$chains,function(x) sum(base::NROW(x$div_trans))))
  if(n_div != 0){
    #report to the user the number of this
    base:: warning(n_div, " trajectory ended with a divergent transition!")
  }
  
  #let's make sure that the K field of the output list is the one relative
  #to the sampling phase and not the warm up one
  if(control$recycle_only_init){
    mcmc_out$K <- 1
  }
  
  #return the output
  return(mcmc_out)
}

### PLOTS FUNCTION OF THE MCMC OUTPUT
#' Function to view the draws from the posterior distribution.
#'
#' @param x an object of class \code{XDNUTS}.
#' @param type the type of plot to display. \describe{
#' \item{\code{type = 1}}{marginal chains, one for each desired dimension.}
#' \item{\code{type = 2}}{bivariate plot.}
#' \item{\code{type = 3}}{time series plot of the energy level sets. Good for a quick diagnostic of big models.}
#' \item{\code{type = 4}}{stickplot of the step-length of each iteration.}
#' \item{\code{type = 5}}{Histograms of the centered marginal energy in gray and of the first differences of energy in red.}
#' \item{\code{type = 6}}{Autoregressive function plot of the parameters.}
#' \item{\code{type = 7}}{Matplot of the empirical acceptance rate and refraction rates.}}
#'
#' @param which either a numerical vector indicating the index of the parameters of interest or a string \describe{
#' \item{\code{which = 'continuous'}}{for plotting the first \eqn{d-k} parameters.}
#' \item{\code{which = 'discontinuous'}}{for plotting the last \eqn{k} parameters.}
#' }
#' where both \eqn{d} and \eqn{k} are elements contained in the output of the \link{xdnuts} function.
#' 
#'
#' @param warm_up a boolean value indicating whether the plot should be made using the warm-up samples.
#' @param cex.titles a numerical value for regulating the size of the plot's titles, default value is \code{1}.
#' @param cex.legend a numerical value for regulating the size of the plot's legend, default value is \code{1}.
#' @param plot.new a boolean value indicating whether a new graphical window should be opened. This is advised if the parameters space is big.
#'
#' @param which_chains a numerical vector indicating the index of the chains of interest.
#' @param colors a numerical vector of the same length as \code{which_chains} containing the colors for each chain.
#' @param ... additional arguments to customize plots.
#' 
#' @return No return value.
#'
#' @export plot.XDNUTS
#' @export
plot.XDNUTS <- function(x,type = 1,which = NULL,warm_up = FALSE,
                       cex.titles = 1, cex.legend = 1, plot.new = FALSE,
                       which_chains = NULL,colors = NULL,...){
  
  #save current graphic window appearence in order to reset them at the end
  op <- graphics::par(no.readonly = TRUE)
  
  #reset the graphic window on.exit
  on.exit(graphics::par(op))
  
  #get the number of chains
  nc <- length(x$chains)
  
  #initialize and make sure that the index of the chain to use is admissible
  if(is.null(which_chains)){
    which_chains <- seq_len(nc)
  }else{
    if(any(which_chains > nc | which_chains < 1)){
      base::stop("Incorrect chain indexes!")
    }
    which_chains <- base::unique(which_chains)
  }
  
  #do the same with the colors argument
  if(is.null(colors)){
    colori <- base::sapply(seq_len(nc),grDevices::adjustcolor,alpha = 0.5)
  }else{
    colori <- base::cbind(colors,1:nc)[,1]
  }
    
  #which parameters do we want to see the graph of?
  #make sure it the input is admissible
  if(is.null(which)){
    which <- seq_len(base::NCOL(x$chains[[1]]$values))
  }else if(all(which == "continuous")){
    
    which <- seq_len(x$d)[seq_len(x$d-x$k)]
    
  }else if(all(which == "discontinuous")){
    
    which <- seq_len(x$d)[-seq_len(x$d-x$k)]
    
  }else if(any(which < 1 | which > base::NCOL(x$chains[[1]]$values)) ){
    base::stop("Incorrect index of parameters!")
  }
  
  #update the dimansion of the parameter space to be plotted
  d <- length(which)
  
  #do we want to see the warm up or the sampling?
  #make sure that the input is admissible
  if(warm_up == TRUE){
    quale <- "warm_up"
    if(is.null(x$chains[[1]][[quale]])){
      base::stop("No warm-up phase available!")
    }
  }else{
    quale <- "values"
  }
  
  #do we want a new graphics window?
  if(plot.new) {
    grDevices::X11()
  }
  
  #plot1: marginal chains
  if(type == 1){
    
    #get MCMC iteration
    tt <- seq_len(base::NROW(x$chains[[1]][[quale]]))
    
    #let's partition the graphics window accordingly
    k1 <- ceiling(sqrt(d))
    if(d <= k1*(k1-1)) {
      k2 <- k1 - 1
    }else{
      k2 <- k1
    }
    
    graphics::layout(mat = base::cbind(base::matrix(1:(k1*k2),k2,k1, byrow = TRUE),k1*k2+1))
    
    #set the margins
    graphics::par(mar = c(1.1,3.1,3.1,1.1))
    
    #get iteration range
    xlim <- range(tt)
    
    #make a plot for every parameter
    for(i in which){
      
      #compute the range of this parameter chains
      ylim <- range(base::sapply(which_chains, function(idx) range(x$chains[[idx]][[quale]][,i])))
      
      #empty plot
      base::plot(NULL,xlim = xlim,ylim = ylim, xlab = "", ylab = "",
                 main = base::colnames(x$chains[[1]][[quale]])[i], cex.main = cex.titles)
      
      #add each chain
      for(j in which_chains){
        graphics::lines(tt,x$chains[[j]][[quale]][,i], col = colori[j])
      }
      
      #add a dashed line in zero
      graphics::abline(h = 0, col = 2, lty = 2)
    }
    
    #empty plots due to a non perfect graphical window partition
    for(i in seq_len(k2*k1 - d)){
      graphics::plot.new()
    }
    
    #add legend on the right window
    graphics::par(mar = c(1.1,1.1,1.1,1.1))
    base::plot(1, type = "n", axes = FALSE, ylab = "", xlab = "", bty = "n")
    graphics::legend("top", title = "Chain", legend = which_chains, col = colori[which_chains],
                     bty = "n", lty = 1, lwd = 2, cex = cex.legend, title.cex = cex.legend,bg = "transparent")
    
    
  }else if(type == 2){
    #plot2: marginal and bivariate densities
    
    #let's partition the graphics window accordingly
    graphics::par(mfrow = c(d,d), mar = c(0.6,0.6,0.6,0.6))
    for(i in 1:d){
      for(j in 1:d){
        if(i > j) {
          #makes empty plots below the diagonal block
          graphics::plot.new()
        }
        if(i == j){
          #plots marginal densities for each chain on the diagonal block
          
          #get posterior densities for each chain of this parameter
          dens_out <- base::lapply(which_chains,function(idx){
            base::do.call( base::cbind, stats::density(x$chains[[idx]][[quale]][,which[i]])[c("x","y")])
          })
          
          #gets the x and y limits
          xlim <- range(base::sapply(seq_along(dens_out),function(ii) range(dens_out[[ii]][,1])))
          ylim <- range(base::sapply(seq_along(dens_out),function(ii) range(dens_out[[ii]][,2])))
          
          #empty plot
          base::plot(NULL,xlim = xlim,ylim = ylim,xlab = "",ylab = "",
                     main = base::colnames(x$chains[[1]][[quale]])[which[i]],
                     cex.main = cex.titles)
          
          #overlap the density of each chain
          for(ii in seq_along(dens_out)){
            graphics::lines(dens_out[[ii]], col = colori[which_chains[[ii]]])
          }
          
          #add a vertical dashed line in zero
          graphics::abline(v = 0, col = 2, lty = 2)
        }
        if(i < j){
          #on the upper triangle section plot the bivariate density
          graphics::smoothScatter(base::do.call(base::rbind,base::lapply(which_chains,
                                                                         function(idx) x$chains[[idx]][[quale]][,which[c(j,i)]])), main = "")
          #add verttical and horizontal dashed line in zero
          graphics::abline(h = 0, col = 2, lty = 2)
          graphics::abline(v = 0, col = 2, lty = 2)
        }
      }
    }
  }else if(type == 3){
    #plot3: energy Markov chain
    
    #gets each chain energy
    energie <- base::sapply(which_chains,function(idx) x$chains[[idx]]$energy)
    
    #gets x and y limit
    xlim <- base::NROW(energie)
    ylim <- range(energie)
    
    #partition the graphic window in order to add the legend on the right
    graphics::layout(mat = base::matrix(1:2,1,2), widths = c(0.6,0.4))
    graphics::par(mar = c(5.1,4.1,2.1,1.1))
    
    #empty plot
    base::plot(NULL,xlim = c(1,xlim), ylim = ylim, xlab = "Iteration", ylab = "Energy", main = "",bty = "L")
    
    #add each energy chain plot
    for(i in seq_along(which_chains)){
      graphics::lines(1:xlim,energie[,i], col = colori[which_chains[i]])
    }
    
    #add the legend on the right
    graphics::par(mar = c(1.1,1.1,1.1,1.1))
    base::plot(1, type = "n", axes = FALSE, ylab = "", xlab = "", bty = "n")
    graphics::legend("top",bty = "n", bg = "transparent", title = "Chain",
                     legend = which_chains, col = colori[which_chains], lty = 1,
                     lwd = 2, title.cex = cex.legend, cex = cex.legend)
    
  }else if(type == 4){
    #plot4: stick plot of trajectory length
    
    #gets trajectory length frequency for each chain
    vals <- base::lapply(which_chains,function(idx) {
      tmp <- base::table(x$chains[[idx]][["step_length"]])
      base::cbind(as.numeric(names(tmp)),tmp)
    })
    
    #compute x and y limits
    x_range <- range(c(1,base::sapply(vals,function(x) max(x[,1]))))
    y_range <- c(0,max(base::sapply(vals,function(x) max(x[,2]))))
    
    #empty plot
    base::plot(NULL,xlim = x_range, ylim = y_range,  xlab = "L",
               ylab = "Frequency",main = "Iteration's Step-Length", bty = "L",
               cex.main = cex.titles)
    
    #add each chain stick plot
    for(i in seq_along(vals)){
      for(j in seq_len(NROW(vals[[i]]))){
        graphics::lines(rep(vals[[i]][j,1],2)+0.1*(i-1) , c(0,vals[[i]][j,2]), col = colori[which_chains[i]], lwd = 3)
      }
    }
    
    #add legend
    graphics::legend("topright",bty = "n", bg = "transparent", title = "Chain",
                     legend = (1:nc)[which_chains], col = colori[which_chains], lty = 1, lwd = 2, cex = cex.legend)
    
  }else if(type == 5){
    
    #plot5: marginal and first difference energy histogram
    
    #get energy chains
    E <- base::sapply(which_chains,function(i) x$chains[[i]]$energy)
    
    #get first difference energy chains
    delta_E <- base::sapply(which_chains,function(i) x$chains[[i]]$delta_energy)
    
    #get sample size
    N <- prod(dim(E))
    
    #compute the number of bins
    n <- min(100,N / 33)
    
    #compute marginal energy frequencies
    g1 <- graphics::hist(c(E),n = n, plot = FALSE)
    
    #center them on their mode
    g1$mids <- g1$mids - g1$mids[base::which.max(g1$density)[1]]
    
    #compute first difference energy frequencies
    g2 <- graphics::hist(c(delta_E),n = n, plot = FALSE)
    
    #center them on their mode
    g2$mids <- g2$mids - g2$mids[base::which.max(g2$density)[1]]
    
    #compute x and y limits
    xlim <- range(g1$mids,g2$mids)
    ylim <- c(0,max(g1$density,g2$density))
    
    #compute an adeguate span between each stick
    span <- min(0.5,base::diff(xlim)/n/2)
    
    #empty plot
    base::plot(NULL,xlim = xlim,ylim = ylim, xlab = "Centered Energy", ylab = "Frequency", bty = "L")
    
    #add marginal histogram
    for(ii in seq_along(g1$breaks)){
      graphics::lines(rep(g1$mids[ii],2),c(0,g1$density[ii]), lwd = 2, col = grDevices::adjustcolor("darkgray",0.7))
    }
    
    #add first differences histogram
    for(ii in seq_along(g2$breaks)){
      graphics::lines(rep(g2$mids[ii],2) + span,c(0,g2$density[ii]), lwd = 2, col = grDevices::adjustcolor("darkred",0.7))
    }
    
    #add legend
    graphics::legend("topright",bty = "n", bg = "transparent", title = "",
                     legend = c("E",expression(paste(Delta,"E"))), 
                     col = grDevices::adjustcolor(c("darkgray","darkred",0.7)), lty = 1, lwd = 2, cex = cex.legend)
  }else if(type == 6){
    #plot6: autocorrelation plots
    
    #get lag maximum number 
    lag_max <- floor(10 * log10(x$N))
    
    #compute autocorrelation for each chain and parameter specified
    acfs <- base::lapply(which,function(i) 
      base::sapply(x$chains[which_chains], function(XX)
        stats::acf(XX[[quale]][,i], plot = FALSE,lag.max = lag_max)$acf[-1,,1] ) )
    
    #compute for each parameter the y limits
    ylim <- base::sapply(acfs,range)
    
    #let's partition the graphics window accordingly
    k1 <- ceiling(sqrt(d))
    if(d <= k1*(k1-1)) {
      k2 <- k1 - 1
    }else{
      k2 <- k1
    }
    graphics::layout(mat = rbind(1,cbind(matrix(1 + 1:(k1*k2),k2,k1, byrow = TRUE),k1*k2+2)),
                     heights = c(max(0.25,1-1/k2),rep(1,k2)))
    
    #add title
    graphics::par(mar = c(1.1,1.1,1.1,1.1))
    graphics::plot.new()
    graphics::text(0.5,0.5,"Autocorrelation Plots",cex=cex.titles,font=2)
    
    graphics::par(mar = c(1.1,2.1,3.1,1.1))
    
    #add each parameter plot
    for(i in seq_along(which)){
      #empty plot
      base::plot(NULL,xlim = c(1,lag_max),
                 ylim = range(0,ylim[,i],1), main = base::colnames(x$chains[[1]]$values)[which[i]],
                 xlab = "", ylab = "")
      #autocorrelation of each chain
      for(j in seq_along(which_chains))
        graphics::lines(1:lag_max, acfs[[i]][,j], col = colori[which_chains[j]])
    }
    
    #remaining empty plot
    for(i in seq_len(k2*k1 - d)){
      graphics::plot.new()
    }
    
    #add legend
    graphics::par(mar = c(1.1,1.1,1.1,1.1))
    base::plot(1, type = "n", axes = FALSE, ylab = "", xlab = "", bty = "n")
    graphics::legend("top", title = "Chain", legend = which_chains, col = colori[which_chains],
                     bty = "n", lty = 1, lwd = 2, cex = cex.legend, title.cex = cex.legend,bg = "transparent")
    
  }else{
    #plot7: matplot of alphas
    
    #check the number of chains
    if(length(which_chains) == 1){
      stop("alphas plots has no sense with only one chain!")
    }
    
    
    graphics::par(mar = c(4.1,4.1,5.1,2.1), xpd = TRUE)
    graphics::matplot(t(base::do.call(base::cbind,base::sapply(x$chains,"[","alpha")))[which_chains,], 
                      type = "l", col = grDevices::adjustcolor(c(1,rep(2,24),0.6)), lty = c(1,rep(2,24)),
                      ylim = c(0,1), xlab = "Chain", ylab = "Empirical rates",
                      bty = "L")
    graphics::legend("topright", title = "", legend = c("global","refraction"),
                     col = 1:2, lty = 1:2, lwd = 2, bty = "n", bg = "transparent", inset = c(0,-0.2),
                     cex = cex.legend)
  }
}

### SUMMARY FUNCTION OF THE MCMC OUTPUT 
#' Function to print the summary of an XDNUTS model.
#'
#' @param object an object of class \code{XDNUTS}.
#' @param ... additional arguments to customize the summary.
#' @param q.val desired quantiles of the posterior distribution for each coordinate.
#' Default values are \code{0.05,0.25,0.5,0.75,0.95}.
#'
#'@param which either a numerical vector indicating the index of the parameters of interest or a string \describe{
#' \item{\code{which = 'continuous'}}{for plotting the first \eqn{d-k} parameters.}
#' \item{\code{which = 'discontinuous'}}{for plotting the last \eqn{k} parameters.}
#' }
#' where both \eqn{d} and \eqn{k} are elements contained in the output of the function \link{xdnuts}.
#' @param which_chains a numerical vector indicating the index of the chains of interest.
#' @param digits number of digits in the summary table.
#'
#' @return a list containing a data frame named \code{stats} with the following columns: \describe{\item{mean}{the mean of the posterior distribution.}
#' \item{sd}{the standard deviation of the posterior distribution.}
#' \item{q.val}{the desired quantiles of the posterior distribution.}
#' \item{ESS}{the Effective Sample Size for each marginal distribution.}
#' \item{R_hat}{the Potential Scale Reduction Factor of Gelman \insertCite{gelman1992inference}{XDNUTS}, only if multiple chains are available.}
#' \item{R_hat_upper_CI}{the upper confidence interval for the latter, only if multiple chains are available.}
#' }
#' Other quantities returned are:\describe{
#' \item{Gelman.Test}{the value of the multivariate Potential Scale Reduction Factor test \insertCite{gelman1992inference}{XDNUTS}.}
#' \item{BFMI}{the value of the empirical Bayesian Fraction of Information Criteria \insertCite{betancourt2016diagnosing}{XDNUTS}. 
#' A value below 0.2 indicates a bad random walk behavior in the energy Markov Chain, mostly due to a suboptimal
#' specification of the momentum parameters probability density.}}
#'
#' @references 
#'  \insertAllCited{} 
#' 
#' @export summary.XDNUTS
#' @export
summary.XDNUTS <- function(object, ..., q.val = c(0.05,0.25,0.5,0.75,0.95),
                           which = NULL, which_chains = NULL, digits = 5){
  #require(coda)

  #get the initial number of chains
  nc <- length(object$chains)
  
  #get the indexes of desired chains and make sure they are admissible
  if(is.null(which_chains)){
    which_chains <- seq_len(nc)
  }else{
    if(any(which_chains > nc | which_chains < 1)){
      base::stop("Incorrect chain indexes!")
    }
    which_chains <- base::unique(which_chains)
  }
  
  #which parameters do we want to see the summary of?
  #make sure the input in admissible
  if(is.null(which)){
    which <- seq_len(base::NCOL(object$chains[[1]]$values))
  }else if(all(which == "continuous")){
    
    which <- seq_len(object$d)[seq_len(object$d-object$k)]
    
  }else if(all(which == "discontinuous")){
    
    which <- seq_len(object$d)[-seq_len(object$d-object$k)]
    
  }else if(any(which < 1 | which > base::NCOL(object$chains[[1]]$values)) ){
    base::stop("Incorrect index of parameters!")
  }
  
  #transformation of the origian output in a mcmc.list object of coda package
  res <- list()
  conta <- 1
  for(i in which_chains){
    res[[conta]] <- coda::mcmc(object$chains[[i]]$values[,which,drop = FALSE])
    conta <- conta + 1
  }
  res <- coda::as.mcmc.list(res)
  
  #compute petential scale reduction factor
  gelman.test <- base::tryCatch(coda::gelman.diag(res), error = function(x) list(psrf = NULL, mpsrf = NULL))
  
  #compute posterior statistics
  out <- t(base::apply(base::do.call(base::rbind,res),2,function(x) 
    c(base::mean(x),stats::sd(x),stats::quantile(x,q.val))))
  out <- base::cbind(out,coda::effectiveSize(res),gelman.test$psrf)
  out <- base::as.data.frame(out)
  
  #if there are more then 1 chain, add the Rhat statistics
  if(conta > 2 && !is.null(gelman.test$psrf)){
    base::colnames(out) <- c("mean","sd",base::paste0(q.val*100,"%"),"ESS","R_hat","R_hat_upper_CI")
  }else{
    base::colnames(out) <- c("mean","sd",base::paste0(q.val*100,"%"),"ESS")
  }
  
  #print to console the table
  base::print(round(out,digits = digits))
  if(!is.null(gelman.test$mpsrf)){
    #if available, add the multivariate test statistic
    base::cat("\nMultivariate Gelman Test: ",round(gelman.test$mpsrf, digits = digits))
  }
  
  #compute the empirical bayesian fraction of missing information:
  
  #get energy
  E <- base::sapply(object$chains,function(x) x$energy)
  
  #get first difference energy
  delta_E <- base::sapply(object$chains,function(x) x$delta_energy)
  
  #estimate it
  BFMI <- stats::var(c(E)) / stats::var(c(delta_E))
  
  #print it to console
  base::cat("\nEstimated Bayesian Fraction of Missing Information: ",round(BFMI, digits = digits))
  
  #build output list
  out <- list(stats = out, Gelman.Test = gelman.test$mpsrf,BFMI = BFMI)
  
  #if present, report the presence of divergent transitions
  n_div <- sum(base::sapply(object$chains,function(x) sum(base::NROW(x$div_trans))))
  if(n_div != 0){
    base::cat("\nWarning: ",n_div, " trajectory ended with a divergent transition!")
  }
  
  #compute and report the number of trajectories terminated prematurely
  n_trunc <- sum( c(base::sapply(object$chains,function(x) 
    x$step_length)) == (2^object$control$max_treedepth - 1))
  if(n_trunc != 0){
    base::cat("\nWarning: ",n_trunc, " trajectory ended before reaching an effective termination!")
  }
  
  #return the list
  return(invisible(out))
  
}

### FUNCTION FOR CONVENIENT SAMPLE EXTRACTION
#' Function to extract samples from the output of an XDNUTS model.
#'
#' @param X an object of class \code{XDNUTS}.
#'
#'@param which either a numerical vector indicating the index of the parameters of interest or a string \describe{
#' \item{\code{which = 'continuous'}}{for plotting the first \eqn{d-k} parameters.}
#' \item{\code{which = 'discontinuous'}}{for plotting the last \eqn{k} parameters.}
#' }
#' where both \eqn{d} and \eqn{k} are elements contained in the output of the function \link{xdnuts}.
#'
#' @param which_chains a vector of indices containing the chains to extract. By default, all chains are considered.
#'
#' @param collapse a boolean value. If TRUE, all samples from every chain are collapsed into one. The default value is FALSE.
#' 
#' @return an \eqn{N \times d} matrix or an \eqn{N \times d \times C} array, where C is the number of chains, containing the MCMC samples.
#'
#' @export xdextract
xdextract <- function(X, which = NULL, which_chains = NULL, collapse = FALSE){
  
  #get chain length
  nc <- length(X$chains)
  
  #get chain indexes to extract and make sure they are admissible
  if(is.null(which_chains)){
    which_chains <- seq_len(nc)
  }else{
    if(any(which_chains > nc | which_chains < 1)){
      base::stop("Incorrect chain indexes!")
    }
    which_chains <- base::unique(which_chains)
  }
  
  #of which parameters do we want to extract samples?
  #make sure they are admissible
  if(is.null(which)){
    which <- seq_len(base::NCOL(X$chains[[1]]$values))
  }else if(all(which == "continuous")){
    
    which <- seq_len(X$d)[seq_len(X$d-X$k)]
    
  }else if(all(which == "discontinuous")){
    
    which <- seq_len(X$d)[-seq_len(X$d-X$k)]
    
  }else if(any(which < 1 | which > base::NCOL(X[[1]]$values)) ){
    base::stop("Incorrect index of parameters!")
  }
  
  if(!collapse){
    #return an array
    
    #initialize the array
    out <- base::array(NA,dim = c(base::NROW(X$chains[[1]]$values),
                            length(which),
                            length(which_chains)))
    
    #fill it
    conta <- 1
    for(i in which_chains){
      out[,,conta] <- X$chains[[i]]$values[,which, drop = FALSE]
      conta <- conta + 1
    }
    
    #assign proper dimension names
    dimnames(out) <- list(NULL,base::colnames(X$chains[[1]]$values)[which],
                          base::paste0("chain_",which_chains))
  }else{
    #return a matrix
    
    #initialize the matrix
    out <- base::matrix(NA,X$N*X$K*length(which_chains),length(which))
    
    #fill it
    conta <- 0
    for(i in which_chains){
      out[conta*X$N*X$K + 1:(X$N*X$K),] <- X$chains[[i]]$values[,which, drop = FALSE]
      conta <- conta + 1
    }
    
    #assign proper dimension names
    base::colnames(out) <- base::colnames(X$chains[[1]]$values)[which]
  }
  
  #return the array/matrix
  return(out)
}

### FUNCTION THAT APPLIES A TRANSFORMATION TO THE CHAINS
#' Function to apply a transformation to the samples from the output of an XDNUTS model.
#'
#' @param X an object of class \code{XDNUTS}.
#' @param which a vector of indices indicating which parameter the transformation should be applied to.
#'  If \code{NULL}, the function is applied to all current iterations of the chain.
#' @param FUN a function object which takes one or more components of an MCMC iteration and any other possible arguments.
#' @param ... optional arguments for FUN.
#' @param new.names a character vector containing the parameter names in the new parameterization.  
#'  If only one value is provided, but the transformation involves more, the name is iterated with an increasing index.
#'
#' @return an object of class \code{XDNUTS} with the specified transformation applied to each chain.
#'
#' @export xdtransform
xdtransform <- function(X, which = NULL, FUN = NULL, ..., new.names = NULL){
  
  #copy the original XDNUTS object
  out <- X
  
  #get old parameters names
  old_names <- colnames(X$chains[[1]]$vales)
  
  #ensure argument which is admissible
  if(!is.null(which)){
    
    if(is.character(which)){
      #character vector case
      if(any(!sapply(which,function(x) x %in% old_names))){
        base::stop("Incorrect parameter names specified!")
      }else{
        which <- base::sapply(which,function(x) base::which(old_names == x))
      }
    }else if(is.numeric(which)){
      #numeric vector case
      if(any(which < 1 | which > base::NCOL(X$chains[[1]]$values))){
        base::stop("Incorrect index of parameters")
      }
    }else{
      base::stop("Wrong 'which' type!")
    }
  }else{
    which <- seq_along(old_names)
  }
  
  #case when we update all the components
  if(length(which) == length(old_names)){
    
    #loop for every chain and apply the transformation
    for(cc in seq_along(X$chains)){
      out$chains[[cc]]$values <- t(base::apply(X$chains[[cc]]$values , 1 , FUN , ... ))
      
      #give names to the new parameters
      if(is.null(new.names)){
        base::colnames(out$chains[[cc]]$values) <- 
          base::paste0("theta",seq_len(base::NCOL(out$chains[[cc]]$values)))
      }else{
        if(length(new.names) == 1 && base::NCOL(out$chains[[cc]]$values) != 1){
          base::colnames(out$chains[[cc]]$values) <- 
            base::paste0(new.names,seq_len(base::NCOL(out$chains[[cc]]$values)))
        }else if(length(new.names) == base::NCOL(out$chains[[cc]]$values) ){
          base::colnames(out$chains[[cc]]$values) <- new.names
        }else{
          base::stop("Incorrect number of values in 'new.names'!")
        }
      }
      
      #do the same, eventually, for the warm up phase
      if(!is.null(out$chains[[cc]]$warm_up)){
        out$chains[[cc]]$warm_up <- t(base::apply(X$chains[[cc]]$warm_up , 1 , FUN , ... ))
        
        #give names to the new parameters
        if(is.null(new.names)){
          base::colnames(out$chains[[cc]]$warm_up) <- 
            base::paste0("theta",seq_len(base::NCOL(out$chains[[cc]]$warm_up)))
        }else{
          if(length(new.names) == 1 && base::NCOL(out$chains[[cc]]$warm_up) != 1){
            base::colnames(out$chains[[cc]]$warm_up) <- 
              base::paste0(new.names,seq_len(base::NCOL(out$chains[[cc]]$warm_up)))
          }else if(length(new.names) == base::NCOL(out$chains[[cc]]$warm_up) ){
            base::colnames(out$chains[[cc]]$warm_up) <- new.names
          }else{
            base::stop("Incorrect number of values in 'new.names'!")
          }
        }
      }
    }
    
  }else{
    #case when we update only some components
    
    #loop for every chain and apply the transformation
    for(cc in seq_along(X$chains)){
      out$chains[[cc]]$values[,which] <- 
        (base::apply(X$chains[[cc]]$values[,which,drop = FALSE] , 2 , FUN , ... ))
      
      #give names to the new parameters
      if(is.null(new.names)){
        base::colnames(out$chains[[cc]]$values)[which] <- 
          base::paste0("theta",seq_along(which))
      }else{
        if(length(new.names) == 1 && length(which) != 1){
          base::colnames(out$chains[[cc]]$values)[which] <- 
            base::paste0(new.names,seq_along(which))
        }else if(length(new.names) == length(which) ){
          base::colnames(out$chains[[cc]]$values)[which] <- new.names
        }else{
          base::stop("Incorrect number of values in 'new.names'!")
        }
      }
      
      #do the same, eventually, for the warm up phase
      if(!is.null(out$chains[[cc]]$warm_up)){
        out$chains[[cc]]$warm_up[,which] <- 
          (base::apply(X$chains[[cc]]$warm_up[,which,drop = FALSE] , 2 , FUN , ... ))
        
        #give names to the new parameters
        if(is.null(new.names)){
          base::colnames(out$chains[[cc]]$warm_up)[which] <- 
            base::paste0("theta",seq_along(which))
        }else{
          if(length(new.names) == 1 && length(which) != 1){
            base::colnames(out$chains[[cc]]$warm_up)[which] <- 
              base::paste0(new.names,seq_along(which))
          }else if(length(new.names) == length(which) ){
            base::colnames(out$chains[[cc]]$warm_up)[which] <- new.names
          }else{
            base::stop("Incorrect number of warm_up in 'new.names'!")
          }
        }
      }
    }
  }
  

  #return the new XDNUTS object
  return(out)
}
