Skip to content

Commit 1656f7d

Browse files
THargreavesgdalle
andauthored
Decompose forward function into initialize, predict, update (#105)
* Decompose forward function into initialize, predict, update * Fixes * Remove unused predict --------- Co-authored-by: Guillaume Dalle <[email protected]>
1 parent 38eb996 commit 1656f7d

File tree

4 files changed

+51
-30
lines changed

4 files changed

+51
-30
lines changed

src/HiddenMarkovModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ include("utils/lightdiagnormal.jl")
3838
include("utils/lightcategorical.jl")
3939
include("utils/limits.jl")
4040

41+
include("inference/predict.jl")
4142
include("inference/forward.jl")
4243
include("inference/viterbi.jl")
4344
include("inference/forward_backward.jl")

src/inference/forward.jl

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,27 @@ function initialize_forward(
6363
return ForwardStorage(α, logL, B, c)
6464
end
6565

66+
function _forward_digest_observation!(
67+
current_state_marginals::AbstractVector{<:Real},
68+
current_obs_likelihoods::AbstractVector{<:Real},
69+
hmm::AbstractHMM,
70+
obs,
71+
control,
72+
)
73+
a, b = current_state_marginals, current_obs_likelihoods
74+
75+
obs_logdensities!(b, hmm, obs, control)
76+
logm = maximum(b)
77+
b .= exp.(b .- logm)
78+
79+
a .*= b
80+
c = inv(sum(a))
81+
lmul!(c, a)
82+
83+
logL = -log(c) + logm
84+
return c, logL
85+
end
86+
6687
function _forward!(
6788
storage::ForwardOrForwardBackwardStorage,
6889
hmm::AbstractHMM,
@@ -73,36 +94,19 @@ function _forward!(
7394
)
7495
(; α, B, c, logL) = storage
7596
t1, t2 = seq_limits(seq_ends, k)
76-
77-
# Initialization
78-
Bₜ₁ = view(B, :, t1)
79-
obs_logdensities!(Bₜ₁, hmm, obs_seq[t1], control_seq[t1])
80-
logm = maximum(Bₜ₁)
81-
Bₜ₁ .= exp.(Bₜ₁ .- logm)
82-
83-
init = initialization(hmm)
84-
αₜ₁ = view(α, :, t1)
85-
αₜ₁ .= init .* Bₜ₁
86-
c[t1] = inv(sum(αₜ₁))
87-
lmul!(c[t1], αₜ₁)
88-
89-
logL[k] = -log(c[t1]) + logm
90-
91-
# Loop
92-
for t in t1:(t2 - 1)
93-
Bₜ₊₁ = view(B, :, t + 1)
94-
obs_logdensities!(Bₜ₊₁, hmm, obs_seq[t + 1], control_seq[t + 1])
95-
logm = maximum(Bₜ₊₁)
96-
Bₜ₊₁ .= exp.(Bₜ₊₁ .- logm)
97-
98-
trans = transition_matrix(hmm, control_seq[t])
99-
αₜ, αₜ₊₁ = view(α, :, t), view(α, :, t + 1)
100-
mul!(αₜ₊₁, transpose(trans), αₜ)
101-
αₜ₊₁ .*= Bₜ₊₁
102-
c[t + 1] = inv(sum(αₜ₊₁))
103-
lmul!(c[t + 1], αₜ₊₁)
104-
105-
logL[k] += -log(c[t + 1]) + logm
97+
logL[k] = zero(eltype(logL))
98+
for t in t1:t2
99+
αₜ = view(α, :, t)
100+
Bₜ = view(B, :, t)
101+
if t == t1
102+
copyto!(αₜ, initialization(hmm))
103+
else
104+
αₜ₋₁ = view(α, :, t - 1)
105+
predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1])
106+
end
107+
cₜ, logLₜ = _forward_digest_observation!(αₜ, Bₜ, hmm, obs_seq[t], control_seq[t])
108+
c[t] = cₜ
109+
logL[k] += logLₜ
106110
end
107111

108112
@argcheck isfinite(logL[k])

src/inference/predict.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
function predict_next_state!(
2+
next_state_marginals::AbstractVector{<:Real},
3+
hmm::AbstractHMM,
4+
current_state_marginals::AbstractVector{<:Real},
5+
control=nothing,
6+
)
7+
trans = transition_matrix(hmm, control)
8+
mul!(next_state_marginals, transpose(trans), current_state_marginals)
9+
return next_state_marginals
10+
end

src/types/abstract_hmm.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ log_initialization(hmm::AbstractHMM) = elementwise_log(initialization(hmm))
7272
transition_matrix(hmm, control)
7373
7474
Return the matrix of state transition probabilities for `hmm` (possibly when `control` is applied).
75+
76+
!!! note
77+
When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`).
7578
"""
7679
function transition_matrix end
7780

@@ -82,6 +85,9 @@ function transition_matrix end
8285
Return the matrix of state transition log-probabilities for `hmm` (possibly when `control` is applied).
8386
8487
Falls back on `transition_matrix`.
88+
89+
!!! note
90+
When processing sequences, the control at time `t` influences the transition from time `t` to `t+1` (and not from time `t-1` to `t`).
8591
"""
8692
function log_transition_matrix(hmm::AbstractHMM, control)
8793
return elementwise_log(transition_matrix(hmm, control))

0 commit comments

Comments
 (0)