#' @title Model-average marginal posterior distributions
#'
#' @description Creates marginal model-averages posterior distributions for a given
#' parameter based on model-averaged posterior samples and parameter name
#' (and formula with at specification).
#'
#' @param samples model-averaged posterior samples created by \code{mix_posteriors()}
#' @param parameter parameter of interest
#' @param formula model formula (needs to be specified if \code{parameter} was part of a formula)
#' @param at named list with predictor levels of the formula for which marginalization
#' should be performed. If a predictor level is missing, \code{0} is used for continuous
#' predictors, the baseline factor level is used for factors with \code{contrast = "treatment"} prior
#' distributions, and the parameter is completely omitted for for factors with \code{contrast = "meandif"},
#' @param prior_samples whether marginal prior distributions should be generated
#' \code{contrast = "orthonormal"}, and \code{contrast = "independent"} levels
#' @param use_formula whether the parameter should be evaluated as a part of supplied formula
#' @param n_samples number of samples to be drawn for the model-averaged
#' prior distribution
#' @inheritParams density.prior
#'
#' @return \code{marginal_posterior} returns a named list of mixed marginal posterior
#' distributions (either a vector of matrix).
#'#'
#' @export
marginal_posterior <- function(samples, parameter, formula = NULL, at = NULL, prior_samples = FALSE, use_formula = TRUE,
                               transformation = NULL, transformation_arguments = NULL, transformation_settings = FALSE,
                               n_samples = 10000, ...){

  check_list(samples, "samples")
  if(!inherits(samples, "mixed_posteriors"))
    stop("'samples' must be a be an object generated by 'mix_posteriors' function.")
  check_char(parameter, "parameter", allow_values = names(samples))
  if(!is.null(formula) && !is.language(formula))
    stop("'formula' must be a formula")
  if(!is.null(at) && !is.list(at))
    stop("'at' must be a list")
  check_bool(prior_samples, "prior_samples")
  check_bool(use_formula, "use_formula")
  .check_transformation_input(transformation, transformation_arguments, transformation_settings)


  # deal formula vs non-formula marginal posterior
  if(use_formula && inherits(samples[[parameter]], "mixed_posteriors.formula")){

      # remove the specified response (would crash the model.frame if not included)
      formula <- .remove_response(formula)
      formula_parameter <- attr(samples[[parameter]], "formula_parameter")

      ### extract the terms information from the formula
      formula_terms          <- stats::terms(formula)
      has_intercept          <- attr(formula_terms, "intercept") == 1
      predictors             <- as.character(attr(formula_terms, "variables"))[-1]
      model_terms            <- c(if(has_intercept) "intercept", attr(formula_terms, "term.labels"))

      JAGS_model_terms <- JAGS_parameter_names(parameters = model_terms, formula_parameter = formula_parameter)
      JAGS_predictors  <- JAGS_parameter_names(parameters = predictors, formula_parameter = formula_parameter)


      ### obtain posterior samples and check that all are present
      if(!all(JAGS_model_terms %in% names(samples)))
        stop(paste0("The posterior samples for the ", paste0("'", JAGS_model_terms[!JAGS_model_terms %in% names(samples)], "'", collapse = ", ")," term is missing in the samples."))


      ### obtain prior list and information
      prior_list <- lapply(names(samples), function(model_term) attr(samples[[model_term]], "prior_list"))
      names(prior_list) <- names(samples)

      # get parameter information
      priors_info <- lapply(names(prior_list), function(model_term) list(
        term              = model_term,
        intercept         = model_term == "intercept",
        factor            = inherits(samples[[model_term]], "mixed_posteriors.factor"),
        levels            = attr(samples[[model_term]], "levels"),
        level_names       = attr(samples[[model_term]], "level_names"),
        interaction       = attr(samples[[model_term]], "interaction"),
        interaction_terms = attr(samples[[model_term]], "interaction_terms"),
        treatment         = attr(samples[[model_term]], "treatment"),
        independent       = attr(samples[[model_term]], "independent"),
        orthonormal       = attr(samples[[model_term]], "orthonormal"),
        meandif           = attr(samples[[model_term]], "meandif")
        ))
      names(priors_info) <- names(prior_list)
      model_terms_type <- sapply(JAGS_model_terms, function(model_term){
        if(priors_info[[model_term]][["factor"]]){
          return("factor")
        }else{
          return("continuous")
        }
      })
      predictors_type  <- model_terms_type[JAGS_parameter_names(parameters = predictors, formula_parameter = formula_parameter)]


      ### prepare at specification
      # in case of an interaction, all levels need to be set
      if(!is.null(priors_info[[parameter]][["interaction"]]) && priors_info[[parameter]][["interaction"]]){
        at_manipulated <- JAGS_parameter_names(priors_info[[parameter]][["interaction_terms"]], formula_parameter = formula_parameter)
      }else{
        at_manipulated <- parameter
      }

      if(!all(names(at) %in% predictors))
        stop(paste0("The following values passed via the 'at' argument do not correspond to the specified model: ", paste0("'", names(at)[!names(at) %in% predictors], "'", collapse = ", ")))
      if(any(format_parameter_names(at_manipulated, formula_parameters = formula_parameter, formula_prefix = FALSE) %in% names(at)))
        stop("Values of the parameter of interested cannot be specified via the 'at' argument.")

      # fill in with default values if needed
      for(i in seq_along(predictors)){
        if(JAGS_predictors[i] %in% at_manipulated){
          # specify levels for the parameter of interest
          if(model_terms_type[[JAGS_predictors[i]]] == "continuous"){
            at[[predictors[i]]] <- c(-1, 0, 1)
          }else{
            at[[predictors[i]]] <- priors_info[[JAGS_predictors[i]]][["level_names"]]
          }
        }else if(is.null(at[[predictors[i]]])){
          # specify levels for the remaining parameters
          if(model_terms_type[[JAGS_predictors[i]]] == "continuous"){
            # fill in zeroes for unspecified continuous predictors
            at[[predictors[i]]] <- 0
          }else if(priors_info[[JAGS_predictors[i]]][["treatment"]]){
            # fill in the default category for unspecified treatment factors
            at[[predictors[i]]] <- priors_info[[JAGS_predictors[i]]][["level_names"]][1]
          }else{
            # fill in NA for any other factor type
            at[[predictors[i]]] <- NA
          }
        }
      }

      # transform to a data.frame
      data <- as.data.frame(expand.grid(at))

      # check the specified data
      if(any(predictors_type == "factor")){

        # check the proper data input for each factor variable
        for(i in seq_along(predictors_type)[predictors_type == "factor"]){

          if(is.factor(data[,predictors[i]])){
            if(all(levels(data[,predictors[i]]) %in% priors_info[[JAGS_predictors[i]]][["level_names"]])){
              # either the formatting is correct, or the supplied levels are a subset of the original levels
              # reformat to check ordering and etc...
              data[,predictors[i]] <- factor(data[,predictors[i]], levels = priors_info[[JAGS_predictors[i]]][["level_names"]])
            }else{
              # there are some additional levels
              stop(paste0("Levels specified in the '", predictors[i], "' factor variable do not match the levels used for model specification."))
            }
          }else if(all(stats::na.omit(unique(data[,predictors[i]])) %in% priors_info[[JAGS_predictors[i]]][["level_names"]])){
            # the variable was not passed as a factor but the values matches the factor levels
            data[,predictors[i]] <- factor(data[,predictors[i]], levels = priors_info[[JAGS_predictors[i]]][["level_names"]])
          }else{
            # there are some additional mismatching values
            stop(paste0("Levels specified in the '", predictors[i], "' factor variable do not match the levels used for model specification."))
          }

          # set the contrast
          if(priors_info[[JAGS_predictors[i]]][["orthonormal"]]){
            stats::contrasts(data[,predictors[i]]) <- "contr.orthonormal"
          }else if(priors_info[[JAGS_predictors[i]]][["meandif"]]){
            stats::contrasts(data[,predictors[i]]) <- "contr.meandif"
          }else if(priors_info[[JAGS_predictors[i]]][["independent"]]){
            stats::contrasts(data[,predictors[i]]) <- "contr.independent"
          }else if(priors_info[[JAGS_predictors[i]]][["treatment"]]){
            stats::contrasts(data[,predictors[i]]) <- "contr.treatment"
            if(anyNA(data[,predictors[i]]))
              stop("Unspecified levels in the '", predictors[i], "' factor (NAs not allowed for 'treatment' factors).")
          }
        }
      }
      if(any(predictors_type == "continuous")){

        # check the proper data input for each continuous variable
        for(i in seq_along(predictors_type)[predictors_type == "continuous"]){
          if(anyNA(data[,predictors[i]]))
            stop("Unspecified levels in the '", predictors[i], "' variable (NAs not allowed for continuous variables).")
          if(!is.numeric(data[,predictors[i]]))
            stop("Nonnumeric values in the '", predictors[i], "' continuous variable.")
        }
      }


      ### get the design matrix
      model_frame  <- stats::model.frame(formula, data = data, na.action = NULL)
      model_matrix <- stats::model.matrix(model_frame, data = model_frame, formula = formula)

      # replaces NAs by zero to omit the corresponding coefficients
      model_matrix[is.na(model_matrix)] <- 0


      ### prepare posterior samples and information
      for(i in seq_along(samples)){
        # de-name factor levels
        if(priors_info[[i]][["factor"]]){
          if(priors_info[[i]][["levels"]] == 1){
            colnames(samples[[i]]) <- priors_info[[i]][["term"]]
          }else{
            colnames(samples[[i]]) <- paste0(priors_info[[i]][["term"]], "[", 1:priors_info[[i]][["levels"]], "]")
          }
        }
      }
      posterior_samples_matrix <- do.call(cbind, samples)


      # obtain samples information
      models_ind <- do.call(cbind, lapply(c(if(has_intercept) "intercept", model_terms), function(x) attr(samples[[JAGS_parameter_names(x, formula_parameter = formula_parameter)]], "models_ind")))
      sample_ind <- do.call(cbind, lapply(c(if(has_intercept) "intercept", model_terms), function(x) attr(samples[[JAGS_parameter_names(x, formula_parameter = formula_parameter)]], "sample_ind")))
      if(!inherits(samples, "as_mixed_posteriors") && (!all(models_ind[,1] == models_ind) || !all(sample_ind[,1] == sample_ind)))
        stop("the posterior samples are not alligned across models/draws")
      models_ind <- models_ind[,1]


      ### evaluate the design matrix on the samples -> output[data, posterior]
      if(has_intercept){

        terms_indexes    <- attr(model_matrix, "assign") + 1
        terms_indexes[1] <- 0

        # get model/sample indices and check for scaling factors
        temp_multiply_by <- .get_combined_parameter_scaling_factor_matrix(
          JAGS_parameter_names("intercept", formula_parameter = formula_parameter),
          prior_list  = prior_list,
          posterior   = posterior_samples_matrix,
          models_ind  = models_ind,
          nrow        = nrow(data),
          simple_list = inherits(samples, "as_mixed_posteriors")
        )


        marginal_posterior_samples <- temp_multiply_by * matrix(posterior_samples_matrix[,JAGS_parameter_names("intercept", formula_parameter = formula_parameter)],
                                                       nrow = nrow(data), ncol = nrow(posterior_samples_matrix), byrow = TRUE)

      }else{

        terms_indexes <- attr(model_matrix, "assign")
        marginal_posterior_samples <- matrix(0, nrow = nrow(data), ncol = nrow(posterior_samples_matrix))

      }

      # add remaining terms (omitting the intercept indexed as 0)
      for(i in unique(terms_indexes[terms_indexes > 0])){

        # subset the model matrix
        temp_data <- model_matrix[,terms_indexes == i,drop = FALSE]

        temp_posterior <- posterior_samples_matrix[,paste0(
          JAGS_model_terms[i],
          if(model_terms_type[i] == "factor" && priors_info[[JAGS_model_terms[i]]][["levels"]] > 1) paste0("[", 1:priors_info[[JAGS_model_terms[i]]][["levels"]], "]"))
          ,drop = FALSE]

        # check for scaling factors
        temp_multiply_by <- .get_combined_parameter_scaling_factor_matrix(
          JAGS_model_terms[i],
          prior_list  = prior_list,
          posterior   = posterior_samples_matrix,
          models_ind  = models_ind,
          nrow        = nrow(data),
          simple_list = inherits(samples, "as_mixed_posteriors")
        )

        marginal_posterior_samples <- marginal_posterior_samples + temp_multiply_by * (temp_data %*% t(temp_posterior))

      }


      # apply transformations
      if(!is.null(transformation)){
        marginal_posterior_samples <- .density.prior_transformation_x(marginal_posterior_samples, transformation, transformation_arguments)
      }


      ### split the output into lists based on specification
      # create indexing and names for the manipulated predictors
      if(length(at_manipulated) == 1 && format_parameter_names(at_manipulated, formula_parameters = formula_parameter, formula_prefix = FALSE) == "intercept"){

        class(marginal_posterior_samples)             <- c(class(marginal_posterior_samples), "marginal_posterior.simple")
        attr(marginal_posterior_samples, "parameter") <- parameter
        attr(marginal_posterior_samples, "level")     <- "intercept"
        attr(marginal_posterior_samples, "data")      <- data

        marginal_posterior_samples <- list("intercept" = marginal_posterior_samples)

        attr(marginal_posterior_samples, "data")        <- data
        attr(marginal_posterior_samples, "level_at")    <- NULL
        attr(marginal_posterior_samples, "level_names") <- "intercept"
        attr(marginal_posterior_samples, "parameter")   <- parameter

      }else{

        manipulated_predictors <- format_parameter_names(at_manipulated, formula_parameters = formula_parameter, formula_prefix = FALSE)
        at_index_output        <- at[manipulated_predictors]
        at_index_output.names  <- at_index_output

        # rename continuous predictors levels
        for(i in seq_along(at_manipulated)){
          if(predictors_type[[at_manipulated[i]]] == "continuous"){
            at_index_output.names[[manipulated_predictors[i]]] <- paste0(at_index_output.names[[manipulated_predictors[i]]], "SD")
          }
        }

        at_index_output_frame       <- expand.grid(at_index_output)
        at_index_output.names_frame <- expand.grid(at_index_output.names)
        level_names                 <- apply(at_index_output.names_frame, 1, paste0, collapse = ", ")

        # split the output samples
        data_split <- lapply(1:nrow(at_index_output_frame), function(i){
          apply(do.call(cbind, lapply(colnames(at_index_output_frame), function(pred){
            data[, pred] == at_index_output_frame[i, pred]
          })), 1, all)
        })
        marginal_posterior_samples <- lapply(seq_along(data_split), function(lvl){
          temp_marginal_posterior_samples <- marginal_posterior_samples[data_split[[lvl]],]
          temp_data                       <- data[data_split[[lvl]],]
          class(temp_marginal_posterior_samples)             <- c(class(temp_marginal_posterior_samples), "marginal_posterior.simple")
          attr(temp_marginal_posterior_samples, "parameter") <- parameter
          attr(temp_marginal_posterior_samples, "level")     <- level_names[lvl]
          attr(temp_marginal_posterior_samples, "data")      <- temp_data
          return(temp_marginal_posterior_samples)
        })
        names(marginal_posterior_samples) <- level_names
        class(marginal_posterior_samples) <- c(class(marginal_posterior_samples), "marginal_posterior.factor")

        attr(marginal_posterior_samples, "data")        <- data
        attr(marginal_posterior_samples, "level_at")    <- at_index_output_frame
        attr(marginal_posterior_samples, "level_names") <- level_names
        attr(marginal_posterior_samples, "parameter")   <- parameter

      }


      # add priors
      if(prior_samples){

        ### generate prior samples matrix in the same format as are the posterior samples
        if(inherits(samples, "as_mixed_posteriors")){
          prior_samples <- .as_mixed_priors(prior_list = prior_list, n_samples = n_samples, conditional = attr(samples, "conditional", exact = TRUE), conditional_rule = attr(samples, "conditional_rule"))
        }else{
          prior_samples <- .mix_priors(prior_list = prior_list, n_samples = n_samples)
        }

        for(i in seq_along(prior_samples)){
          # de-name factor levels
          if(priors_info[[i]][["factor"]]){
            if(priors_info[[i]][["levels"]] == 1){
              colnames(prior_samples[[i]]) <- priors_info[[i]][["term"]]
            }else{
              colnames(prior_samples[[i]]) <- paste0(priors_info[[i]][["term"]], "[", 1:priors_info[[i]][["levels"]], "]")
            }
          }
        }
        prior_samples_matrix <- do.call(cbind, prior_samples)


        # obtain prior_samples information
        models_ind <- do.call(cbind, lapply(c(if(has_intercept) "intercept", model_terms), function(x) attr(prior_samples[[JAGS_parameter_names(x, formula_parameter = formula_parameter)]], "models_ind")))
        sample_ind <- do.call(cbind, lapply(c(if(has_intercept) "intercept", model_terms), function(x) attr(prior_samples[[JAGS_parameter_names(x, formula_parameter = formula_parameter)]], "sample_ind")))
        if(!inherits(samples, "as_mixed_posteriors") && (!all(models_ind[,1] == models_ind) || !all(sample_ind[,1] == sample_ind)))
          stop("the prior prior_samples are not alligned across models/draws")
        models_ind <- models_ind[,1]


        ### evaluate the design matrix on the prior_samples -> output[data, prior]
        if(has_intercept){

          terms_indexes    <- attr(model_matrix, "assign") + 1
          terms_indexes[1] <- 0

          # get model/sample indices and check for scaling factors
          temp_multiply_by  <- .get_combined_parameter_scaling_factor_matrix(
            JAGS_parameter_names("intercept", formula_parameter = formula_parameter),
            prior_list  = prior_list,
            posterior   = prior_samples_matrix,
            models_ind  = models_ind,
            nrow        = nrow(data),
            simple_list = inherits(samples, "as_mixed_posteriors")
          )

          marginal_prior_samples <- temp_multiply_by * matrix(prior_samples_matrix[,JAGS_parameter_names("intercept", formula_parameter = formula_parameter)],
                                                         nrow = nrow(data), ncol = nrow(prior_samples_matrix), byrow = TRUE)

        }else{

          terms_indexes     <- attr(model_matrix, "assign")
          marginal_prior_samples <- matrix(0, nrow = nrow(data), ncol = nrow(prior_samples_matrix))

        }

        # add remaining terms (omitting the intercept indexed as 0)
        for(i in unique(terms_indexes[terms_indexes > 0])){

          # subset the model matrix
          temp_data <- model_matrix[,terms_indexes == i,drop = FALSE]

          temp_prior <- prior_samples_matrix[,paste0(
            JAGS_model_terms[i],
            if(model_terms_type[i] == "factor" && priors_info[[JAGS_model_terms[i]]][["levels"]] > 1) paste0("[", 1:priors_info[[JAGS_model_terms[i]]][["levels"]], "]"))
            ,drop = FALSE]

          # check for scaling factors
          temp_multiply_by <- .get_combined_parameter_scaling_factor_matrix(
            JAGS_model_terms[i],
            prior_list   = prior_list,
            posterior    = prior_samples_matrix,
            models_ind   = models_ind,
            nrow         = nrow(data),
            simple_list  = inherits(samples, "as_mixed_posteriors")
          )

          marginal_prior_samples <- marginal_prior_samples + temp_multiply_by * (temp_data %*% t(temp_prior))

        }

        # apply transformations
        if(!is.null(transformation)){
          marginal_prior_samples <- .density.prior_transformation_x(marginal_prior_samples, transformation, transformation_arguments)
        }


        ### split the output into lists based on specification
        if(length(at_manipulated) == 1 && format_parameter_names(at_manipulated, formula_parameters = formula_parameter, formula_prefix = FALSE) == "intercept"){

          class(marginal_prior_samples)                   <- c(class(marginal_prior_samples), "marginal_posterior.simple")
          attr(marginal_prior_samples, "parameter")       <- parameter
          attr(marginal_prior_samples, "level")           <- "intercept"
          attr(marginal_prior_samples, "data")            <- data
          attr(marginal_prior_samples, "all_alternative") <- attr(prior_samples, "all_alternative")

          attr(marginal_posterior_samples[["intercept"]], "prior_samples") <- marginal_prior_samples

        }else{

          marginal_prior_samples <- lapply(seq_along(data_split), function(lvl){
            temp_marginal_prior_samples <- marginal_prior_samples[data_split[[lvl]],]
            temp_data                       <- data[data_split[[lvl]],]
            class(temp_marginal_prior_samples)                   <- c(class(temp_marginal_prior_samples), "marginal_posterior.simple")
            attr(temp_marginal_prior_samples, "parameter")       <- parameter
            attr(temp_marginal_prior_samples, "level")           <- level_names[lvl]
            attr(temp_marginal_prior_samples, "data")            <- temp_data
            attr(temp_marginal_prior_samples, "all_alternative") <- attr(prior_samples, "all_alternative")
            return(temp_marginal_prior_samples)
          })
          names(marginal_prior_samples) <- level_names

          for(lvl in level_names){
            attr(marginal_posterior_samples[[lvl]], "prior_samples") <- marginal_prior_samples[[lvl]]
          }
        }

      }

      attr(marginal_posterior_samples, "formula_parameter") <- formula_parameter
      class(marginal_posterior_samples) <- c(class(marginal_posterior_samples), "marginal_posterior.formula")

  }else{

    if(!is.null(formula))
      stop("'formula' is supposed to be NULL when dealing with simple posteriors")
    if(!is.null(at))
      stop("'at' is supposed to be NULL when dealing with simple posteriors")


    ### obtain prior list and information
    prior_list <- lapply(names(samples), function(model_term) attr(samples[[model_term]], "prior_list"))
    names(prior_list) <- names(samples)


    ### extract the corresponding samples
    if(inherits(samples[[parameter]], "mixed_posteriors.factor")){

      # transform factor levels
      marginal_posterior_samples <- transform_factor_samples(samples[parameter])
      marginal_posterior_samples <- transform_treatment_samples(marginal_posterior_samples)[[parameter]]

      # apply transformations
      if(!is.null(transformation)){
        marginal_posterior_samples <- .density.prior_transformation_x(marginal_posterior_samples, transformation, transformation_arguments)
      }

      # TODO: change once dealing with factors interactions is solved
      if(attr(marginal_posterior_samples, "interaction")){
        if(length(attr(marginal_posterior_samples, "level_names")) == 1){
          level_names <- attr(marginal_posterior_samples, "level_names")[[1]]
        }else{
          stop("de-transformation for interaction of multiple factors is not implemented.")
        }
      }else{
        level_names <- attr(marginal_posterior_samples, "level_names")
      }

      # create output object
      marginal_posterior_samples <- lapply(level_names, function(lvl){
        temp_marginal_posterior_samples <- marginal_posterior_samples[,level_names == lvl]
        class(temp_marginal_posterior_samples) <- c(class(temp_marginal_posterior_samples), "marginal_posterior.factor")
        attr(temp_marginal_posterior_samples, "parameter")  <- parameter
        attr(temp_marginal_posterior_samples, "level_name") <- lvl
        return(temp_marginal_posterior_samples)
      })
      names(marginal_posterior_samples) <- level_names
      attr(marginal_posterior_samples, "level_names") <- level_names
      class(marginal_posterior_samples) <- c(class(marginal_posterior_samples), "marginal_posterior.factor")

    }else if(inherits(samples[[parameter]], "mixed_posteriors.simple")){

      # apply transformations
      if(!is.null(transformation)){
        marginal_posterior_samples <- .density.prior_transformation_x(marginal_posterior_samples, transformation, transformation_arguments)
      }

      marginal_posterior_samples <- samples[[parameter]]
      class(marginal_posterior_samples) <- c(class(marginal_posterior_samples), "marginal_posterior.simple")

    }


    # add prior samples
    if(prior_samples){

      if(inherits(samples, "as_mixed_posteriors")){
        prior_samples <- .as_mixed_priors(prior_list = prior_list, n_samples = n_samples, conditional = attr(samples, "conditional", exact = TRUE), conditional_rule = attr(samples, "conditional_rule"))
      }else{
        prior_samples <- .mix_priors(prior_list = prior_list, n_samples = n_samples)
      }
      marginal_prior_samples <- prior_samples[[parameter]]

      # transform if factors
      ### extract the corresponding samples
      if(inherits(prior_samples[[parameter]], "mixed_posteriors.factor")){

        # transform factor levels
        marginal_prior_samples <- transform_factor_samples(prior_samples[parameter])
        marginal_prior_samples <- transform_treatment_samples(marginal_prior_samples)[[parameter]]

        # apply transformations
        if(!is.null(transformation)){
          marginal_prior_samples <- .density.prior_transformation_x(marginal_prior_samples, transformation, transformation_arguments)
        }

        # create output object
        marginal_prior_samples <- lapply(level_names, function(lvl){
          temp_marginal_prior_samples <- marginal_prior_samples[,level_names == lvl]
          class(temp_marginal_prior_samples) <- c(class(temp_marginal_prior_samples), "marginal_posterior.factor")
          attr(temp_marginal_prior_samples, "parameter")  <- parameter
          attr(temp_marginal_prior_samples, "level_name") <- lvl
          return(temp_marginal_prior_samples)
        })
        names(marginal_prior_samples) <- level_names
        class(marginal_prior_samples) <- c(class(marginal_prior_samples), "marginal_posterior.factor")

        for(lvl in level_names){
          attr(marginal_posterior_samples[[lvl]], "prior_samples") <- marginal_prior_samples[[lvl]]
        }


      }else if(inherits(prior_samples[[parameter]], "mixed_posteriors.simple")){

        marginal_prior_samples <- prior_samples[[parameter]]

        # apply transformations
        if(!is.null(transformation)){
          marginal_prior_samples <- .density.prior_transformation_x(marginal_prior_samples, transformation, transformation_arguments)
        }

        class(marginal_prior_samples) <- c(class(marginal_prior_samples), "marginal_posterior.simple")
        attr(marginal_posterior_samples, "prior_samples") <- marginal_prior_samples

      }

    }
  }

  class(marginal_posterior_samples) <- c(class(marginal_posterior_samples), "marginal_posterior")
  return(marginal_posterior_samples)
}

.get_combined_parameter_scaling_factor_matrix <- function(term, prior_list, posterior, models_ind, nrow, simple_list = FALSE){

  if(simple_list){
    temp_multiply_by <- .get_parameter_scaling_factor_matrix(term, prior_list, posterior, nrow = nrow, ncol = nrow(posterior))
  }else{
    model_samples <- table(models_ind)

    temp_multiply_by <- do.call(cbind, lapply(unique(models_ind), function(m){
      temp_prior_list <- lapply(prior_list, function(parameter_priors) parameter_priors[[m]])
      temp_posterior  <- posterior[models_ind == m,,drop=FALSE]
      return(.get_parameter_scaling_factor_matrix(term, temp_prior_list, temp_posterior, nrow = nrow, ncol = sum(models_ind == m)))
    }))
  }

  return(temp_multiply_by)
}

.mix_priors                <- function(prior_list, seed = NULL, n_samples = 10000){

  check_list(prior_list, "prior_list")
  for(i in seq_along(prior_list)){
    if(any(!sapply(prior_list[[i]], is.prior)))
      stop("'prior_list' must be a list of prior distributions")
  }
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  ### get model indices
  prior_weights <- do.call(cbind, lapply(seq_along(prior_list), function(i){
    prior_weights <- sapply(prior_list[[i]], function(prior) prior[["prior_weights"]])
    prior_weights <- prior_weights / sum(prior_weights)
    return(prior_weights)
  }))
  if(!all(prior_weights[,1] == prior_weights))
    stop("the prior samples are not alligned across models/draws")
  prior_weights <- prior_weights[,1]

  # set seed only once at the beginning -- not in the individual draws as the priors will end up completely correlated
  if(is.null(seed)){
    seed <- sample(.Machine$integer.max, 1)
  }
  set.seed(seed)

  ### adapted from 'mix_posteriors'
  parameters <- names(prior_list)
  out        <- list()

  for(p in seq_along(parameters)){

    # prepare parameter specific values
    temp_parameter    <- parameters[p]
    temp_priors       <- prior_list[[temp_parameter]]

    if(any(sapply(temp_priors, is.prior.weightfunction)) && all(sapply(temp_priors, is.prior.weightfunction) | sapply(temp_priors, is.prior.point) | sapply(temp_priors, is.prior.none) | sapply(temp_priors, is.null))){
      # weightfunctions:

      # replace missing priors with default prior: none
      for(i in 1:length(temp_priors)){
        if(is.null(temp_priors[[i]])){
          temp_priors[[i]] <- prior_none(prior_weights = prior_weights[i])
        }
      }

      out[[temp_parameter]] <- .mix_priors.weightfunction(temp_priors, temp_parameter, NULL, n_samples)

    }else if(any(sapply(temp_priors, is.prior.factor)) && all(sapply(temp_priors, is.prior.factor) | sapply(temp_priors, is.prior.point) | sapply(temp_priors, is.null))){
      # factor priors

      # replace missing priors with default prior: spike(0)
      for(i in 1:length(temp_priors)){
        if(is.null(temp_priors[[i]])){
          temp_priors[[i]] <- prior("spike", parameters = list("location" = 0), prior_weights = prior_weights[i])
        }
      }

      out[[temp_parameter]] <- .mix_priors.factor(temp_priors, temp_parameter, NULL, n_samples)

    }else if(any(sapply(temp_priors, is.prior.vector)) && all(sapply(temp_priors, is.prior.vector) | sapply(temp_priors, is.prior.point) | sapply(temp_priors, is.null))){
      # vector priors:

      # replace missing priors with default prior: spike(0)
      for(i in 1:length(temp_priors)){
        if(is.null(temp_priors[[i]])){
          temp_priors[[i]] <- prior("spike", parameters = list("location" = 0), prior_weights = prior_weights[i])
        }
      }

      out[[temp_parameter]] <- .mix_priors.vector(temp_priors, temp_parameter, NULL, n_samples)

    }else if(all(sapply(temp_priors, is.prior.simple) | sapply(temp_priors, is.prior.point) | sapply(temp_priors, is.null))){
      # simple priors:

      # replace missing priors with default prior: spike(0)
      for(i in 1:length(temp_priors)){
        if(is.null(temp_priors[[i]])){
          temp_priors[[i]] <- prior("spike", parameters = list("location" = 0), prior_weights = prior_weights[i])
        }
      }

      out[[temp_parameter]] <- .mix_priors.simple(temp_priors, temp_parameter, NULL, n_samples)

    }else{
      stop("The posterior samples cannot be mixed: unsupported mixture of prior distributions.")
    }

    # add formula relevant information
    if(!is.null(unique(unlist(lapply(temp_priors, attr, which = "parameter"))))){
      class(out[[temp_parameter]]) <- c(class(out[[temp_parameter]]), "mixed_posteriors.formula")
      attr(out[[temp_parameter]], "formula_parameter")  <- unique(unlist(lapply(temp_priors, attr, which = "parameter")))
    }
  }

  return(out)
}
.mix_priors.simple         <- function(priors, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(priors, "priors")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")
  if(!all(sapply(priors, is.prior.simple) | sapply(priors, is.prior.point)))
    stop("'priors' must be a list of simple priors")

  # get prior model probabilities
  prior_probs <- sapply(priors, function(prior) prior[["prior_weights"]])
  prior_probs <- prior_probs / sum(prior_probs)

  # do not set seed when sampling multiple priors for the same model -- they will end up completely correlated
  if(!is.null(seed)){
    set.seed(seed)
  }

  # prepare output objects
  samples <- NULL
  sample_ind <- NULL
  models_ind <- NULL

  # mix samples
  for(i in seq_along(priors)[ceiling(prior_probs * n_samples) >= 1]){

    # sample indexes
    temp_ind <- 1:ceiling(n_samples * prior_probs[i])

    # sample prior
    samples <- c(samples, rng(priors[[i]], length(temp_ind), transform_factor_samples = FALSE))

    sample_ind <- c(sample_ind, temp_ind)
    models_ind <- c(models_ind, rep(i, length(temp_ind)))
  }

  # assure the correct number of samples
  samples    <- samples[1:n_samples]
  sample_ind <- sample_ind[1:n_samples]
  models_ind <- models_ind[1:n_samples]

  samples <- unname(samples)
  attr(samples, "sample_ind") <- sample_ind
  attr(samples, "models_ind") <- models_ind
  attr(samples, "parameter")  <- parameter
  attr(samples, "prior_list") <- priors
  class(samples) <- c("mixed_posteriors", "mixed_posteriors.simple")

  return(samples)
}
.mix_priors.vector         <- function(priors, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(priors, "priors")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")
  if(!all(sapply(priors, is.prior.vector) | sapply(priors, is.prior.point)))
    stop("'priors' must be a list of vector priors")

  # get prior model probabilities
  prior_probs <- sapply(priors, function(prior) prior[["prior_weights"]])
  prior_probs <- prior_probs / sum(prior_probs)

  # do not set seed when sampling multiple priors for the same model -- they will end up completely correlated
  if(!is.null(seed)){
    set.seed(seed)
  }

  # prepare output objects
  K <- unique(sapply(priors[sapply(priors, is.prior.vector)], function(p) p$parameters[["K"]]))
  if(length(K) != 1)
    stop("all vector priors must be of the same length")

  samples    <- matrix(nrow = 0, ncol = K)
  sample_ind <- NULL
  models_ind <- NULL

  # mix samples
  for(i in seq_along(priors)[ceiling(prior_probs * n_samples) > 1]){

    # sample indexes
    temp_ind <- 1:ceiling(n_samples * prior_probs[i])

    if(is.prior.point(priors[[i]]) & is.prior.simple(priors[[i]])){
      # not sampling the priors in case they were imputed (missing dimensions)
      samples <- rbind(samples, matrix(priors[[i]]$parameters[["location"]], nrow = length(temp_ind), ncol = K))
    }else if(K == 1){
      samples <- rbind(samples, matrix(rng(priors[[i]], length(temp_ind), transform_factor_samples = FALSE), nrow = length(temp_ind), ncol = K))
    }else{
      samples <- rbind(samples, rng(priors[[i]], length(temp_ind), transform_factor_samples = FALSE))
    }

    sample_ind <- c(sample_ind, temp_ind)
    models_ind <- c(models_ind, rep(i, length(temp_ind)))
  }

  # assure the correct number of samples
  samples    <- samples[1:n_samples,,drop=FALSE]
  sample_ind <- sample_ind[1:n_samples]
  models_ind <- models_ind[1:n_samples]

  rownames(samples) <- NULL
  colnames(samples) <- paste0(parameter,"[",1:K,"]")
  attr(samples, "sample_ind") <- sample_ind
  attr(samples, "models_ind") <- models_ind
  attr(samples, "parameter")  <- parameter
  attr(samples, "prior_list") <- priors
  class(samples) <- c("mixed_posteriors", "mixed_posteriors.vector")

  return(samples)
}
.mix_priors.factor         <- function(priors, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(priors, "priors")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")
  if(!all(sapply(priors, is.prior.factor) | sapply(priors, is.prior.point)))
    stop("'priors' must be a list of factor priors")

  # get prior model probabilities
  prior_probs <- sapply(priors, function(prior) prior[["prior_weights"]])
  prior_probs <- prior_probs / sum(prior_probs)

  # check the prior levels
  levels <- unique(sapply(priors[sapply(priors, is.prior.factor)], .get_prior_factor_levels))
  if(length(levels) != 1)
    stop("all factor priors must be of the same number of levels")

  # gather and check compatibility of prior distributions
  priors_info <- lapply(priors, function(p){
    if(is.prior.point(p) | is.prior.none(p)){
      return(FALSE)
    }else if(is.prior.factor(p)){
      return(list(
        "levels"      = .get_prior_factor_levels(p),
        "level_names" = .get_prior_factor_level_names(p),
        "interaction" = .is_prior_interaction(p),
        "treatment"   = is.prior.treatment(p),
        "independent" = is.prior.independent(p),
        "orthonormal" = is.prior.orthonormal(p),
        "meandif"     = is.prior.meandif(p)
      ))
    }else{
      stop("unsupported prior type")
    }
  })
  priors_info <- priors_info[!sapply(priors_info, isFALSE)]
  if(length(priors_info) >= 2 && any(!unlist(lapply(priors_info, function(i) all.equal(i, priors_info[[1]]))))){
    stop("non-matching prior factor type specifications")
  }
  priors_info <- priors_info[[1]]

  if(priors_info[["treatment"]]){

    if(levels == 1){

      samples <- .mix_priors.simple(priors, parameter, seed, n_samples)

      sample_ind <- attr(samples, "sample_ind")
      models_ind <- attr(samples, "models_ind")

      samples <- matrix(samples, ncol = 1)

    }else{

      samples <- lapply(1:levels, function(i) .mix_priors.simple(priors, paste0(parameter, "[", i, "]"), seed, n_samples))

      sample_ind <- attr(samples[[1]], "sample_ind")
      models_ind <- attr(samples[[1]], "models_ind")

      samples <- do.call(cbind, samples)

    }

    rownames(samples) <- NULL
    colnames(samples) <- paste0(parameter,"[",priors_info$level_names[-1],"]")
    attr(samples, "sample_ind") <- sample_ind
    attr(samples, "models_ind") <- models_ind
    attr(samples, "parameter")  <- parameter
    attr(samples, "prior_list") <- priors
    class(samples) <- c("mixed_posteriors", "mixed_posteriors.factor", "mixed_posteriors.vector")

  }else if(priors_info[["independent"]]){

    if(levels == 1){

      samples <- .mix_priors.simple(priors, parameter, seed, n_samples)

      sample_ind <- attr(samples, "sample_ind")
      models_ind <- attr(samples, "models_ind")

      samples <- matrix(samples, ncol = 1)

    }else{

      samples <- lapply(1:levels, function(i) .mix_priors.simple(priors, paste0(parameter, "[", i, "]"), seed, n_samples))

      sample_ind <- attr(samples[[1]], "sample_ind")
      models_ind <- attr(samples[[1]], "models_ind")

      samples <- do.call(cbind, samples)

    }

    rownames(samples) <- NULL
    colnames(samples) <- paste0(parameter,"[",priors_info$level_names,"]")
    attr(samples, "sample_ind") <- sample_ind
    attr(samples, "models_ind") <- models_ind
    attr(samples, "parameter")  <- parameter
    attr(samples, "prior_list") <- priors
    class(samples) <- c("mixed_posteriors", "mixed_posteriors.factor", "mixed_posteriors.vector")

  }else if(priors_info[["orthonormal"]] | priors_info[["meandif"]]){

    for(i in seq_along(priors)){
      if(is.prior.factor(priors[[i]])){
        priors[[i]]$parameters[["K"]] <- levels
      }
    }

    samples <- .mix_priors.vector(priors, parameter, seed, n_samples)
    class(samples) <- c(class(samples), "mixed_posteriors.factor")

  }

  attr(samples, "levels")      <- priors_info[["levels"]]
  attr(samples, "level_names") <- priors_info[["level_names"]]
  attr(samples, "interaction") <- priors_info[["interaction"]]
  attr(samples, "treatment")   <- priors_info[["treatment"]]
  attr(samples, "independent") <- priors_info[["independent"]]
  attr(samples, "orthonormal") <- priors_info[["orthonormal"]]
  attr(samples, "meandif")     <- priors_info[["meandif"]]

  return(samples)
}
.mix_priors.weightfunction <- function(priors, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(priors, "priors")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")
  if(!all(sapply(priors, is.prior.weightfunction) | sapply(priors, is.prior.point) | sapply(priors, is.prior.none)))
    stop("'priors' must be a list of weightfunction priors distributions")

  # get prior model probabilities
  prior_probs <- sapply(priors, function(prior) prior[["prior_weights"]])
  prior_probs <- prior_probs / sum(prior_probs)

  # do not set seed when sampling multiple priors for the same model -- they will end up completely correlated
  if(!is.null(seed)){
    set.seed(seed)
  }

  # obtain mapping for the weight coefficients
  omega_mapping <- weightfunctions_mapping(priors)
  omega_cuts    <- weightfunctions_mapping(priors, cuts_only = TRUE)
  omega_names   <- sapply(1:(length(omega_cuts)-1), function(i)paste0("omega[",omega_cuts[i],",",omega_cuts[i+1],"]"))

  # prepare output objects
  samples    <- matrix(nrow = 0, ncol = length(omega_cuts) - 1)
  sample_ind <- NULL
  models_ind <- NULL

  # mix samples
  for(i in seq_along(priors)[ceiling(prior_probs * n_samples) > 1]){

    # sample indexes
    temp_ind <- 1:ceiling(n_samples * prior_probs[i])

    if(is.prior.none(priors[[i]])){
      samples <- rbind(samples, matrix(1, ncol = length(omega_cuts) - 1, nrow = length(temp_ind)))
    }else{
      # create temp samples so names can be matched by mapping
      temp_samples <- rng(priors[[i]], length(temp_ind))
      colnames(temp_samples) <- paste0("omega[",1:ncol(temp_samples),"]")
      samples <- rbind(samples, temp_samples[, paste0("omega[",omega_mapping[[i]],"]")])
    }

    sample_ind <- c(sample_ind, temp_ind)
    models_ind <- c(models_ind, rep(i, length(temp_ind)))
  }

  # assure the correct number of samples
  samples    <- samples[1:n_samples,,drop=FALSE]
  sample_ind <- sample_ind[1:n_samples]
  models_ind <- models_ind[1:n_samples]

  rownames(samples) <- NULL
  colnames(samples) <- omega_names
  attr(samples, "sample_ind") <- sample_ind
  attr(samples, "models_ind") <- models_ind
  attr(samples, "parameter")  <- parameter
  attr(samples, "prior_list") <- priors
  class(samples) <- c("mixed_posteriors", "mixed_posteriors.weightfunction")

  return(samples)
}

.as_mixed_priors            <- function(prior_list, seed = NULL, n_samples = 10000, conditional = NULL, conditional_rule = NULL){

  check_list(prior_list, "prior_list")
  if(any(!sapply(prior_list, is.prior)))
    stop("'prior_list' must be a list of prior distributions")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  # set seed only once at the beginning -- not in the individual draws as the priors will end up completely correlated
  if(is.null(seed)){
    seed <- sample(.Machine$integer.max, 1)
  }
  set.seed(seed)

  # adapted from 'as_mixed_posteriors'
  parameters <- names(prior_list)
  out        <- list()

  # estimate the number of necessary samples for conditioning
  if(length(conditional) > 0){

    conditioning_probabilities <- sapply(conditional, function(parameter){

      temp_prior <- prior_list[[parameter]]

      if(is.prior.spike_and_slab(temp_prior)){

        return(mean(.get_spike_and_slab_inclusion(temp_prior)))

      }else if(is.prior.mixture(temp_prior)){

        components    <- attr(temp_prior, "components")
        prior_weights <- attr(temp_prior, "prior_weights")

        if(!all(components %in% c("null", "alternative")))
          stop("conditional mixture posterior distributions are available only for 'null' and 'alternative' components")

        return(sum(prior_weights[components == "alternative"]) / sum(prior_weights))

      }else{

        warning(sprintf("The parameter '%s' is not a conditional parameter. All samples are assumed to compe from the conditional posterior distribution.", parameter), call. = FALSE, immediate. = TRUE)
        return(1)
      }
    })
    # add a check when forwarding samples to marginal inference
    all_alternative <- all(sapply(conditional, function(parameter){

      temp_prior <- prior_list[[parameter]]

      if(is.prior.spike_and_slab(temp_prior)){
        return(mean(.get_spike_and_slab_inclusion(temp_prior)) == 1)
      }else if(is.prior.mixture(temp_prior)){
        return(all(attr(temp_prior, "components") == "alternative"))
      }else{
        return(TRUE)
      }
    }))

    # multiply by 1.25 to ensure that the requested number of samples is reached
    requested_samples <- n_samples
    if(conditional_rule == "AND"){
      n_samples <- round(n_samples / prod(conditioning_probabilities) * 1.25)
    }else if(conditional_rule == "OR"){
      n_samples  <- round(n_samples / (1 - prod(1 - conditioning_probabilities)) * 1.25)
    }
  }

  # create the samples
  for(p in seq_along(parameters)){

    # prepare parameter specific values
    temp_parameter <- parameters[p]
    temp_prior     <- prior_list[[temp_parameter]]

    if(is.prior.spike_and_slab(temp_prior)){
      # spike and slab priors
      out[[temp_parameter]] <- .as_mixed_priors.spike_and_slab(temp_prior, temp_parameter, NULL, n_samples)

    }else if(is.prior.mixture(temp_prior)){
      # mixture priors
      out[[temp_parameter]] <- .as_mixed_priors.mixture(temp_prior, temp_parameter, NULL, n_samples)

    }else if(is.prior.weightfunction(temp_prior)){
      # weightfunctions:
      out[[temp_parameter]] <- .as_mixed_priors.weightfunction(temp_prior, temp_parameter, NULL, n_samples)

    }else if(is.prior.factor(temp_prior)){
      # factor priors
      out[[temp_parameter]] <- .as_mixed_priors.factor(temp_prior, temp_parameter, NULL, n_samples)

    }else if(is.prior.vector(temp_prior)){
      # vector priors:
      out[[temp_parameter]] <- .as_mixed_priors.vector(temp_prior, temp_parameter, NULL, n_samples)

    }else if(is.prior.simple(temp_prior)){
      # simple priors:
      out[[temp_parameter]] <- .as_mixed_priors.simple(temp_prior, temp_parameter, NULL, n_samples)

    }else{
      stop("The posterior samples cannot be mixed: unsupported mixture of prior distributions.")
    }

    # add formula relevant information
    if(!is.null(attr(temp_prior, which = "parameter"))){
      class(out[[temp_parameter]]) <- c(class(out[[temp_parameter]]), "mixed_posteriors.formula")
      attr(out[[temp_parameter]], "formula_parameter")  <- attr(temp_prior, which = "parameter")
    }
  }


  # perform conditioning (and copy back with attributes)
  if(length(conditional) > 0){

    # obtain the indicator samples
    conditioning_samples <- do.call(cbind, lapply(conditional, function(parameter){

      temp_prior <- prior_list[[parameter]]

      if(is.prior.spike_and_slab(temp_prior)){

        return(attr(out[[parameter]], "models_ind") == 1)

      }else if(is.prior.mixture(temp_prior)){

        components <- attr(temp_prior, "components")
        return(attr(out[[parameter]], "models_ind") %in% which(components == "alternative"))

      }else{

        return(rep(TRUE, n_samples))
      }
    }))
    conditioning_samples <- apply(conditioning_samples, 1, ifelse(conditional_rule == "AND", all, any))

    # check enough samples were drawn (if too many remove the extra ones)
    if(sum(conditioning_samples) < requested_samples){
      warning(sprintf("Only %d samples were drawn from the prior distributions due to conditioning.", sum(conditioning_samples)))
    }else{
      conditioning_samples[which(conditioning_samples)[-(1:requested_samples)]] <- FALSE
    }

    # select the conditional samples (and copy attributes)
    for(p in seq_along(parameters)){
      temp <- attributes(out[[parameters[p]]])
      if(is.null(dim(out[[parameters[p]]]))){
        out[[parameters[p]]] <- out[[parameters[p]]][conditioning_samples]
        attributes(out[[parameters[p]]]) <- c(attributes(out[[parameters[p]]]), temp)
        attr(out[[parameters[p]]], "models_ind") <- attr(out[[parameters[p]]], "models_ind")[conditioning_samples]
      }else{
        out[[parameters[p]]] <- out[[parameters[p]]][conditioning_samples,,drop=FALSE]
        attributes(out[[parameters[p]]]) <- c(attributes(out[[parameters[p]]])[!names(attributes(out[[parameters[p]]])) %in% c("dimnames")], temp[!names(temp) %in% c("dim")])
        attr(out[[parameters[p]]], "models_ind") <- attr(out[[parameters[p]]], "models_ind")[conditioning_samples]
      }
    }

    # put a check whether all samples were conditional
    attr(out, "all_alternative") <- all_alternative
  }

  return(out)
}
.as_mixed_priors.simple         <- function(prior, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(prior, "prior")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  # do not set seed when sampling multiple prior for the same model -- they will end up completely correlated
  if(!is.null(seed)){
    set.seed(seed)
  }

  # prepare output objects
  samples <- rng(prior, n_samples, transform_factor_samples = FALSE)

  attr(samples, "sample_ind") <- FALSE
  attr(samples, "models_ind") <- FALSE
  attr(samples, "parameter")  <- parameter
  attr(samples, "prior_list") <- prior
  class(samples) <- c("mixed_posteriors", "mixed_posteriors.simple")

  return(samples)
}
.as_mixed_priors.vector         <- function(prior, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(prior, "prior")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  # do not set seed when sampling multiple prior for the same model -- they will end up completely correlated
  if(!is.null(seed)){
    set.seed(seed)
  }

  # prepare output objects
  K <- prior$parameters[["K"]]

  if(is.prior.point(prior) & is.prior.simple(prior)){
    # not sampling the prior in case they were imputed (missing dimensions)
    samples <- matrix(prior$parameters[["location"]], nrow = n_samples, ncol = K)
  }else if(K == 1){
    samples <- matrix(rng(prior, n_samples, transform_factor_samples = FALSE), nrow = n_samples, ncol = K)
  }else{
    samples <- rng(prior, n_samples, transform_factor_samples = FALSE)
  }

  rownames(samples) <- NULL
  colnames(samples) <- paste0(parameter,"[",1:K,"]")
  attr(samples, "sample_ind") <- FALSE
  attr(samples, "models_ind") <- FALSE
  attr(samples, "parameter")  <- parameter
  attr(samples, "prior_list") <- prior
  class(samples) <- c("mixed_posteriors", "mixed_posteriors.vector")

  return(samples)
}
.as_mixed_priors.factor         <- function(prior, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(prior, "prior")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  # check the prior levels
  levels <- .get_prior_factor_levels(prior)

  # gather and check compatibility of prior distributions
  prior_info <- list(
    "levels"      = .get_prior_factor_levels(prior),
    "level_names" = .get_prior_factor_level_names(prior),
    "interaction" = .is_prior_interaction(prior),
    "treatment"   = is.prior.treatment(prior),
    "independent" = is.prior.independent(prior),
    "orthonormal" = is.prior.orthonormal(prior),
    "meandif"     = is.prior.meandif(prior)
  )

  if(prior_info[["treatment"]]){

    if(levels == 1){

      samples <- .as_mixed_priors.simple(prior, parameter, seed, n_samples)
      samples <- matrix(samples, ncol = 1)

    }else{

      samples <- lapply(1:levels, function(i) .as_mixed_priors.simple(prior, paste0(parameter, "[", i, "]"), seed, n_samples))
      samples <- do.call(cbind, samples)

    }

    rownames(samples) <- NULL
    colnames(samples) <- paste0(parameter,"[",prior_info$level_names[-1],"]")
    attr(samples, "sample_ind") <- FALSE
    attr(samples, "models_ind") <- FALSE
    attr(samples, "parameter")  <- parameter
    attr(samples, "prior_list") <- prior
    class(samples) <- c("mixed_posteriors", "mixed_posteriors.factor", "mixed_posteriors.vector")

  }else if(prior_info[["independent"]]){

    if(levels == 1){

      samples <- .as_mixed_priors.simple(prior, parameter, seed, n_samples)
      samples <- matrix(samples, ncol = 1)

    }else{

      samples <- lapply(1:levels, function(i) .as_mixed_priors.simple(prior, paste0(parameter, "[", i, "]"), seed, n_samples))
      samples <- do.call(cbind, samples)

    }

    rownames(samples) <- NULL
    colnames(samples) <- paste0(parameter,"[",prior_info$level_names,"]")
    attr(samples, "sample_ind") <- FALSE
    attr(samples, "models_ind") <- FALSE
    attr(samples, "parameter")  <- parameter
    attr(samples, "prior_list") <- prior
    class(samples) <- c("mixed_posteriors", "mixed_posteriors.factor", "mixed_posteriors.vector")

  }else if(prior_info[["orthonormal"]] | prior_info[["meandif"]]){

    prior$parameters[["K"]] <- levels
    samples <- .as_mixed_priors.vector(prior, parameter, seed, n_samples)
    class(samples) <- c(class(samples), "mixed_posteriors.factor")

  }

  attr(samples, "levels")      <- prior_info[["levels"]]
  attr(samples, "level_names") <- prior_info[["level_names"]]
  attr(samples, "interaction") <- prior_info[["interaction"]]
  attr(samples, "treatment")   <- prior_info[["treatment"]]
  attr(samples, "independent") <- prior_info[["independent"]]
  attr(samples, "orthonormal") <- prior_info[["orthonormal"]]
  attr(samples, "meandif")     <- prior_info[["meandif"]]

  return(samples)
}
.as_mixed_priors.weightfunction <- function(prior, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(prior, "prior")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  # do not set seed when sampling multiple prior for the same model -- they will end up completely correlated
  if(!is.null(seed)){
    set.seed(seed)
  }

  # obtain mapping for the weight coefficients
  omega_mapping <- weightfunctions_mapping(list(prior))
  omega_cuts    <- weightfunctions_mapping(list(prior), cuts_only = TRUE)
  omega_names   <- sapply(1:(length(omega_cuts)-1), function(i)paste0("omega[",omega_cuts[i],",",omega_cuts[i+1],"]"))

  # prepare output objects
  samples <- rng(prior, n_samples)

  rownames(samples) <- NULL
  colnames(samples) <- omega_names
  attr(samples, "sample_ind") <- FALSE
  attr(samples, "models_ind") <- FALSE
  attr(samples, "parameter")  <- parameter
  attr(samples, "prior_list") <- prior
  class(samples) <- c("mixed_posteriors", "mixed_posteriors.weightfunction")

  return(samples)
}
.as_mixed_priors.spike_and_slab <- function(prior, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(prior, "prior")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  # do not set seed when sampling multiple prior for the same model -- they will end up completely correlated
  if(!is.null(seed)){
    set.seed(seed)
  }

  prior_variable   <- .get_spike_and_slab_variable(prior)
  prior_inclusion  <- .get_spike_and_slab_inclusion(prior)

  inclusion <- stats::rbinom(n_samples, size = 1, prob = rng(prior_inclusion, n_samples))

  if(is.prior.factor(prior_variable)){

    samples <- .as_mixed_priors.factor(prior_variable, parameter, seed, n_samples)

  }else if(is.prior.simple(prior_variable)){

    samples <- .as_mixed_priors.simple(prior_variable, parameter, seed, n_samples)

  }

  # merge with names and attributes
  samples       <- samples * inclusion

  class(samples) <- c(class(samples), "mixed_posteriors.spike_and_slab")
  attr(samples, "sample_ind") <- FALSE
  attr(samples, "models_ind") <- inclusion
  attr(samples, "prior_list") <- prior

  return(samples)
}
.as_mixed_priors.mixture        <- function(prior, parameter, seed = NULL, n_samples = 10000){

  # check input
  check_list(prior, "prior")
  check_char(parameter, "parameter")
  check_real(seed, "seed", allow_NULL = TRUE)
  check_int(n_samples, "n_samples")

  is_PET            <- sapply(prior, is.prior.PET)
  is_PEESE          <- sapply(prior, is.prior.PEESE)
  is_weightfunction <- sapply(prior, is.prior.weightfunction)

  if(any(is_PET | is_PEESE | is_weightfunction)){

    stop("not implemented yet")  # probably not needed
    # samples <- NULL
    #
    # if(any(is_PET)){
    #   samples <- cbind(samples, .as_mixed_posteriors.simple(fit, prior[is_PET][[1]], "PET"))
    # }
    # if(any(is_PEESE)){
    #   samples <- cbind(.as_mixed_posteriors.simple(fit, prior[is_PEESE][[1]], "PEESE"))
    # }
    # if(any(is_weightfunction)){
    #   # create a dummy prior with all the cuts
    #   dummy_prior <- #TODO:
    #   samples     <- cbind(.as_mixed_posteriors.weightfunction(fit, dummy_prior, "omega"))
    # }
    #
    # samples <- .as_mixed_posteriors.factor(fit, prior_variable, parameter)
    # attr(samples, "models_ind") <- as.vector(model_samples[,paste0(parameter, "_indicator")])

  }else{

    is_simple <- sapply(prior, is.prior.simple)
    is_factor <- sapply(prior, is.prior.factor)

    if(any(is_factor)){
      temp_samples <- .mix_priors.factor(prior, parameter = parameter, seed = seed, n_samples = n_samples)
    }else{
      temp_samples <- .mix_priors.simple(prior, parameter = parameter, seed = seed, n_samples = n_samples)
    }

  }

  # the samples parameters need to be randomly shuffled
  # (the  .mix_priors.XXX functions generate the samples model by model to keep bridge-sampling model-averaging consistent structure,
  #  this however does not apply to the spike an slab priors)
  random_ind <- sample(n_samples)
  if(is.null(dim(temp_samples))){
    samples <- temp_samples[random_ind]
  }else{
    samples <- temp_samples[random_ind,,drop=FALSE]
  }
  attributes(samples) <- attributes(temp_samples)
  attr(samples, "sample_ind") <- FALSE
  attr(samples, "models_ind") <- attr(samples, "models_ind")[random_ind]

  # append classes and priors
  class(samples) <- c(class(samples), "mixed_posteriors.mixture")
  attr(samples, "prior_list") <- prior

  return(samples)
}

#' @title Compute Savage-Dickey inclusion Bayes factors
#'
#' @description Computes Savage-Dickey (density ratio) inclusion Bayes factors
#' based the change of height from prior to posterior distribution at the test value.
#'
#' @param posterior marginal posterior distribution generated via the
#' \code{marginal_posterior} function
#' @param null_hypothesis point null hypothesis to test. Defaults to \code{0}
#' @param normal_approximation whether the height of prior and posterior density should be
#' approximated via a normal distribution (rather than kernel density). Defaults to \code{FALSE}.
#' @param silent whether warnings should be returned silently. Defaults to \code{FALSE}
#'
#'
#' @return \code{Savage_Dickey_BF} returns a Bayes factor.
#'
#' @export
Savage_Dickey_BF <- function(posterior, null_hypothesis = 0, normal_approximation = FALSE, silent = FALSE){

  if(!inherits(posterior, "marginal_posterior"))
    stop("'BF_savage_dickey' function requires an object of class 'marginal_posteriors'")
  check_real(null_hypothesis, "null_hypothesis")
  check_bool(normal_approximation, "normal_approximation")
  check_bool(silent, "silent")

  if(is.list(posterior)){
    bf <- list()
    for(i in seq_along(posterior)){
      bf[[i]] <- .Savage_Dickey_BF.fun(posterior[[i]], null_hypothesis, normal_approximation, silent)
    }
    names(bf) <- names(posterior)
  }else{
    bf <- .Savage_Dickey_BF.fun(posterior, null_hypothesis, normal_approximation, silent)
  }

  return(bf)
}

.Savage_Dickey_BF.fun    <- function(posterior, null_hypothesis, normal_approximation, silent){

  if(is.null(attr(posterior, "prior_samples")))
    stop("there are no prior samples for the posterior distribution", call. = FALSE)

  if (!is.null(attr(attr(posterior, "prior_samples"),"all_alternative")) && attr(attr(posterior, "prior_samples"),"all_alternative"))
    return(NA) # all prior samples come from alternative distributions --- there is no hypothesis to test

  prior <- attr(posterior, "prior_samples")

  warnings <- NULL

  if(mean(posterior == null_hypothesis) > 0.05){
    warnings <- c(warnings, "There is a considerable cluster of posterior samples at the exact null hypothesis values. The Savage-Dickey density ratio is likely to be invalid.")
  }
  if(mean(prior == null_hypothesis) > 0.05){
    warnings <- c(warnings, "There is a considerable cluster of prior samples at the exact null hypothesis values. The Savage-Dickey density ratio is likely to be invalid.")
  }
  if(null_hypothesis < min(prior) || null_hypothesis > max(prior)){
    warnings <- c(warnings, "Prior samples do not span both sides of the null hypothesis. Check whether the prior distribution contain the null hypothesis in the first place. The Savage-Dickey density ratio is likely to be invalid.")
  }
  if(null_hypothesis < min(posterior) || null_hypothesis > max(posterior)){
    warnings <- c(warnings, "Posterior samples do not span both sides of the null hypothesis. The Savage-Dickey density ratio is likely to be overestimated.")
  }
  if(!silent && !is.null(warnings)){
    sapply(warnings, warning, call. = FALSE)
  }


  if(normal_approximation){
    posterior_height <- .Savage_Dickey_BF.normal(posterior, null_hypothesis)
    prior_height     <- .Savage_Dickey_BF.normal(prior, null_hypothesis)
  }else{
    posterior_height <- .Savage_Dickey_BF.kd(posterior, null_hypothesis)
    prior_height     <- .Savage_Dickey_BF.kd(prior, null_hypothesis)
  }

  BF <- exp(log(prior_height) - log(posterior_height))

  if(!is.null(warnings)){
    attr(BF, "warnings") <- warnings
  }

  return(BF)
}
.Savage_Dickey_BF.normal <- function(samples, null_hypothesis){

  height <- stats::dnorm(null_hypothesis, mean = mean(samples), sd = stats::sd(samples))

  return(height)
}
.Savage_Dickey_BF.kd     <- function(samples, null_hypothesis){

  if(null_hypothesis < min(samples) || null_hypothesis > max(samples)){

    # the test value is outside of the samples
    height  <- 0

  }else{

    # use linear approximation to find the point
    density_posterior <- stats::density(samples)
    density_posterior.x <- c(
      density_posterior$x[which.max(density_posterior$x > null_hypothesis) - 1],
      density_posterior$x[which.max(density_posterior$x > null_hypothesis)]
    )
    density_posterior.y <- c(
      density_posterior$y[which.max(density_posterior$x > null_hypothesis) - 1],
      density_posterior$y[which.max(density_posterior$x > null_hypothesis)]
    )

    dif.y <- density_posterior.y[2] - density_posterior.y[1]
    dif.x <- density_posterior.x[2] - density_posterior.x[1]

    height <- density_posterior.y[1] + dif.y * (null_hypothesis - density_posterior.x[1])/dif.x
  }

  return(height)
}


#' @title Model-average marginal posterior distributions and
#' marginal Bayes factors
#'
#' @description Creates marginal model-averaged and conditional
#' posterior distributions based on a list of models, vector of parameters,
#' formula, and a list of indicators of the null or alternative hypothesis models
#' for each parameter. Computes inclusion Bayes factors for each
#' marginal estimate via a Savage-Dickey density approximation.
#'
#' @param marginal_parameters parameters for which the the marginal summary
#' should be created
#' @param parameters all parameters included in the model_list that are
#' relevant for the formula (all of which need to have specification of
#' \code{is_null_list})
#' @param seed seed for random number generation
#' @inheritParams ensemble_inference
#' @inheritParams marginal_posterior
#' @inheritParams Savage_Dickey_BF
#'
#' @return \code{marginal_inference} returns an object of class 'marginal_inference'.
#'
#' @seealso [ensemble_inference] [mix_posteriors] [BayesTools_ensemble_tables]
#'
#' @export
marginal_inference <- function(model_list, marginal_parameters, parameters, is_null_list, formula,
                               null_hypothesis = 0, normal_approximation = FALSE,
                               n_samples = 10000, seed = NULL, silent = FALSE){

  # check input (majority of the checks performed within mix_posteriors)
  check_list(model_list, "model_list")
  check_char(parameters, "parameters", check_length = FALSE)
  check_char(marginal_parameters, "marginal_parameters", check_length = FALSE)
  check_list(is_null_list, "is_null_list", check_length = length(parameters))
  if(!all(unlist(sapply(model_list, function(m) sapply(attr(m[["fit"]], "prior_list"), function(p) is.prior(p))))))
    stop("model_list:priors must contain 'BayesTools' priors")


  # create one full model-averaged ensemble
  averaged_posterior <- mix_posteriors(
    model_list   = model_list,
    parameters   = parameters,
    is_null_list = is_null_list,
    seed         = seed,
    n_samples    = n_samples,
    conditional  = FALSE
  )

  # prepare output object
  out <- list(
    conditional = list(),
    averaged    = list(),
    inference   = list()
  )

  for(i in seq_along(marginal_parameters)){

    if(all(is_null_list[[marginal_parameters[i]]])){
      warning(paste0("parameter '", marginal_parameters[i], "' does not contain any alternative hypothesis models."), immediate. = TRUE, call. = FALSE)
      next
    }

    # obtain model-averaged posterior conditional on including the parameter of interest
    # (different from individual conditionals)
    temp_conditional_posterior <- mix_posteriors(
      model_list   = model_list[!is_null_list[[marginal_parameters[i]]]],
      parameters   = parameters,
      is_null_list = lapply(is_null_list, function(l) l[!is_null_list[[marginal_parameters[i]]]]),
      seed         = seed,
      n_samples    = n_samples,
      conditional  = FALSE
    )

    # compute the marginals
    out[["averaged"]][[marginal_parameters[i]]] <- marginal_posterior(
      samples           = averaged_posterior,
      parameter         = marginal_parameters[i],
      formula           = formula,
      prior_samples     = TRUE,
      n_samples         = n_samples
    )
    out[["conditional"]][[marginal_parameters[i]]] <- marginal_posterior(
      samples           = temp_conditional_posterior,
      parameter         = marginal_parameters[i],
      formula           = formula,
      prior_samples     = TRUE,
      n_samples         = n_samples
    )

    # and inclusion Bayes factor
    out[["inference"]][[marginal_parameters[i]]] <- Savage_Dickey_BF(
      posterior            = out[["conditional"]][[marginal_parameters[i]]],
      null_hypothesis      = null_hypothesis,
      normal_approximation = normal_approximation,
      silent               = silent
    )
  }

  attr(out, "null_hypothesis")      <- null_hypothesis
  attr(out, "normal_approximation") <- normal_approximation
  class(out) <- c(class(out), "marginal_inference")
  return(out)
}


#' @title Model-average marginal posterior distributions and
#' marginal Bayes factors based on BayesTools JAGS model via \code{marginal_inference}
#'
#' @description Creates marginal model-averaged and conditional
#' posterior distributions based on a BayesTools JAGS model, vector of parameters,
#' formula, and a list of conditional specifications for each parameter.
#' Computes inclusion Bayes factors for each marginal estimate via a Savage-Dickey
#' density approximation.
#'
#' @param marginal_parameters parameters for which the the marginal summary
#' should be created
#' @param conditional_list list of conditional parameters for each marginal parameter
#' @param parameters all parameters included in the model_list that are
#' relevant for the formula (all of which need to have specification of
#' \code{is_null_list})
#' @inheritParams as_mixed_posteriors
#' @inheritParams marginal_inference
#' @inheritParams Savage_Dickey_BF
#'
#' @return \code{as_marginal_inference} returns an object of class 'marginal_inference'.
#'
#' @seealso [marginal_inference] [as_mixed_posteriors]
#'
#' @export
as_marginal_inference <- function(model, marginal_parameters, parameters, conditional_list, conditional_rule, formula,
                                  null_hypothesis = 0, normal_approximation = FALSE,
                                  n_samples = 10000, silent = FALSE, force_plots = FALSE){

  # check input (majority of the checks performed within mix_posteriors)
  # check input
  if(!inherits(model, "BayesTools_fit"))
    stop("'model' must be a 'BayesTools_fit'")
  check_char(parameters, "parameters", check_length = FALSE)
  check_char(marginal_parameters, "marginal_parameters", check_length = FALSE)
  check_list(conditional_list, "conditional_list", check_length = length(marginal_parameters))
  check_char(conditional_rule, "conditional_rule")

  priors <- attr(model, "prior_list")


  # create one full model-averaged ensemble
  averaged_posterior <- as_mixed_posteriors(
    model        = model,
    parameters   = parameters
  )

  # prepare output object
  out <- list(
    conditional = list(),
    averaged    = list(),
    inference   = list()
  )

  for(i in seq_along(marginal_parameters)){

    check_char(conditional_list[[marginal_parameters[i]]], sprintf("conditional_list[[%1$s]]", marginal_parameters[i]), check_length = FALSE, allow_values = parameters, allow_NULL = TRUE)

    # obtain model-averaged posterior conditional on including the parameter of interest
    # (different from individual conditionals)
    temp_conditional_posterior <- as_mixed_posteriors(
      model            = model,
      parameters       = parameters,
      conditional      = conditional_list[[marginal_parameters[i]]],
      conditional_rule = conditional_rule,
      force_plots       = force_plots
    )

    # skip the rest of the parameter because of impossibility of obtaining conditional samples
    if (length(temp_conditional_posterior) == 0)
      next

    # compute the marginals
    out[["averaged"]][[marginal_parameters[i]]] <- marginal_posterior(
      samples           = averaged_posterior,
      parameter         = marginal_parameters[i],
      formula           = formula,
      prior_samples     = TRUE,
      n_samples         = n_samples
    )
    out[["conditional"]][[marginal_parameters[i]]] <- marginal_posterior(
      samples           = temp_conditional_posterior,
      parameter         = marginal_parameters[i],
      formula           = formula,
      prior_samples     = TRUE,
      n_samples         = n_samples
    )

    # and inclusion Bayes factor
    out[["inference"]][[marginal_parameters[i]]] <- Savage_Dickey_BF(
      posterior            = out[["conditional"]][[marginal_parameters[i]]],
      null_hypothesis      = null_hypothesis,
      normal_approximation = normal_approximation,
      silent               = silent
    )
  }

  attr(out, "null_hypothesis")      <- null_hypothesis
  attr(out, "normal_approximation") <- normal_approximation
  class(out) <- c(class(out), "marginal_inference")
  return(out)
}
