Skip to content

alexjackson1/tx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tx

This library is intended to be a reimplementation of the 'TransformerLens' library1 using JAX and the Flax module system.

The main distinguishing features of this library are as follows.

  • No dependencies on PyTorch, all numerical operations are performed using JAX.
  • The Flax module system is used for defining networks and storing activations.
  • Simplified low-level module implementations are provided in a 'single batch' style.

Installation

The following prerequisites are required to use the library.

  1. A Python 3.7+ installation.
  2. A working installation of jax and jaxlib (either CPU or GPU).
  3. Other module requirements (see requirements.txt).

1. Create Virtual Environment

Assuming you have a working Python 3.7+ installation, you should first clone this project into a new directory <project_dir>, and then create and upgrade a virtual environment in <env_dir>.

git clone https://github.com/alexjackson1/tx.git <project_dir>
cd <project_dir>
python -m venv <env_dir>
source <env_dir>/bin/activate
pip install --upgrade pip

2. Install a Compatible Version of JAX

To install a version of JAX that is compatible with your hardware, please refer to the JAX installation instructions on the project README. Installation via the pip wheel(s) is highly recommended.

3. Install the Remaining Requirements

Once you have installed a compatible version of JAX, you can install the remaining requirements as follows. This includes Flax, the module system used by this library for defining networks.

pip install -r requirements.txt

Usage

This library is still in development and is not yet ready for use.

The notebook(s) in the examples directory follow the tutorials on mechanistic interpretability provided here.

Relation to TransformerLens

The API of this library is intended to (eventually) expose the same functionality as the original 'TransformerLens' library, making some changes where appropriate.

Similarities

  1. The library seeks to model Transformer architectures and enable users to inspect intermediate activations and other hidden information (e.g. attention weights).
  2. Modules are written 'from scratch', attempting to eliminate abstractions that obfuscate the underlying mathematics.
  3. GPU acceleration is supported as a first-class feature.

Differences

  1. The transformer architecture and related algorithms use JAX, instead of PyTorch, for better performance and hardware acceleration.
  2. In-keeping with the functional paradigm of JAX, the library and API are designed to be more functional in nature and embrace the Flax philosophy.
  3. Module definitions use a 'single batch' style made possible by jax.vmap (reducing cognitive load and improving readability).

License

This project is licensed under the terms of the MIT license. The full license text can be found in the LICENSE file.

The original 'TransformerLens' library is also licensed under the terms of the MIT license. The full license text can be found here. Additionally, the original library can be cited as shown below.

@misc{nandatransformerlens2022,
    title  = {TransformerLens},
    author = {Nanda, Neel and Bloom, Joseph},
    url    = {https://github.com/neelnanda-io/TransformerLens},
    year   = {2022}
}

Footnotes

  1. Formerly 'EasyTransformer', 'TransformerLens' is maintained by Joseph Bloom and was created by Neel Nanda.

About

A Flax-based library for examining transformers, based on TransformerLens.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages