This library is intended to be a reimplementation of the 'TransformerLens' library1 using JAX and the Flax module system.
The main distinguishing features of this library are as follows.
- No dependencies on PyTorch, all numerical operations are performed using JAX.
- The Flax module system is used for defining networks and storing activations.
- Simplified low-level module implementations are provided in a 'single batch' style.
The following prerequisites are required to use the library.
- A Python 3.7+ installation.
- A working installation of
jax
andjaxlib
(either CPU or GPU). - Other module requirements (see
requirements.txt
).
Assuming you have a working Python 3.7+ installation, you should first clone this project into a new directory <project_dir>
, and then create and upgrade a virtual environment in <env_dir>
.
git clone https://github.com/alexjackson1/tx.git <project_dir>
cd <project_dir>
python -m venv <env_dir>
source <env_dir>/bin/activate
pip install --upgrade pip
To install a version of JAX that is compatible with your hardware, please refer to the JAX installation instructions on the project README.
Installation via the pip
wheel(s) is highly recommended.
Once you have installed a compatible version of JAX, you can install the remaining requirements as follows. This includes Flax, the module system used by this library for defining networks.
pip install -r requirements.txt
This library is still in development and is not yet ready for use.
The notebook(s) in the examples
directory follow the tutorials on mechanistic interpretability provided here.
The API of this library is intended to (eventually) expose the same functionality as the original 'TransformerLens' library, making some changes where appropriate.
- The library seeks to model Transformer architectures and enable users to inspect intermediate activations and other hidden information (e.g. attention weights).
- Modules are written 'from scratch', attempting to eliminate abstractions that obfuscate the underlying mathematics.
- GPU acceleration is supported as a first-class feature.
- The transformer architecture and related algorithms use JAX, instead of PyTorch, for better performance and hardware acceleration.
- In-keeping with the functional paradigm of JAX, the library and API are designed to be more functional in nature and embrace the Flax philosophy.
- Module definitions use a 'single batch' style made possible by
jax.vmap
(reducing cognitive load and improving readability).
This project is licensed under the terms of the MIT license.
The full license text can be found in the LICENSE
file.
The original 'TransformerLens' library is also licensed under the terms of the MIT license. The full license text can be found here. Additionally, the original library can be cited as shown below.
@misc{nandatransformerlens2022,
title = {TransformerLens},
author = {Nanda, Neel and Bloom, Joseph},
url = {https://github.com/neelnanda-io/TransformerLens},
year = {2022}
}
Footnotes
-
Formerly 'EasyTransformer', 'TransformerLens' is maintained by Joseph Bloom and was created by Neel Nanda. ↩