# Ancestral state reconstruction (ASR) for discrete characters using a fixed-rates continuous-time Markov model (aka. "Mk model")
# Requires that states (or prior distributions) are known for all tips
# The transition matrix can either be provided, or can be estimated via maximum-likelihood fitting.
# Returns the loglikelihood of the model, the transition matrix and (optionally) the probabilities of ancestral states ("marginal ancestral_likelihoods") for all internal nodes of the tree.
# Uses the rerooting method introduced by Yang et al (1995), to infer marginal likelihoods of ancestral states.
# This function works similarly to phytools::rerootingMethod().
hsp_mk_model = function(tree, 
						tip_states,							# 1D numerical/character/factor array of size Ntips. Can also be NULL.
						Nstates = NULL,						# number of possible states. Can be NULL.
						tip_priors = NULL, 					# 2D numerical array of size Ntips x Nstates
						rate_model = "ARD",					# either "ER" or "SYM" or "ARD" or "SUEDE" or an integer vector mapping entries of the transition matrix to a set of independent rate parameters. The format and interpretation is the same as for index.matrix generated by the function get_transition_index_matrix(..).
						transition_matrix=NULL,				# either NULL, or a transition matrix of size Nstates x Nstates, such that transition_matrix * p gives the rate of change of probability vector p. If NULL, the transition matrix will be fitted via maximum-likelihood. The convention is that [i,j] gives the transition rate i-->j.
						root_prior = "empirical",			# can be 'flat', 'stationary', 'empirical' or a numeric vector of size Nstates
						Ntrials = 1,						# (int) number of trials (starting points) for fitting the transition matrix. Only relevant if transition_matrix=NULL.
						optimization_algorithm = "nlminb",	# either "optim" or "nlminb". What algorithm to use for fitting.
						store_exponentials = TRUE,
						check_input = TRUE,					# (bool) perform some basic sanity checks on the input data. Set this to FALSE if you're certain your input data is valid.
						Nthreads = 1){						# (integer) number of threads for running multiple fitting trials in parallel
    Ntips 			= length(tree$tip.label);
    Nnodes 			= tree$Nnode;
    Nedges 			= nrow(tree$edge);
	
	# basic error checking
	if((!is.null(tip_states)) && (!is.null(tip_priors))) stop("ERROR: tip_states and tip_priors are both non-NULL, but exactly one of them should be NULL")
	else if(is.null(tip_states) && is.null(tip_priors))  stop("ERROR: tip_states and tip_priors are both NULL, but exactly one of them should be non-NULL")
	if((!is.null(tip_states)) && (!is.numeric(tip_states))) stop(sprintf("ERROR: tip_states must be integers"))
	if(check_input){
		if((!is.null(names(tip_states))) && any(names(tip_states)!=tree$tip.label)) stop("ERROR: Names in tip_states and tip labels in tree don't match (must be in the same order).")
	}
	
	# find known_tips, extract known_subtree and synchronize known_tip_states with known_subtree
	if(!is.null(tip_states)){
		known_tips = which(!is.na(tip_states));
	}else{
		known_tips = which(rowSums(is.na(tip_priors))==0);
	}
	if(length(known_tips)==0) stop("ERROR: All tip states are hidden");
	extraction	 		= get_subtree_with_tips(tree, only_tips=known_tips, omit_tips=FALSE, collapse_monofurcations=TRUE, force_keep_root=TRUE);
	known_subtree		= extraction$subtree;
	known2all_tips		= extraction$new2old_tip;
	known2all_nodes		= extraction$new2old_node;
	if(length(known_subtree$tip.label)==0) stop("ERROR: Subtree with known tip-states is empty")
	if(!is.null(tip_states)){
		known_tip_states	= tip_states[known2all_tips]
		known_tip_priors	= NULL
	}else{
		known_tip_states	= NULL
		known_tip_priors	= tip_priors[known2all_tips,,drop=FALSE]
	}
	
	
	# perform ancestral state reconstruction on known_subtree
	asr_results = asr_mk_model(	tree					= known_subtree, 
								tip_states		 		= known_tip_states,
								Nstates 				= Nstates,
								tip_priors 				= known_tip_priors,
								rate_model 				= rate_model,
								transition_matrix 		= transition_matrix,
								include_ancestral_likelihoods = TRUE,
								root_prior 				= root_prior,
								Ntrials 				= Ntrials,
								optimization_algorithm 	= optimization_algorithm,
								store_exponentials 		= store_exponentials,
								check_input 			= check_input,
								Nthreads				= Nthreads);
	Nstates 			= ncol(asr_results$ancestral_likelihoods);
	loglikelihood 		= asr_results$loglikelihood
	transition_matrix 	= asr_results$transition_matrix
	if(is.null(loglikelihood) || is.nan(loglikelihood) || is.null(asr_results$ancestral_likelihoods)) return(list(loglikelihood=NULL, transition_matrix=NULL, likelihoods=NULL)); # ASR failed
	
	# forward-project posteriors to tips with hidden state
	likelihoods = matrix(0, nrow=(Ntips+Nnodes), ncol=Nstates);
	if(!is.null(tip_states)){
		likelihoods[known2all_tips, ] = 0.0;
		likelihoods[cbind(known2all_tips, known_tip_states)] = 1.0;
	}else{
		likelihoods[known2all_tips, ] = known_tip_priors;
	}
	likelihoods[known2all_nodes+Ntips, ] 		= asr_results$ancestral_likelihoods;


	likelihoods_known							= rep(FALSE, times=(Ntips+Nnodes))
	likelihoods_known[known2all_tips]			= TRUE;
	likelihoods_known[known2all_nodes+Ntips]  	= TRUE;
	likelihoods = apply_fixed_rate_Markov_model_to_missing_clades_CPP(	Ntips 					= Ntips,
																		Nnodes					= Nnodes,
																		Nedges					= Nedges,
																		Nstates					= Nstates,
																		tree_edge				= as.vector(t(tree$edge))-1,	# flatten in row-major format and make indices 0-based
																		edge_length 			= (if(is.null(tree$edge.length)) numeric() else tree$edge.length),
																		transition_matrix 		= as.vector(t(transition_matrix)), # flatten in row-major format
																		exponentiation_accuracy	= 1e-3,
																		max_polynomials			= 1000,
																		likelihoods_known		= likelihoods_known,
																		likelihoods				= as.vector(t(likelihoods))); # flatten in row-major format	
	likelihoods = matrix(likelihoods, ncol=Nstates, byrow=TRUE); # unflatten returned table
	colnames(likelihoods) = colnames(asr_results$ancestral_likelihoods);
		
	return(list(transition_matrix=transition_matrix, 
				loglikelihood=loglikelihood, 
				likelihoods=likelihoods));
}
