Authors: David M. Knigge*, David R. Wessels*, Riccardo Valperga, Samuele Papa, Jan-Jakob Sonke, Efstratios Gavves^, Erik J. Bekkers^
*equal contribution, ^equal advising
This is the reproducibility repo for the paper "Space-Time Continuous PDE Forecasting using Equivariant Neural Fields". All experiments in the paper should be reproducible using the code in this repository. Data for the experiments is generated by using solvers provided by Dedalus and Py-PDE, except for the Navier-Stokes equations .
To install the requirements, we use conda. We recommend creating a new environment for the project.
conda create -n enf-pde-jax python=3.11
conda activate enf-pde-jax
Install the relevant dependencies.
pip install "jax[cuda12]" flax optax orbax
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install wandb matplotlib trimesh hydra-core tqdm netCDF4 py-pde
conda install -c conda-forge dedalus
All datasets are generated using either dedalus or py-pde as before training. Code for datasets can be found under /experiments/fitting/datasets/pdes.py
.
Note: To generate the navier-stokes dataset, PyTorch compiled with CUDA is required. As PyTorch and JAX do not play well together, we recommend running the navier-stokes dataset generation in a separate environment:
The following generates a new env that includes a CUDA-compiled PyTorch.
conda create -n enf-pde-jax-dset python=3.11
conda activate enf-pde-jax-dset
conda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install optax flax orbax wandb matplotlib trimesh hydra-core tqdm netCDF4 py-pde
conda install -c conda-forge dedalus
Next, run the following command to generate the navier-stokes dataset:
export PYTHONPATH=. && python experiments/fitting/gen_navier_stokes.py
Afterwards, you can deactivate the environment and return to the main environment:
conda deactivate
conda activate enf-pde-jax
We list relevant components of the repository here:
enf/
contains the code for the Equivariant Neural Field.enf/steerable_attention/invariant/
contains the code for the bi-invariants used in the experiments.experiments/fitting/
contains the code for all experiments in the paper.
We specify commands per experiment in the experiments readme.
If you find this code useful, please consider citing our paper:
@article{knigge2024space,
title={Space-Time Continuous PDE Forecasting using Equivariant Neural Fields},
author={Knigge, David M and Wessels, David R and Valperga, Riccardo and Papa, Samuele and Sonke, Jan-Jakob and Gavves, Efstratios and Bekkers, Erik J},
journal={arXiv preprint arXiv:2406.06660},
year={2024}
}