Skip to content

david-knigge/enf-pde

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🚀 Space-Time Continuous PDE Forecasting using Equivariant Neural Fields

License: MIT Python jax_badge badge

Authors: David M. Knigge*, David R. Wessels*, Riccardo Valperga, Samuele Papa, Jan-Jakob Sonke, Efstratios Gavves^, Erik J. Bekkers^

*equal contribution, ^equal advising


Overview

This is the reproducibility repo for the paper "Space-Time Continuous PDE Forecasting using Equivariant Neural Fields". All experiments in the paper should be reproducible using the code in this repository. Data for the experiments is generated by using solvers provided by Dedalus and Py-PDE, except for the Navier-Stokes equations .


Requirements

To install the requirements, we use conda. We recommend creating a new environment for the project.

conda create -n enf-pde-jax python=3.11
conda activate enf-pde-jax

Install the relevant dependencies.

pip install "jax[cuda12]" flax optax orbax
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install wandb matplotlib trimesh hydra-core tqdm netCDF4 py-pde
conda install -c conda-forge dedalus

Data

All datasets are generated using either dedalus or py-pde as before training. Code for datasets can be found under /experiments/fitting/datasets/pdes.py.

Note: To generate the navier-stokes dataset, PyTorch compiled with CUDA is required. As PyTorch and JAX do not play well together, we recommend running the navier-stokes dataset generation in a separate environment:

The following generates a new env that includes a CUDA-compiled PyTorch.

conda create -n enf-pde-jax-dset python=3.11
conda activate enf-pde-jax-dset
conda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install optax flax orbax wandb matplotlib trimesh hydra-core tqdm netCDF4 py-pde
conda install -c conda-forge dedalus

Next, run the following command to generate the navier-stokes dataset:

export PYTHONPATH=. && python experiments/fitting/gen_navier_stokes.py

Afterwards, you can deactivate the environment and return to the main environment:

conda deactivate
conda activate enf-pde-jax

Repo structure

We list relevant components of the repository here:

  • enf/ contains the code for the Equivariant Neural Field.
  • enf/steerable_attention/invariant/ contains the code for the bi-invariants used in the experiments.
  • experiments/fitting/ contains the code for all experiments in the paper.

Experiments

We specify commands per experiment in the experiments readme.

Citation

If you find this code useful, please consider citing our paper:

@article{knigge2024space,
  title={Space-Time Continuous PDE Forecasting using Equivariant Neural Fields},
  author={Knigge, David M and Wessels, David R and Valperga, Riccardo and Papa, Samuele and Sonke, Jan-Jakob and Gavves, Efstratios and Bekkers, Erik J},
  journal={arXiv preprint arXiv:2406.06660},
  year={2024}
}

About

Code for reproducing the paper "Space-Time Continuous PDE Forecasting using Equivariant Neural Fields" (https://arxiv.org/abs/2406.06660).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages