Skip to content

VBT vs brownian path slowdown #489

@lockwo

Description

@lockwo

While implementing some weak solver schemes, I noticed that when I used the VBT as opposed to the UBP (unsafe brownian path) it was substantially (like ~5 mins vs ~10 seconds) slower. Using an UBP is fine in this case for us (since its a fixed step solver, and we aren't differentiating through the equation), but in the future is not ideal. Below is a MVC, but in summary:

  • This is not the real solver, I ripped out most of everything just to make the code smaller
  • These numbers seem a bit small/microbenchmark-y but I have verified them on some larger problems as well, I am just using this small problem for speed and demonstration.
  • VBT is 10x slower here, and seems dominated by the line u += g1 @ (_dW + chi1). Specifically, if I comment that out, I see a decrease in speed from ~11s to 3s. But in UBP it only goes from like 1.1s to 0.8s (so it isn't just the elimination of a matmul making that whole speed gap).
  • The surprising thing is not that VBT is slower (I figured it would come with some overhead), but that it seems to scale as well. Specifically, if I decrease dt, it's not just some constant overhead but seems to increase. Maybe this is expected, but even for small problems with large dts this becomes prohibitive (see the original 5 min vs 10 seconds).

All of this is a bit surprising since I just call the diffusion control once. Is there a way of using VBT's or integrating them into new solvers that avoids this slowdown, or am I just making some mistake in my usage of the VBT?

Here is the full code:

import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar

_NORMAL_ONESIX_QUANTILE = -0.9674215661017014


def calc_threepoint_random(x):
    return jnp.where(
        jnp.abs(x) > -_NORMAL_ONESIX_QUANTILE,
        jnp.where(x < _NORMAL_ONESIX_QUANTILE, -1.0, 1.0),
        0.0,
    )


def calc_twopoint_random(x):
    return jnp.where(x > 0, 1.0, -1.0)


class Solver(diffrax.AbstractSolver):

    term_structure: ClassVar = diffrax.AbstractTerm
    interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation

    def func(self, terms, t0, y0, args):
        return terms.vf(t0, y0, args)

    def init(self, terms, t0, t1, y0, args):
        return None

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        drift = terms.terms[0]
        diffusion = terms.terms[1]
        cont = diffusion.contr(t0, t1)
        dt = t1 - t0
        dW_scaled = cont["dW"] / jnp.sqrt(dt)
        sq3dt = jnp.sqrt(3 * dt)
        _dW = sq3dt * calc_threepoint_random(dW_scaled)
        dZ_scaled = cont["dZ"]
        _dZ = calc_twopoint_random(dZ_scaled)
        xi = jnp.sqrt(dt) * _dZ[0]
        chi1 = (_dW**2 / xi - xi) / 2
        k1 = drift.vf(t0, y0, args)
        g1 = diffusion.vf(t0, y0, args)
        H02 = y0 + k1 * dt + g1 @ _dW
        k2 = drift.vf(t0, H02, args)
        H03 = y0 + k2 * dt + k1 * dt + g1 @ _dW
        k3 = drift.vf(t0, H03, args)
        u = y0 + k1 * dt + k2 * dt + k3 * dt
        u += g1 @ (_dW + chi1)
        dense_info = dict(y0=y0, y1=u)

        return u, None, dense_info, None, diffrax.RESULTS.successful

def drift(t, X, args):
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
    return jnp.array([dy1, dy2])


def diffusion(t, X, args):
    y1, y2 = X
    g11 = 1 / 4 * y1
    g12 = 1 / 16 * y1
    g21 = (1 - 2 * jnp.sqrt(2)) / 4 * y1
    g22 = 1 // 10 * y1 + 1 // 16 * y2

    return jnp.array([[g11, g12], [g21, g22]])


t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])


def solve_wrapper(dt, num_samples, use_tree):
    keys = jax.random.split(jax.random.key(42), num_samples)
    solver = Solver()
    saveat = diffrax.SaveAt(t1=True)

    def solve(key):
        if not use_tree:
            tree = diffrax.UnsafeBrownianPath(
                shape={
                    "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                    "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                },
                key=key,
            )
            terms = diffrax.MultiTerm(
                diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
            )
            return diffrax.diffeqsolve(
                terms,
                solver,
                t0,
                t1,
                dt0=dt,
                y0=y0,
                saveat=saveat,
                adjoint=diffrax.DirectAdjoint(),
            )
        else:
            tree = diffrax.VirtualBrownianTree(
                t0,
                t1,
                tol=dt / 2,
                shape={
                    "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                    "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                },
                key=key,
            )
            terms = diffrax.MultiTerm(
                diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
            )
            return diffrax.diffeqsolve(
                terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat
            )
            # , adjoint=diffrax.DirectAdjoint()) this is 2x slower

    return jax.jit(jax.vmap(solve))(keys).ys.squeeze(axis=1)
%%timeit
_ = solve_wrapper(1.0, 20 * 100_000, True).block_until_ready()

yields 9.37 s ± 281 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for VBT and 1.03 s ± 8.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for UBP.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions