Skip to content

Commit 20dfbec

Browse files
authored
Merge pull request #125 from JuliaDecisionFocusedLearning/fix-perturbed-multiplicative
Fix bug in PerturbedMultiplicative
2 parents 01bbd5a + 5c18b0a commit 20dfbec

File tree

5 files changed

+17
-3
lines changed

5 files changed

+17
-3
lines changed

CITATION.bib

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ @misc{InferOpt.jl
22
author = {Guillaume Dalle, Léo Baty, Louis Bouvier and Axel Parmentier},
33
title = {InferOpt.jl},
44
url = {https://github.com/axelparmentier/InferOpt.jl},
5-
version = {v0.7.0},
5+
version = {v0.7.1},
66
year = {2025},
7-
month = {4}
7+
month = {7}
88
}

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InferOpt"
22
uuid = "4846b161-c94e-4150-8dac-c7ae193c601f"
33
authors = ["Guillaume Dalle", "Léo Baty", "Louis Bouvier", "Axel Parmentier"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/InferOpt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ using DifferentiableExpectations:
1515
using Distributions:
1616
Distributions,
1717
ContinuousUnivariateDistribution,
18+
AffineDistribution,
19+
Dirac,
1820
LogNormal,
1921
Normal,
2022
product_distribution,

src/layers/perturbed/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ It is equal to ``logpdf(d, log(x)) - log(x)``.
2626
function Distributions.logpdf(d::ExponentialOf, x::Real)
2727
return logpdf(d.dist, log(x)) - log(x)
2828
end
29+
30+
function Base.:*(x::Real, d::ExponentialOf)
31+
return iszero(x) ? Dirac(0.0) : AffineDistribution(zero(x), x, d)
32+
end

test/perturbed.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,11 @@ end
112112
@test Ja Jz rtol = 0.01
113113
@test Jm Jz rtol = 0.01
114114
end
115+
116+
@testitem "Perturbed - Misc" begin
117+
# Make sure that having 0s in θ works
118+
perturbed = PerturbedMultiplicative(identity; ε=1.0, nb_samples=1e4, seed=0)
119+
θ = [1.0, 0.0]
120+
y = perturbed(θ)
121+
@test iszero(y[2])
122+
end

0 commit comments

Comments
 (0)