0.1 Finding Model Parameters: The Baum-Welsh Training Algorithm

0.1.1 Setup

Now, for the last frontier in tackling HMMs: traininga model. Suppose that we are trying to figure out the probabilites of tranistioning dies in our dishonest casino example. Now, in the example we were looking at before, we knew two very handy things: both the tranition probabilities and the emission probabilities. Over the past sections, we learned how to deal with various types of missing data: finding the hidden state at a particular time, the optimal path through hidden states, etc. But all that relyed on having the TPM and the emission probabilites. What if we knew neither?

Lets make a new example. You have walked into a casino, and are standing, scotch in hand, watching the crooked croupier. You watch for 30 throws of the die, and observe the following sequence of throws:

new_pattern <- c(6,3,6,1,6,6,1,2,3,1,5,6,6,6,1,6,6,2,6,2,5,2,5,6,2,2,3,6,5,4)

You watch, sipping pensivly, wondering if this is indeed the crooked croupier that you hear about for all those years. If only you have a way to tell!

Well you do, thanks to the Baum-Welsh algorithm, which is a Expectation Maximization algorithm for HMMs. Here is what we shall do:

  1. Initiate: set your TPM and emission probabilities to some arbitraty (but legal) values
  2. Expectation: "decode" the observed symbol sequence (emmitted) given those parameters
  3. Maximization: Use this to update the parameters
  4. repeat until convergence.

0.1.2 Nomenclature

We are going to be refering to our parameters of the model as the Greek leter theta (θ). This means our TPM \(a\) and emission matrix \(e\). Don't let the shorthand get in the way.

How do we do step 2? How do we esimate param values based on current params and the sequence?

say aij is the estimated matrix. estimate it from P(ρi = k,ρi+1 = m| X,θ)

  • where θ is the current set of parameters

derive another formula (dont need to knwo) that has the fowraed backward probabilities

(f_k (i) akm e_m (xx+i) b_m (i+1)) / P(x)

a <- matrix(c(.5,.5,.5,.5),2)
findGamma <- function(x, f, b){
  #  returns a new A matrix
  res <- matrix(rep(0, (length(x) * nrow(f))), nrow(f))
  total_prob <- 0
  for(state1a in 1:nrow(f)){
    for (i in 1:(length(x))){
      total_prob <- total_prob + (f[state1a, x[i]]  * b[state1a, x[i]])
    }
  }
  for(state1b in 1:nrow(f)){
    for (j in 1:(length(x))){
      res[state1b, j] <- (f[state1b, x[j]]  * b[state1b, x[j]]  ) / total_prob
    }
  }
  return(res)
}
(diceGamma <- findGamma(x=pattern, f=forward_results, b=backward_results))
findXi <- function(x, a, e, f, b){
  reslist = list()
  total_prob_with_trans <- 0
  for (i in 1:(length(x)-1)){
    for (state1a in 1:nrow(f)){
      for (state2a in 1:nrow(f)){
        total_prob_with_trans <- total_prob_with_trans + (f[state1a, x[i]] * a[state1a, state2a] * e[state2a, x[i+1]] * b[state2a, x[i+1]])
      }
    }
  }
  for (j in 1:(length(x)-1)){
    res = matrix(rep(0, (nrow(a) * ncol(a))), ncol(a))
    for (state1b in 1:nrow(f)){
      for (state2b in 1:nrow(f)){
        temp_numerator <- f[state1b, x[j]] * a[state1b, state2b] * e[state2b, x[j+1]] * b[state2b, x[j+1]]
        res[state1b, state2b] <- temp_numerator / total_prob_with_trans
      }
    }
    reslist[[j]] <- res
  }
  return(reslist)
}
(diceXi <- findXi(x=pattern, a=a, e=dice_probs, f=forward_results, b=backward_results))
           [,1]       [,2]       [,3]       [,4]       [,5]      [,6]
[1,] 0.09738410 0.09199935 0.10084981 0.10951318 0.10084981 0.1150910
[2,] 0.06928257 0.07466732 0.06581685 0.05715348 0.06581685 0.0515757
[[1]]
             [,1]         [,2]
[1,] 0.0009284284 0.0003637924
[2,] 0.0003714959 0.0001455657

[[2]]
             [,1]        [,2]
[1,] 0.0009379149 0.002813745
[2,] 0.0011656155 0.003496847

[[3]]
             [,1]         [,2]
[1,] 1.480039e-07 2.365514e-07
[2,] 9.659067e-08 1.543788e-07

[[4]]
          [,1]      [,2]
[1,] 0.2069071 0.6207212
[2,] 0.0405369 0.1216107

[[5]]
             [,1]         [,2]
[1,] 2.540521e-08 2.276967e-08
[2,] 1.658001e-08 1.486000e-08

Right, so that was the expectation step. Now, for the maximization:

updateTPM = function(g, xi){
  #  returns a new A matrix
  nstates <- nrow(g)
  res = matrix(rep(0, (nstates * nstates)), nstates)
  norm_res = matrix(rep(0, (nstates * nstates)), nstates)
  for(state1 in 1:nstates){
    for(state2 in 1:nstates){
      temp_num <- 0
      temp_denom <- 0
      for( i in 1:length(xi)){   # should be T -1
            temp_num <- temp_num + xi[[i]][state1, state2]
            temp_denom <- temp_denom + g[state1, i]
      }
      res[state1, state2] <- temp_num / temp_denom
    }
  }
  # normalize
  print(res)
  for (j in 1:nrow(g)){
    norm_res[j,] <- res[j,] / sum(res[j,])
  }
  return(norm_res)
}
(aprime <- updateTPM(xi=diceXi, g=diceGamma))
          [,1]      [,2]
[1,] 0.4170498 1.2463118
[2,] 0.1264485 0.3764332
          [,1]      [,2]
[1,] 0.2507271 0.7492729
[2,] 0.2514479 0.7485521

And, now to iupdate our emmission probabilities:

updateEmissions = function(x, g, xi, olde){
  #  returns a new e matrix
  nstates <- nrow(g)
  nobs <- ncol(olde)
  res = matrix(rep(0, (nstates * nobs)), nstates)
  norm_res = matrix(rep(0, (nstates * nobs)), nstates)
  for(state1 in 1:nstates){
    total_state_prob <- sum(g[state1,])
    for(ob in 1:nobs){
      match_indexes <- x == ob  # bool vector of data that matches e
      prob_for_emission <- sum(g[state1, match_indexes])
      res[state1, ob] <- prob_for_emission / total_state_prob
    }
  }
  # print(res)
  #normalize
  for (j in 1:nrow(res)){
    norm_res[j,] <- res[j,] / sum(res[j,])
  }
  return(norm_res)
}
(eprime <- updateEmissions(x=pattern, xi=diceXi, g=diceGamma, olde=dice_probs))
          [,1]      [,2] [,3]      [,4]      [,5]      [,6]
[1,] 0.1869309 0.1778715    0 0.1581714 0.1494255 0.3276008
[2,] 0.1342024 0.1487161    0 0.1802765 0.1942879 0.3425171

Whew, that was not fun! Now, lets do a few rounds rounds of training on our data:

new_pattern
colnames(a) <- c("F", "L")
rownames(a) <- c("F", "L")
a
new_dice_probs <- matrix(c(1/6, 1/6, 1/6, 1/6, 1/6, 1/6,
                          1/10, 1/10, 1/10, 1/10,  2/5, 1/5), 2, byrow=T)
print("Iteration 1")
f1 <- forward(data = new_pattern, init_probs = c(.5, .5), TPM=a, emission=new_dice_probs)
b1 <- backward(data = new_pattern, TPM=a, emission=new_dice_probs)
g1 <- findGamma(x=new_pattern, f=f1, b=b1)
x1 <- findXi(x=new_pattern, a=a, e=new_dice_probs, f=f1, b=b1)
(a1 <- updateTPM(xi=x1, g=g1))
(e1 <- updateEmissions(x=new_pattern, xi=x1, g=g1, olde=new_dice_probs))
print("Iteration 2")
f2 <- forward(data = new_pattern, init_probs = g1[,1], TPM=a1, emission=e1)
b2 <- backward(data = new_pattern, TPM=a1, emission=e1)
g2 <- findGamma(x=new_pattern, f=f2, b=b2)
x2 <- findXi(x=new_pattern, a=a1, e=e1, f=f2, b=b2)
(a2 <- updateTPM(xi=x2, g=g2))
(e2 <- updateEmissions(x=new_pattern, xi=x2, g=g2, olde=e1))
 [1] 6 3 6 1 6 6 1 2 3 1 5 6 6 6 1 6 6 2 6 2 5 2 5 6 2 2 3 6 5 4
    F   L
F 0.5 0.5
L 0.5 0.5
[1] "Iteration 1"
          [,1]      [,2]
[1,] 0.4265419 0.5631482
[2,] 0.4664774 0.6110091
          [,1]      [,2]
[1,] 0.4309853 0.5690147
[2,] 0.4329311 0.5670689
          [,1]      [,2]       [,3]       [,4]      [,5]      [,6]
[1,] 0.1226054 0.2528736 0.09195402 0.04214559 0.1226054 0.3678161
[2,] 0.1438202 0.1483146 0.10786517 0.02471910 0.1438202 0.4314607
[1] "Iteration 2"
          [,1]      [,2]
[1,] 0.4062225 0.6280243
[2,] 0.4083338 0.6263021
          [,1]      [,2]
[1,] 0.3927713 0.6072287
[2,] 0.3946643 0.6053357
          [,1]      [,2]       [,3]       [,4]      [,5]      [,6]
[1,] 0.1332993 0.2000078 0.10000384 0.03333461 0.1333385 0.4000160
[2,] 0.1333554 0.1999949 0.09999751 0.03333250 0.1333300 0.3999896