#include "cfbgibbs.h"
#include "matrix.h"
#include "ghmm.h"
#include "mes.h"
#include "ghmm_internals.h"
#include "sequence.h"
#include "rng.h"
#include "randvar.h"
#include <math.h>
#include "fbgibbs.h"

#ifdef HAVE_CONFIG_H
#  include "../config.h"
#endif



/* based on  Speeding Up Bayesian HMM by the Four Russians Method
 *
 * M. Mahmud and A. Schliep
 *
 * In Algorithms in Bioinformatics, Springer Berlin / Heidelberg, 6833, 188–200, 2011. 
 * /

//=====================================================================
//====================precompute position==============================
//=====================================================================

//we need to be able to get the matrix M^B(X) where B is M^maxorder 
//and X is an obs sequence of length 1, 2, ..., R
//we precompute this and store index in stored pos

/* computes tupleSize and preposition
 * R: compression lenght
 * M: alphabet size
 * tupleSize: 0-M -> M, M + M^2, ...
 * preposition: preposition[i][j] = j*M^i */
void precomputeposition(int R, int M, int *tupleSize, int **preposition, int flag){
  //printf("m=%d\n", R);
  int i,j;
  int pw[R];
  tupleSize[0] = 0;
  tupleSize[1] = flag ? 0 : 1;
  pw[0] = 1;
  for(i = 1; i < R; i++){
    pw[i] = pw[i-1]*M;
    tupleSize[i+1] = tupleSize[i] + pw[i];
  }
  for(i = 0; i < R; i++){
    for(j = 0; j < M; j++){
      preposition[i][j] = j*pw[i];
      //printf("prepos %d %d = %d\n", i, j, preposition[i][j]);
    }
  }
}

/* returns pos, the matrix correspoinding to the X = (o_start, o_start+1,..., obs_end) 
 * obs: observation sequence
 * start: start position in the sequence
 * end: end position in the sequence
 * tupleSize: 0-R -> M, M+M^2, ..
 * preposition: preposition[i][j] = j*M^i */
int position(int *obs, int start, int end, int *tupleSize, int **preposition){
  int pos = tupleSize[end-start];
  int i = 0;
  for(; start<end; i++, start++){
    pos += preposition[i][obs[start]];
  }
  return pos;
}

/* compute storedpos
 * R: compression length
 * T: length of obs 
 * storedpos: array to get matrix for each observation 0-limit -> M^B(X)
 * M: alphabet size
 * tupleSize: 0-M -> M, M + M^2, ...
 * preposition: preposition[i][j] = j*M^i */
void storeposition(int R, int T, int *obs, int *storedpos, int M, int *tupleSize, int **preposition){
    int j, s, e, pos;
    storedpos[0] = position(obs, 0, R, tupleSize, preposition);
    pos = position(obs, 1, R, tupleSize, preposition);
    storedpos[1] = pos;    
    for(j=2;j<R;j++){
        pos = pos/M -1;
        storedpos[j] = pos;
    }
    
    s = R;
    e = s + R;    
    while (1){
        pos = position(obs, s, e, tupleSize, preposition);
        storedpos[s] = pos;
        for(j=s+1;j<e;j++){
            pos = pos/M -1;//why not call position??
	    //printf("%d, e, pos = %d, position = %d\n",j,  pos,
		       	//position(obs, j, e, tupleSize, preposition));
            storedpos[j] = pos;
        }

        if (e == T)
            break;
        s += R;
        e += R;
        if (e>T)
            e = T;
    }
}

void storepositionH(int R, int M, int order, int T, int *obs,
	       	int *storedpos,  int *storedfpos, int *tupleSize, int**preposition,
	       	int *tupleSizeH, int **prepositionH){
    int j, s, e, si ;
    int pos, fpos;

    storedpos[0] = position(obs, 0, R, tupleSize, preposition);
    storedfpos[0] = 0;

    si = 1 - order;
        if(si<0)
            si = 0;
    pos = position(obs, 1, R, tupleSize, preposition);
    fpos = position(obs, si, 1, tupleSizeH, prepositionH);
    storedpos[1] = pos;
    storedfpos[1] = fpos;
    //printf("storedf %d = %d\n",1,  storedfpos[1]);
    for (j=2;j<R;j++){
        si = j - order;
        if(si<0)
            si = 0;
        pos = pos/M -1;
        //fpos = fpos/M -1;
        fpos = position(obs, si, j, tupleSizeH, prepositionH);
        //pos = position(obs, si, j, tupleSize, preposition);
        storedpos[j] = pos;
        storedfpos[j] = fpos;
        //printf("storedf %d = %d\n",j,  storedfpos[j]);
    }

    s = R ;
    e = s + R ;
    si = s - order;
    if(si<0)
        si = 0;

    while (1){
        pos = position(obs, s, e, tupleSize, preposition);
        fpos = position(obs, si, s, tupleSizeH, prepositionH);
        storedpos[s] = pos ;
        storedfpos[s] = fpos;

        //printf("3 storedf %d = %d\n",j,  storedfpos[j]);
        for (j=s+1;j<e;j++){
            si++;
            pos = pos/M -1;
            fpos = position(obs, si, j, tupleSizeH, prepositionH);
            storedpos[j] = pos;
            storedfpos[j] = fpos;
            //printf("4 storedf %d = %d   %d, %d\n",j,  storedfpos[j], si, j);
        }

        if (e == T)
            break;
        s += R;
        e += R;
        if (e>T)
            e = T;
        si = s- order ;
        if(si<0)
            si = 0;
    }
}

//=====================================================================
//=================precomputation matrix ==============================
//=====================================================================
/* R: compression length
 * mo: model
 * mats: M^B(X)
 * rmats: cdfs of mats used for sampling B3 */
void precompute(int R, ghmm_dmodel *mo, double ***mats, double**** rmats){
  int i, j, k;
  int limit = (pow(mo->M, R+1)-1)/(mo->M-1) -1;  
  //mats[i][j][k] i = obs; j,k indice of matrix   M(obs)_ik 
  for(i = 0; i < mo->N; i++){
    for(j = 0; j < mo->N; j++){
      for(k = 0; k < mo->M; k++){
        mats[k][i][j] = ghmm_dmodel_get_transition(mo, i,j)*mo->s[j].b[k];
        //printf("mats(%d, %d, %d) = %f \n",i,j,k, mats[i][j][k]);
      }
    }  
  }
  
 int pos = mo->M;
 int fpos = 0;
 int y = 0;
 double x;
 while(pos < limit){ 
     for(j = 0; j < mo->N; j++){
         for(k = 0; k < mo->N; k++){
             x = mats[y][j][0] * mats[fpos][0][k];
             rmats[pos][j][k][0] = x;
             //printf("rmats(%d, %d, %d, %d) = %f \n",pos,j,k,0, rmats[pos][j][k][0]);
             for(i = 1; i < mo->N; i++){
                 x += mats[y][j][i] * mats[fpos][i][k]; 
                 rmats[pos][j][k][i] = x;
                 //printf("rmats(%d, %d, %d, %d) = %f \n",pos,j,k,i, rmats[pos][j][k][i]);
             }
                 mats[pos][j][k] = x;
	//printf("mats(%d, %d, %d) = %f \n",pos,j,k, mats[pos][j][k]);
         }
    }
    pos++;
    y++;
    if(y == mo->M){
        y = 0;
        fpos = pos/mo->M - 1; 
        //printf("fpos = %d\n", fpos); 
    } 
  }
}

//-----------------------HO-----------------------------------------

void recursivemats(int pos, int fpos, int a, int b, int R,
                int totalobs, int* obs, int **mflag,
                double ****mats, double *****rmats,
                int* storedpos, int *storedfpos, ghmm_dmodel *mo){

    if(a==b-1)
        return;

    if(mflag[pos][fpos])
        return;

    int npos = storedpos[a+1];
    int nfpos = storedfpos[a+1]; 
    recursivemats(npos, nfpos, a+1, b, R, totalobs, obs, mflag, mats, rmats, storedpos, storedfpos, mo);
        
    int y = obs[a];
    double x;
    int j,k,l;
    for(j=0;j<mo->N;++j){
        for (k=0;k<mo->N;++k){
            x = mats[y][fpos][j][0] * mats[npos][nfpos][0][k] ;
            //printf("     1mats %d, %d, %d, %d = %f\n", y, fpos, j, 0, mats[y][fpos][j][0]);
            //printf("     mats %d, %d, %d, %d = %f\n",npos, nfpos, 0, k, mats[npos][nfpos][0][k]);
            rmats[pos][fpos][j][k][0] = x ;
            for (l=1; l<mo->N; ++l){
                x += mats[y][fpos][j][l] * mats[npos][nfpos][l][k];
               //printf("     1mats %d, %d, %d, %d = %f\n", y, fpos, j, l, mats[y][fpos][j][l]);
               //printf("     mats %d, %d, %d, %d = %f\n",npos, nfpos, l, k, mats[npos][nfpos][l][k]);
                rmats[pos][fpos][j][k][l] = x ;
                //printf("rmats %d, %d, %d, %d, %d = %f\n", npos, nfpos, j, k, l, x);
            }
            //printf("mats %d, %d, %d, %d = %f\n", pos, fpos, j, k, x);
            mats[pos][fpos][j][k] = x ;
        }
    }
    mflag[pos][fpos] = 1;
}


void lazyrecmats(int R, int totalobs, int *obs, int **mflag,
                double ****mats, double *****rmats, 
                int *storedpos, int *storedfpos, ghmm_dmodel* mo
                ){

    int s, e;
    for(s=0,e=R; ;){
        
        int pos = storedpos[s];
        int fpos = storedfpos[s];
        //printf("%d, %d\n", pos, fpos);
        recursivemats( pos, fpos, s, e, R, totalobs, obs, mflag, mats, rmats, storedpos, storedfpos, mo);

        s += R;
        e += R;
        if(s>=totalobs)
            break;
        if(e>totalobs)
            e = totalobs;
    }
}



void precomputedmatsH(int totalobs, int *obs, int R,
        double ****mats, double *****rmats,
        int **mflag,
        int *storedpos, int *storedfpos, ghmm_dmodel* mo){

    int i,j,k,l,m,n,e,pos,fpos;
    int dsize = (pow(mo->M, mo->maxorder+1)-1)/(mo->M-1);
    int size = (pow(mo->M, R+1)-1)/(mo->M-1);
    int read[(int)pow(mo->M, mo->maxorder+1)];
    int write[(int)pow(mo->M, mo->maxorder+1)];
    int *tmp1 = read;
    int *tmp2 = write;
    int *tmp3;
    for(i=0; i<size; i++) 
        for(j=0;j<dsize;j++)
            mflag[i][j] = 0;
    for(i=0;i<mo->M;i++)
        for(j=0;j<dsize;j++)
            mflag[i][j] = 1;

    fpos = 0;
    n = 0;
    for(i=0; i <= mo->maxorder;i++)
    {
        for(l = 0; l<pow(mo->M, i); l++)//fpos
        {
            for(pos = 0; pos < mo->M; pos++)
            { 
                
                for(j=0; j < mo->N; j++)
                {
                    for (k=0; k < mo->N; k++)
                    {
                        if(i == 0){
                            mo->emission_history = 0;
                            if(mo->order[k] ==0){
                                mats[pos][fpos][j][k] =
                                    //ghmm_dmodel_get_transition(mo, j,k)*mo->s[k].b[pos];
                                    mo->s[k].pi*mo->s[k].b[pos];
                            }
                            else{
                                mats[pos][fpos][j][k] = 0;
                            } 
                        }
                        else{
                            mo->emission_history = tmp1[l];
                            e = get_emission_index(mo, k, pos, i);
                            //printf("e = %d, tmp = %d, b = %f ; ", e, tmp1[l], mo->s[k].b[e]);
                            if(e == -1)
                                mats[pos][fpos][j][k] = 0;
                            else{
                                mats[pos][fpos][j][k] =
                                    ghmm_dmodel_get_transition(mo, j,k)*mo->s[k].b[e];
                            }
                        }
                        //printf("mats %d, %d, %d, %d, = %f\n", pos, fpos, j, k, 
                                //mats[pos][fpos][j][k]);
                    }
                }
                update_emission_history(mo, pos);
                tmp2[n] = mo->emission_history;                    
                n++;
            }
            fpos++;
        }
        tmp3 = tmp2;
        tmp2 = tmp1; 
        tmp1 = tmp3;
        n=0;
    }
    //printf("\n");
    lazyrecmats(R, totalobs, obs, mflag, mats, rmats, storedpos, storedfpos, mo);
}


//===================================================================
//============= compressed sample state path ========================
//===================================================================
/* samples a dscrete distribution
 * seed: seed for uniform random sampler
 * distrubution: discrete distribution to sample
 * N: size of distrubution */
int samplebinsearch(int seed, double *distribution, int N){
    double total = distribution[N-1];
    double rn = ighmm_rand_uniform_cont(seed, total, 0);
    int l = 0;
    int r = N-1;
    int m;
    while (1){
        m = (l+r)>>1;
        if (distribution[m] < rn)
        {
            if(l==m)
                return r;
            else
                l = m;
        }
        else if (distribution[m] > rn)
        {
            if(l==m)
                return l;
            else
                r = m;
        }
        else
        {
           return m;
        }
    }
}
/* samples path 
 * seed: seed for random generator
 * T: length of observation sequence
 * obs: observeration sequence
 * fwds: forward variable
 * R: length of compression
 * mats: matrix form of foward algorithm used for compression
 * rmats: cdf of  mats used for sampling
 * states: states
 * storedpos: get matrix for observation
 * sneak: delta in pavels paper, cdfs of forwards
 * N: number of states */
void csamplestatepath(int T, int *obs,
        double **fwds, int R,
        double ***mats, double ****rmats,
        int *states, int* storedpos, double ***sneak, int N){
    double *distribution;
    int pos, cs, js, je;
    int p, s, e;
    int md = T%R ;
    
    if (md == 0){
        md = R;
        p = T/R;
    }
    else{
        p = T/R + 1;
    }
    e = T;
    s = T - md;
    double tmp[N];
    tmp[0] = fwds[p][0];
    int i;
    for(i = 1; i <N; i++)
       tmp[i] = tmp[i-1] + fwds[p][i];

    states[e-1] = sample(0, tmp, N);
    //printf("state %d = %d\n", e-1, states[e-1]);
           
    while (s>=0){
        p--;
        cs = states[e-1];

        if(s>0){
            js = s - 1;
            pos = storedpos[s];                       
            states[js] = samplebinsearch(0, sneak[p+1][cs], N);
        }
        else{
            js = 0 ;
            pos = storedpos[1] ;
            states[js] = samplebinsearch(0, sneak[p+1][cs], N);
        }
        
        //printf("state %d = %d\n",js, states[js]);
        je = md + s -2;
        
        for (;js<je;js++){
            distribution = rmats[pos][states[js]][cs];
            states[js+1] = samplebinsearch(0, distribution, N);
            pos = storedpos[js+2];
            //printf("state %d = %d\n",js+1, states[js+1]);
        }

        md = R;
        s -= md;
        e = s + md;
    }
}

void csamplestatepathH(int T, int *obs,
        double **fwds, int R, int N, 
        double ****mats, double *****rmats,
        int *states, int *storedpos, int *storedfpos,double ***sneak){

    int j, s, e, p, md;
    int pos, cs, js, je, fpos, si;
    double *distribution;

    md = T%R;
    if (md == 0){
        md = R;
        p = T/R;
    }
    else{
        p = T/R + 1;
    }
    e = T;
    s = T - md;
    je = T - 2;
 
    double tmp[N];
    tmp[0] = fwds[p][0];
    int i;
    for(i = 1; i <N; i++)
       tmp[i] = tmp[i-1] + fwds[p][i];

   
    states[e-1] = sample(0, tmp, N);
    //printf("state %d = %d\n", e-1, states[e-1]);

    while ( s >= 0 ){
        p-- ;
        cs = states[e-1];
        if(s>0)
        {
            js = s - 1 ;
            pos = storedpos[s];
            fpos = storedfpos[s];
            states[js] = samplebinsearch(0, sneak[p+1][cs], N);
            //printf("state  %d = %d\n", js, states[js]);
        }
        else
        {
            js = 0;
            pos = storedpos[1];
            fpos = storedfpos[1];
            states[js] = samplebinsearch(0, sneak[p+1][cs], N);
            //printf("state  %d = %d\n", js, states[js]);
        }        

        for (;js<je;js++){
            distribution = rmats[pos][fpos][states[js]][cs];
            states[js+1] = samplebinsearch(0, distribution, N);
            //printf("  total %f, %d, %d, %d, %d", distribution[N-1], pos, fpos, states[js],cs);
            //printf("  state %d = %d\n", js+1, states[js+1]);
            pos = storedpos[js+2];
            fpos = storedfpos[js+2];
        }
        s -= R;
        e = s + R;
        je = R + s - 2;
    }
}


//===================================================================
//==================compressed forwards==============================
//===================================================================
/* calulates forward variables and cdfs as a byproduct
 * totalobs: length of observation sequence
 * obs: observation sequence
 * R: compression length
 * fwds: forward variables
 * mats: matrix of forward varibles used for compression
 * storedpos: gets matrix for observation
 * sneak: cdf of mats columns */
void cforwards(int totalobs, int* obs, ghmm_dmodel *mo, int R, double **fwds, 
               double ***mats, int *storedpos, double ***sneak){
#define CUR_PROC "cforwards"
    int i,j,k;
    double sum = 0, tv;
    int s, e ;
    int pos;
    for (j=0;j<mo->N;j++){
        fwds[0][j] = mo->s[j].pi*mo->s[j].b[obs[0]];
        sum += fwds[0][j];
    }
   
    if(sum > GHMM_EPS_PREC){
        for (j=0;j<mo->N;j++){
            fwds[0][j] /= sum ;
        }
    }

    i = 1; 
    pos = storedpos[1];
    

    for (j=0;j<mo->N;j++){
        tv = fwds[0][0]*mats[pos][0][j];

        sneak[i][j][0] = tv;

        for (k=1;k<mo->N;k++){
            tv += fwds[0][k]*mats[pos][k][j];
            sneak[i][j][k] = tv;
        }
        fwds[i][j] = tv;
        sum += tv;
    }
    if(sum > GHMM_EPS_PREC){
        for(j=0;j<mo->N;j++)
            fwds[i][j] /= sum;
    }

    i = 2;
    s = R;
    e = 2*R;
    
    while (1){

        pos = storedpos[s];
        sum = 0;
        for (j=0;j<mo->N;j++){
            tv = fwds[i-1][0]*mats[pos][0][j];
            sneak[i][j][0] = tv;
            for (k=1;k<mo->N;k++){
                tv += fwds[i-1][k]*mats[pos][k][j];
                sneak[i][j][k] = tv;
            }
            fwds[i][j] = tv;
            sum += tv;
        }
        if( sum > GHMM_EPS_PREC){
            for(j=0;j<mo->N;j++) 
                fwds[i][j] /= sum ;
        }
            
        i++;
        s += R;
        if ( s >= totalobs )
            break;
    }        
#undef CUR_PROC
}

void cforwardsH(int T, int *obs, ghmm_dmodel *mo,
        int R, double **fwds,
        double ****mats, int *storedpos, int *storedfpos, double ***sneak
){

    int i,j,k;
    double sum, tv;
    int s, e, si ;
    int pos, fpos;
    sum = 0;
    //fist column of forwards
    for (j=0;j<mo->N;j++){
        if(mo->order[j] == 0)
            fwds[0][j] = mo->s[j].pi*mo->s[j].b[obs[0]];
        else
            fwds[0][j] = 0;
        sum += fwds[0][j];
        //printf("fwds %d, %d =  %f\n",0, j, fwds[0][j]);
    }
    if(sum > GHMM_EPS_PREC){
        for (j=0;j<mo->N;j++){
            fwds[0][j] /= sum ;
        }
    }
    i = 1;
    
    pos = storedpos[1];
    fpos = storedfpos[1];
    //printf("pos = %d, fpos = %d \n", pos, fpos);
    sum = 0;
    for (j=0;j<mo->N;j++){        
        tv = fwds[i-1][0]* mats[pos][fpos][0][j];
        //printf("fwds %d, %d =  %f  ",i-1, 0, fwds[i-1][0]);
        sneak[i][j][0] = tv;
        //printf("sneak %d, %d, %d = %f\n", i, j, 0, sneak[i][j][0]);
        for (k=1;k<mo->N;k++){
            tv += (fwds[i-1][k]* mats[pos][fpos][k][j]) ;
            sneak[i][j][k] = tv;
            
            //printf("fwds %d, %d =  %f  ",i-1, j, fwds[i-1][j]);
            //printf("sneak %d, %d, %d = %f\n", i, j, k, sneak[i][j][k]);
        }
        fwds[i][j] = tv;
        sum += tv;
    }
    if(sum > GHMM_EPS_PREC){
        for(j=0;j<mo->N;j++)
             fwds[i][j] /= sum;
    }
    i = 2 ;
    s = R ;
    e = s + R ;

    while (1){

        pos = storedpos[s];
        fpos = storedfpos[s];

        //printf("pos = %d, fpos = %d s = %d\n", pos, fpos, s);
        sum = 0;
        
        for (j=0;j<mo->N;j++){            
            
            tv = fwds[i-1][0]* mats[pos][fpos][0][j];
            sneak[i][j][0] = tv;
            //printf("mats=%f  ", mats[pos][fpos][0][j]);
            //printf("fwds %d, %d =  %f  ",i-1, 0, fwds[i-1][0]);
            //printf("sneak %d, %d, %d = %f\n", i, j, 0, sneak[i][j][0]);
            for (k=1;k<mo->N;k++){
                tv += (fwds[i-1][k]* mats[pos][fpos][k][j]) ;
                sneak[i][j][k] = tv;
                //printf("mats=%f  ", mats[pos][fpos][k][j]);
                //printf("alph %d, %d = %f  ",i-1, k, fwds[i-1][k]);
                //printf("sneak %d, %d, %d = %f\n", i, j, k, sneak[i][j][k]);
            }

            fwds[i][j] = tv;
            sum += tv;

        }
        if(sum > GHMM_EPS_PREC){
            for(j=0;j<mo->N;j++){
                fwds[i][j] /= sum;
                //printf("fwd %d, %f\n",j, p_fwds[j]);
            }
        }
        
        if (e == T)
            break;
        i++;
        s += R;
        e += R;
        if (e>T)
            e = T;
    }
}
//===================================================================
//===================== compressed gibbs ============================
//===================================================================
/* runs the forward backward gibbs
 * mo: model
 * seed: seed 
 * obs: observation
 * totalobs: length of observation sequence
 * pA: prior for A
 * pB: prior for B
 * pPi: prior for pi
 * Q: states
 * R: length of compression */
void ghmm_dmodel_cfbgibbstep(ghmm_dmodel *mo, int *obs, int totalobs,
        double **pA, double **pB, double *pPi, int* Q, int R, double**fwds,
        double ***sneak, double ***mats, double ****rmats, int *storedpos){
        precompute(R, mo, mats, rmats);

        cforwards(totalobs, obs, mo, R, fwds, mats, storedpos, sneak);
        
        csamplestatepath(totalobs, obs, fwds, R, mats, rmats, Q,
                storedpos, sneak, mo->N);
        
}
/* runs the forward backward gibbs burnIn times
 * mo: model
 * seed: seed 
 * obs: observation
 * totalobs: length of observation sequence
 * pA: prior count for A
 * pB: prior count for B
 * pPi: prior count for pi
 * Q: states
 * R: length of compression
 * burnIn: number of times to run forward backward gibbs */

int** ghmm_dmodel_cfbgibbs(ghmm_dmodel* mo, ghmm_dseq* seq, double **pA, double **pB, double *pPi, int R, int burnIn, int seed){
#ifdef DO_WITH_GSL
#define CUR_PROC "ghmm_dmodel_cfbgibbs"
    GHMM_RNG_SET (RNG, seed);
    int **Q;
    ARRAY_CALLOC (Q ,seq->seq_number);     
    double **transitions, **obsinstatealpha;
    double *obsinstate;
    int i;
    int len = 0;
    for(i = 0; i < seq->seq_number; i++){
        ARRAY_CALLOC (Q[i] ,seq->seq_len[i]);     
        if(len < seq->seq_len[i])
            len = seq->seq_len[i];
    }
    if(mo->model_type & GHMM_kHigherOrderEmissions){//higher order
        //forwards
        int shtsize = len/R+2;
        double **fwds = ighmm_cmatrix_alloc(shtsize, mo->N);
        double ***sneak = ighmm_cmatrix_3d_alloc(shtsize, mo->N, mo->N);
        double ****mats;
        double *****rmats;
        int j;
        int limit = (pow(mo->M, R+1)-1)/(mo->M-1);  
        int d = (pow(mo->M, mo->maxorder+1)-1)/(mo->M-1);
        ARRAY_MALLOC (mats, limit+1);
        ARRAY_MALLOC (rmats, limit+1);
        int **mflag = ighmm_dmatrix_alloc(limit, d);
        for(i = 0; i < limit+1; i++){
            mats[i] = ighmm_cmatrix_3d_alloc(d, mo->N, mo->N);
            rmats[i] = malloc(sizeof(double*)*d);
            for(j = 0; j < d; j++)
               rmats[i][j] = ighmm_cmatrix_3d_alloc(mo->N, mo->N, mo->N);
        }
        //positions 
        int tupleSize[R+1];
        int tupleSizeH[mo->maxorder+1];
        int **preposition = ighmm_dmatrix_alloc(R, mo->M);
        int **prepositionH = ighmm_dmatrix_alloc(mo->maxorder, mo->M);
        int storedpos[seq->seq_number][len+1];//XXX should be different every seq
        int storedfpos[seq->seq_number][len+1];

        //precompute and store positions for matrices cooresponding to observations
        precomputeposition(R, mo->M, tupleSize, preposition, 1);
        precomputeposition(mo->maxorder, mo->M, tupleSizeH, prepositionH, 0);
        for(i=0;i<seq->seq_number;i++){
            storepositionH(R, mo->M, mo->maxorder, seq->seq_len[i], seq->seq[i], storedpos[i],
                    storedfpos[i], tupleSize, preposition, tupleSizeH, prepositionH);
        }

        //counts
        allocCountsH(mo, &transitions, &obsinstate, &obsinstatealpha);

        for(;burnIn > 0; burnIn--){
            if(burnIn%100==0)
                printf("iter %d\n", burnIn);
            initCountsH(mo, transitions, obsinstate, obsinstatealpha, pA, pB, pPi);
            for(i = 0; i < seq->seq_number; i++){
                getCountsH(mo, Q[i], seq->seq[i], seq->seq_len[i], transitions, obsinstate, obsinstatealpha);
                ghmm_dmodel_cfbgibbstepH(mo, seq->seq[i], seq->seq_len[i], pA, pB, pPi, Q[i], R,
                    fwds, sneak, mats, rmats, mflag, storedpos, storedfpos);
            }
            updateH(mo, transitions, obsinstate, obsinstatealpha);
        }
        //clean up
        freeCountsH(mo, &transitions, &obsinstate, &obsinstatealpha);
        ighmm_cmatrix_3d_free(&sneak, shtsize, mo->N);
        ighmm_dmatrix_free(&preposition, R);
        ighmm_dmatrix_free(&prepositionH, mo->maxorder);
        ighmm_cmatrix_free(&fwds, shtsize);
        for( i = 0 ; i < limit+1; i++){
           ighmm_cmatrix_3d_free(&mats[i],d, mo->N);
           for(j =0; j < d; j++)
              ighmm_cmatrix_3d_free(&(rmats[i][j]), mo->N, mo->N);
        } 
        m_free(rmats);
        m_free(mats);
    }
    else{//not higher order
        //forwards
        int shtsize = len/R+2;
        double **fwds = ighmm_cmatrix_alloc(shtsize, mo->N);
        double ***sneak = ighmm_cmatrix_3d_alloc(shtsize, mo->N, mo->N);
        double ***mats;
        double ****rmats;
        int i;
        int limit = (pow(mo->M, R+1)-1)/(mo->M-1) -1;  
  	mats = ighmm_cmatrix_3d_alloc(limit, mo->N, mo->N);
  	ARRAY_MALLOC (rmats, limit);
  	for(i = 0; i < limit; i++)
    	    rmats[i] = ighmm_cmatrix_3d_alloc(mo->N, mo->N, mo->N);
        int tupleSize[R+1];
        int **preposition = ighmm_dmatrix_alloc(R, mo->M);
        int storedpos[seq->seq_number][len+1];//XXX dif every seq

        //position
        precomputeposition(R, mo->M, tupleSize, preposition, 1);
        for(i = 0; i < seq->seq_number; i++){
            storeposition(R, seq->seq_len[i], seq->seq[i], storedpos[i], mo->M, tupleSize, preposition);
        }
        //counts
        allocCounts(mo, &transitions, &obsinstate, &obsinstatealpha);
        for(;burnIn > 0; burnIn--){
            if(burnIn % 100==0) printf("iter %d", burnIn);
            initCounts(mo, transitions, obsinstate, obsinstatealpha, pA, pB, pPi);
            for(i = 0; i < seq->seq_number;i++){
                ghmm_dmodel_cfbgibbstep(mo, seq->seq[i], seq->seq_len[i], pA, pB, pPi, Q[i], R, 
                      fwds, sneak, mats, rmats, storedpos[i]);
                getCounts(Q[i], seq->seq[i], seq->seq_len[i], transitions, obsinstate, obsinstatealpha);
            }
            update(mo, transitions, obsinstate, obsinstatealpha);
        }
        //clean up
        freeCounts(mo, &transitions, &obsinstate, &obsinstatealpha);
        ighmm_cmatrix_3d_free(&sneak, shtsize, mo->N);
        ighmm_dmatrix_free(&preposition, R);
        ighmm_cmatrix_free(&fwds, shtsize);
        ighmm_cmatrix_3d_free(&mats,limit, mo->N);
        for( i = 0 ; i < limit; i++)
           ighmm_cmatrix_3d_free(&(rmats[i]), mo->N, mo->N);
        m_free(rmats);
    }
    return Q;
STOP:
   return NULL; 
#undef CUR_PROC
#else
   printf("cfbgibbs uses gsl for dirichlete distrubutions, compile with gsl\n");
   return NULL;
#endif
}


void ghmm_dmodel_cfbgibbstepH(ghmm_dmodel *mo, int *obs, int totalobs,
        double **pA, double **pB, double *pPi, int* Q, int R,
        double**fwds, double ***sneak, double ****mats, double *****rmats, int **mflag, 
        int *storedpos, int *storedfpos){
    precomputedmatsH(totalobs, obs, R, mats, rmats, mflag, storedpos, storedfpos, mo);

    cforwardsH(totalobs, obs, mo, R, fwds, mats, storedpos, storedfpos, sneak);
   
    csamplestatepathH(totalobs, obs, fwds, R, mo->N,  mats, 
          rmats, Q, storedpos, storedfpos, sneak);

}
