Skip to content

Commit 71adcc9

Browse files
authored
Don't error if the loglikelihood isn't finite (#127)
* Don't error if the loglikelihood isn't finite * Fix * Fix
1 parent fc44a53 commit 71adcc9

File tree

7 files changed

+36
-13
lines changed

7 files changed

+36
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "HiddenMarkovModels"
22
uuid = "84ca31d5-effc-45e0-bfda-5a68cd981f47"
33
authors = ["Guillaume Dalle"]
4-
version = "0.6.1"
4+
version = "0.6.2"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/inference/forward.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,12 @@ function _forward_digest_observation!(
6868
current_obs_likelihoods::AbstractVector{<:Real},
6969
hmm::AbstractHMM,
7070
obs,
71-
control,
71+
control;
72+
error_if_not_finite::Bool,
7273
)
7374
a, b = current_state_marginals, current_obs_likelihoods
7475

75-
obs_logdensities!(b, hmm, obs, control)
76+
obs_logdensities!(b, hmm, obs, control; error_if_not_finite)
7677
logm = maximum(b)
7778
b .= exp.(b .- logm)
7879

@@ -90,7 +91,8 @@ function _forward!(
9091
obs_seq::AbstractVector,
9192
control_seq::AbstractVector,
9293
seq_ends::AbstractVectorOrNTuple{Int},
93-
k::Integer,
94+
k::Integer;
95+
error_if_not_finite::Bool,
9496
)
9597
(; α, B, c, logL) = storage
9698
t1, t2 = seq_limits(seq_ends, k)
@@ -104,12 +106,14 @@ function _forward!(
104106
αₜ₋₁ = view(α, :, t - 1)
105107
predict_next_state!(αₜ, hmm, αₜ₋₁, control_seq[t - 1])
106108
end
107-
cₜ, logLₜ = _forward_digest_observation!(αₜ, Bₜ, hmm, obs_seq[t], control_seq[t])
109+
cₜ, logLₜ = _forward_digest_observation!(
110+
αₜ, Bₜ, hmm, obs_seq[t], control_seq[t]; error_if_not_finite
111+
)
108112
c[t] = cₜ
109113
logL[k] += logLₜ
110114
end
111115

112-
@argcheck isfinite(logL[k])
116+
error_if_not_finite && @argcheck isfinite(logL[k])
113117
return nothing
114118
end
115119

@@ -122,14 +126,15 @@ function forward!(
122126
obs_seq::AbstractVector,
123127
control_seq::AbstractVector;
124128
seq_ends::AbstractVectorOrNTuple{Int},
129+
error_if_not_finite::Bool=true,
125130
)
126131
if seq_ends isa NTuple{1}
127132
for k in eachindex(seq_ends)
128-
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
133+
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k; error_if_not_finite)
129134
end
130135
else
131136
@threads for k in eachindex(seq_ends)
132-
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
137+
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k; error_if_not_finite)
133138
end
134139
end
135140
return nothing
@@ -147,8 +152,9 @@ function forward(
147152
obs_seq::AbstractVector,
148153
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
149154
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
155+
error_if_not_finite::Bool=true,
150156
)
151157
storage = initialize_forward(hmm, obs_seq, control_seq; seq_ends)
152-
forward!(storage, hmm, obs_seq, control_seq; seq_ends)
158+
forward!(storage, hmm, obs_seq, control_seq; seq_ends, error_if_not_finite)
153159
return storage.α, storage.logL
154160
end

src/inference/forward_backward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ function _forward_backward!(
4242
t1, t2 = seq_limits(seq_ends, k)
4343

4444
# Forward (fill B, α, c and logL)
45-
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k)
45+
_forward!(storage, hmm, obs_seq, control_seq, seq_ends, k; error_if_not_finite=true)
4646

4747
# Backward
4848
β[:, t2] .= c[t2]

src/inference/logdensity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function DensityInterface.logdensityof(
99
control_seq::AbstractVector=Fill(nothing, length(obs_seq));
1010
seq_ends::AbstractVectorOrNTuple{Int}=(length(obs_seq),),
1111
)
12-
_, logL = forward(hmm, obs_seq, control_seq; seq_ends)
12+
_, logL = forward(hmm, obs_seq, control_seq; seq_ends, error_if_not_finite=false)
1313
return sum(logL)
1414
end
1515

src/types/abstract_hmm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,13 @@ StatsAPI.fit!
128128
## Fill logdensities
129129

130130
function obs_logdensities!(
131-
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control
131+
logb::AbstractVector{T}, hmm::AbstractHMM, obs, control; error_if_not_finite::Bool=true
132132
) where {T}
133133
dists = obs_distributions(hmm, control)
134134
@simd for i in eachindex(logb, dists)
135135
logb[i] = logdensityof(dists[i], obs)
136136
end
137-
@argcheck maximum(logb) < typemax(T)
137+
error_if_not_finite && @argcheck maximum(logb) < typemax(T)
138138
return nothing
139139
end
140140

test/misc.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using HiddenMarkovModels
2+
using HiddenMarkovModels: rand_prob_vec, rand_trans_mat
3+
using Distributions
4+
using Test
5+
6+
@testset "Allow NaN density" begin
7+
init = rand_prob_vec(2)
8+
trans = rand_trans_mat(2)
9+
dists = [Normal(i, Inf) for i in 1:2]
10+
hmm = HMM(init, trans, dists)
11+
obs_seq = rand(5)
12+
@test isnan(logdensityof(hmm, obs_seq))
13+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,8 @@ Pkg.develop(; path=joinpath(dirname(@__DIR__), "libs", "HMMTest"))
5353
@testset verbose = true "Correctness - $TEST_SUITE" begin
5454
include("correctness.jl")
5555
end
56+
57+
@testset verbose = true "Miscellaneous" begin
58+
include("misc.jl")
59+
end
5660
end

0 commit comments

Comments
 (0)