Skip to content

Implementation of Mixture of Experts in JAX using Flax #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

RubensZimbres
Copy link
Contributor

JAX/Flax implementation of a Mixture of Experts (MoE) layer.

This code provides a conceptual implementation of a Mixture of Experts layer,
a neural network architecture where multiple specialized "expert" sub-networks
are combined. A gating network determines which expert (or combination of
experts) processes a given input, allowing the model to learn to route
different parts of the input space to specialized modules. This can lead to
models with higher capacity and better efficiency, especially in sparse
formulations where only a subset of experts are activated per input.

"Adaptive Mixtures of Local Experts"
by Robert A. Jacobs, Michael I. Jordan, Steven J. Nowlan, and Geoffrey E. Hinton.
Published in Neural Computation, Volume 3, Issue 1, Pages 79-87, 1991.
Available at: https://www.cs.toronto.edu/~hinton/absps/jjnh91.pdf

All pytest tests passed, as well as ruff and mypy.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @RubensZimbres, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

This pull request introduces a JAX/Flax implementation of a Mixture of Experts (MoE) layer, along with a Rotary Positional Embedding (RoPE) Multi-Head Attention mechanism. The MoE implementation includes Expert and GatingNetwork modules, allowing for a neural network architecture where multiple specialized sub-networks are combined. The RoPE implementation provides a multi-head attention mechanism with rotary positional embeddings. The pull request also includes comprehensive tests for both the MoE layer and the RoPE attention mechanism, ensuring their correct functionality. Additionally, the README has been updated with example usages of the MoE layer and RoPE attention module.

Highlights

  • Mixture of Experts (MoE) Implementation: Introduces a JAX/Flax implementation of a Mixture of Experts layer, including Expert and GatingNetwork modules. This allows the model to learn to route different parts of the input space to specialized modules.
  • RoPE Multi-Head Attention: Implements a Multi-Head Attention mechanism with Rotary Positional Embeddings (RoPE) in JAX/Flax. This includes functions for precomputing rotary embeddings and applying them to the attention mechanism.
  • Comprehensive Testing: Adds thorough tests for both the Mixture of Experts layer and the RoPE Multi-Head Attention mechanism, covering forward passes, masking, and error conditions.
  • README Updates: Updates the README with example usages of the newly implemented Mixture of Experts layer and RoPE Multi-Head Attention module, providing clear instructions and code snippets for users.

Changelog

Click here to see the changelog
  • .gitignore
    • Added jax_env to the .gitignore file to exclude the JAX environment from version control.
  • README.md
    • Added example usage of the Mixture of Experts (MoE) layer.
    • Added example usage of the RoPEMultiHeadAttention module.
  • jaxgarden/attention/rope_multi_head_attention.py
    • Implemented the RoPEMultiHeadAttention module, including functions for rotary positional embeddings.
    • Added rotate_half function to rotate half the hidden dimensions of the input tensor.
    • Added apply_rotary_pos_emb function to apply Rotary Positional Embedding to the input tensor.
    • Added precompute_rotary_embeddings function to precompute the RoPE cosine and sine embeddings.
    • Implemented the RoPEMultiHeadAttention Flax module.
  • jaxgarden/functional/mixture_of_experts.py
    • Implemented the Mixture of Experts (MoE) layer in JAX/Flax.
    • Added Expert module, a simple feed-forward expert network.
    • Added GatingNetwork module, a gating network that outputs weights for each expert.
    • Added MixtureOfExperts module, which combines the expert networks and gating network.
  • tests/attention/test_RoPEMultiHeadAttention.py
    • Added tests for the RoPEMultiHeadAttention module.
    • Included tests for rotate_half, precompute_rotary_embeddings, and apply_rotary_pos_emb functions.
    • Added tests for the forward pass, masking, and error conditions of the RoPEMultiHeadAttention module.
  • tests/functional/test_MoE.py
    • Added tests for the Mixture of Experts (MoE) layer.
    • Included tests for the Expert, GatingNetwork, and MixtureOfExperts modules.
    • Added tests for initialization, forward pass, output logic, and different parameter configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


In JAX and Flax, the experts rise,
A gating net, with keen, discerning eyes,
Routes inputs true,
For outputs new,
A mixture's wisdom, to the skies.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Mixture of Experts (MoE) layer and a RoPE Multi-Head Attention mechanism in JAX/Flax. The implementation seems well-structured and includes corresponding tests. However, there are a few areas that could be improved for clarity, efficiency, and robustness.

Summary of Findings

  • Mask Shape Validation: The mask shape validation in RoPEMultiHeadAttention could be more robust to handle various mask shapes and provide more informative error messages.
  • Efficiency of Expert Initialization: The MixtureOfExperts layer could be made more efficient by using nn.scan or vmap for initializing the experts, especially when the experts are identical.
  • RoPEMultiHeadAttention Masking Test: The masking test in test_RoPEMultiHeadAttention.py could be more rigorous by inspecting attention weights to ensure causal masking is working as expected.

Merge Readiness

The pull request is almost ready for merging. Addressing the medium severity comments would improve the robustness and maintainability of the code. I am unable to directly approve this pull request, so please have others review and approve this code before merging.

Comment on lines 181 to 187
if mask.ndim == 2: # Likely (seq_len, seq_len)
mask = mask[None, None, :, :] # -> (1, 1, seq_len, seq_len)
elif mask.ndim == 3 and mask.shape[1] != self.num_heads:
# Likely (batch, seq_len, seq_len) or causal (1, sl, sl)
mask = mask[:, None, :, :]
# Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic assumes that if mask.ndim == 3 and mask.shape[1] != self.num_heads, then the mask is either (batch, seq_len, seq_len) or causal (1, seq_len, seq_len). However, it's possible that the mask has a different shape that is also incompatible with attn_scores. Consider adding a more robust check to ensure that the mask shape is compatible before attempting to broadcast it. Also, consider adding a comment to explain the assumption.

                mask = mask[:, None, :, :] # Assume (batch, seq_len, seq_len) -> (batch, 1, seq_len, seq_len)

Comment on lines 200 to 201
else:
raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This ValueError could be more informative by including the expected shape. Consider adding the expected shape to the error message.

                    raise ValueError(f"Mask shape {mask.shape} != exp shape {mask_shape_expected}")

@monatis
Copy link
Member

monatis commented May 15, 2025

hi @RubensZimbres, thanks for this PR! seems like it has conflicts with the main branch. You may want to start a new branch from the main and add your code in a new PR --if it turns out be harder to solve conflicts .

@RubensZimbres
Copy link
Contributor Author

Conflicts solved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants