-
-
Notifications
You must be signed in to change notification settings - Fork 156
Description
While implementing some weak solver schemes, I noticed that when I used the VBT as opposed to the UBP (unsafe brownian path) it was substantially (like ~5 mins vs ~10 seconds) slower. Using an UBP is fine in this case for us (since its a fixed step solver, and we aren't differentiating through the equation), but in the future is not ideal. Below is a MVC, but in summary:
- This is not the real solver, I ripped out most of everything just to make the code smaller
- These numbers seem a bit small/microbenchmark-y but I have verified them on some larger problems as well, I am just using this small problem for speed and demonstration.
- VBT is 10x slower here, and seems dominated by the line
u += g1 @ (_dW + chi1)
. Specifically, if I comment that out, I see a decrease in speed from ~11s to 3s. But in UBP it only goes from like 1.1s to 0.8s (so it isn't just the elimination of a matmul making that whole speed gap). - The surprising thing is not that VBT is slower (I figured it would come with some overhead), but that it seems to scale as well. Specifically, if I decrease dt, it's not just some constant overhead but seems to increase. Maybe this is expected, but even for small problems with large dts this becomes prohibitive (see the original 5 min vs 10 seconds).
All of this is a bit surprising since I just call the diffusion control once. Is there a way of using VBT's or integrating them into new solvers that avoids this slowdown, or am I just making some mistake in my usage of the VBT?
Here is the full code:
import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar
_NORMAL_ONESIX_QUANTILE = -0.9674215661017014
def calc_threepoint_random(x):
return jnp.where(
jnp.abs(x) > -_NORMAL_ONESIX_QUANTILE,
jnp.where(x < _NORMAL_ONESIX_QUANTILE, -1.0, 1.0),
0.0,
)
def calc_twopoint_random(x):
return jnp.where(x > 0, 1.0, -1.0)
class Solver(diffrax.AbstractSolver):
term_structure: ClassVar = diffrax.AbstractTerm
interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation
def func(self, terms, t0, y0, args):
return terms.vf(t0, y0, args)
def init(self, terms, t0, t1, y0, args):
return None
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
drift = terms.terms[0]
diffusion = terms.terms[1]
cont = diffusion.contr(t0, t1)
dt = t1 - t0
dW_scaled = cont["dW"] / jnp.sqrt(dt)
sq3dt = jnp.sqrt(3 * dt)
_dW = sq3dt * calc_threepoint_random(dW_scaled)
dZ_scaled = cont["dZ"]
_dZ = calc_twopoint_random(dZ_scaled)
xi = jnp.sqrt(dt) * _dZ[0]
chi1 = (_dW**2 / xi - xi) / 2
k1 = drift.vf(t0, y0, args)
g1 = diffusion.vf(t0, y0, args)
H02 = y0 + k1 * dt + g1 @ _dW
k2 = drift.vf(t0, H02, args)
H03 = y0 + k2 * dt + k1 * dt + g1 @ _dW
k3 = drift.vf(t0, H03, args)
u = y0 + k1 * dt + k2 * dt + k3 * dt
u += g1 @ (_dW + chi1)
dense_info = dict(y0=y0, y1=u)
return u, None, dense_info, None, diffrax.RESULTS.successful
def drift(t, X, args):
y1, y2 = X
dy1 = -273 / 512 * y1
dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
return jnp.array([dy1, dy2])
def diffusion(t, X, args):
y1, y2 = X
g11 = 1 / 4 * y1
g12 = 1 / 16 * y1
g21 = (1 - 2 * jnp.sqrt(2)) / 4 * y1
g22 = 1 // 10 * y1 + 1 // 16 * y2
return jnp.array([[g11, g12], [g21, g22]])
t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])
def solve_wrapper(dt, num_samples, use_tree):
keys = jax.random.split(jax.random.key(42), num_samples)
solver = Solver()
saveat = diffrax.SaveAt(t1=True)
def solve(key):
if not use_tree:
tree = diffrax.UnsafeBrownianPath(
shape={
"dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
"dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
},
key=key,
)
terms = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
)
return diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0=dt,
y0=y0,
saveat=saveat,
adjoint=diffrax.DirectAdjoint(),
)
else:
tree = diffrax.VirtualBrownianTree(
t0,
t1,
tol=dt / 2,
shape={
"dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
"dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
},
key=key,
)
terms = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
)
return diffrax.diffeqsolve(
terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat
)
# , adjoint=diffrax.DirectAdjoint()) this is 2x slower
return jax.jit(jax.vmap(solve))(keys).ys.squeeze(axis=1)
%%timeit
_ = solve_wrapper(1.0, 20 * 100_000, True).block_until_ready()
yields 9.37 s ± 281 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
for VBT and 1.03 s ± 8.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
for UBP.