-
Notifications
You must be signed in to change notification settings - Fork 9
Implementation of Mixture of Experts in JAX using Flax #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @RubensZimbres, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
This pull request introduces a JAX/Flax implementation of a Mixture of Experts (MoE) layer, along with a Rotary Positional Embedding (RoPE) Multi-Head Attention mechanism. The MoE implementation includes Expert
and GatingNetwork
modules, allowing for a neural network architecture where multiple specialized sub-networks are combined. The RoPE implementation provides a multi-head attention mechanism with rotary positional embeddings. The pull request also includes comprehensive tests for both the MoE layer and the RoPE attention mechanism, ensuring their correct functionality. Additionally, the README has been updated with example usages of the MoE layer and RoPE attention module.
Highlights
- Mixture of Experts (MoE) Implementation: Introduces a JAX/Flax implementation of a Mixture of Experts layer, including
Expert
andGatingNetwork
modules. This allows the model to learn to route different parts of the input space to specialized modules. - RoPE Multi-Head Attention: Implements a Multi-Head Attention mechanism with Rotary Positional Embeddings (RoPE) in JAX/Flax. This includes functions for precomputing rotary embeddings and applying them to the attention mechanism.
- Comprehensive Testing: Adds thorough tests for both the Mixture of Experts layer and the RoPE Multi-Head Attention mechanism, covering forward passes, masking, and error conditions.
- README Updates: Updates the README with example usages of the newly implemented Mixture of Experts layer and RoPE Multi-Head Attention module, providing clear instructions and code snippets for users.
Changelog
Click here to see the changelog
- .gitignore
- Added
jax_env
to the.gitignore
file to exclude the JAX environment from version control.
- Added
- README.md
- Added example usage of the Mixture of Experts (MoE) layer.
- Added example usage of the RoPEMultiHeadAttention module.
- jaxgarden/attention/rope_multi_head_attention.py
- Implemented the RoPEMultiHeadAttention module, including functions for rotary positional embeddings.
- Added
rotate_half
function to rotate half the hidden dimensions of the input tensor. - Added
apply_rotary_pos_emb
function to apply Rotary Positional Embedding to the input tensor. - Added
precompute_rotary_embeddings
function to precompute the RoPE cosine and sine embeddings. - Implemented the
RoPEMultiHeadAttention
Flax module.
- jaxgarden/functional/mixture_of_experts.py
- Implemented the Mixture of Experts (MoE) layer in JAX/Flax.
- Added
Expert
module, a simple feed-forward expert network. - Added
GatingNetwork
module, a gating network that outputs weights for each expert. - Added
MixtureOfExperts
module, which combines the expert networks and gating network.
- tests/attention/test_RoPEMultiHeadAttention.py
- Added tests for the RoPEMultiHeadAttention module.
- Included tests for
rotate_half
,precompute_rotary_embeddings
, andapply_rotary_pos_emb
functions. - Added tests for the forward pass, masking, and error conditions of the RoPEMultiHeadAttention module.
- tests/functional/test_MoE.py
- Added tests for the Mixture of Experts (MoE) layer.
- Included tests for the
Expert
,GatingNetwork
, andMixtureOfExperts
modules. - Added tests for initialization, forward pass, output logic, and different parameter configurations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
In JAX and Flax, the experts rise,
A gating net, with keen, discerning eyes,
Routes inputs true,
For outputs new,
A mixture's wisdom, to the skies.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a Mixture of Experts (MoE) layer and a RoPE Multi-Head Attention mechanism in JAX/Flax. The implementation seems well-structured and includes corresponding tests. However, there are a few areas that could be improved for clarity, efficiency, and robustness.
Summary of Findings
- Mask Shape Validation: The mask shape validation in
RoPEMultiHeadAttention
could be more robust to handle various mask shapes and provide more informative error messages. - Efficiency of Expert Initialization: The
MixtureOfExperts
layer could be made more efficient by usingnn.scan
orvmap
for initializing the experts, especially when the experts are identical. - RoPEMultiHeadAttention Masking Test: The masking test in
test_RoPEMultiHeadAttention.py
could be more rigorous by inspecting attention weights to ensure causal masking is working as expected.
Merge Readiness
The pull request is almost ready for merging. Addressing the medium severity comments would improve the robustness and maintainability of the code. I am unable to directly approve this pull request, so please have others review and approve this code before merging.
if mask.ndim == 2: # Likely (seq_len, seq_len) | ||
mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len) | ||
elif mask.ndim == 3 and mask.shape[1] != self.num_heads: | ||
# Likely (batch, seq_len, seq_len) or causal (1, sl, sl) | ||
mask = mask[:, None, :, :] | ||
# Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic assumes that if mask.ndim == 3
and mask.shape[1] != self.num_heads
, then the mask is either (batch, seq_len, seq_len)
or causal (1, seq_len, seq_len)
. However, it's possible that the mask has a different shape that is also incompatible with attn_scores
. Consider adding a more robust check to ensure that the mask shape is compatible before attempting to broadcast it. Also, consider adding a comment to explain the assumption.
mask = mask[:, None, :, :] # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)
else: | ||
raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi @RubensZimbres, thanks for this PR! seems like it has conflicts with the main branch. You may want to start a new branch from the main and add your code in a new PR --if it turns out be harder to solve conflicts . |
Conflicts solved |
JAX/Flax implementation of a Mixture of Experts (MoE) layer.
This code provides a conceptual implementation of a Mixture of Experts layer,
a neural network architecture where multiple specialized "expert" sub-networks
are combined. A gating network determines which expert (or combination of
experts) processes a given input, allowing the model to learn to route
different parts of the input space to specialized modules. This can lead to
models with higher capacity and better efficiency, especially in sparse
formulations where only a subset of experts are activated per input.
"Adaptive Mixtures of Local Experts"
by Robert A. Jacobs, Michael I. Jordan, Steven J. Nowlan, and Geoffrey E. Hinton.
Published in Neural Computation, Volume 3, Issue 1, Pages 79-87, 1991.
Available at: https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf
All pytest tests passed, as well as ruff and mypy.