samsja/muon_fsdp_2
Muon fsdp 2
Muon fsdp 2
This codebase implements the muon optimizer compatible with fsdp2 as described in this blog post
Most of the important code was developed by main-horse here
This repo integrates that code into a training codebase and optimizes the communication part (we do gather scatter instead of all_gather)
This repo is composed of two parts:
src/muon_fsdp2is the implementation of the muon optimizer compatible with fsdp2src/zerobandis the training codebase
Muon FSDP2 package
this is a standalone package that can be used to train models with the muon optimizer.
install the package from pypi
uv pip install muon-fsdp2or from source
uv pip install git+https://github.com/samsja/muon_fsdp_2.git@mainexample usage
from muon_fsdp2 import Muon
optimizer = Muon([
dict(
params=model.square_params(),
lr=1e-3,
use_muon=True
),
dict(
params=model.non_square_params(),
lr=1e-3,
use_muon=False
)
]) ZeroBand
ZeroBand is a fork of this repo, a standalone training codebase for LLMs specifically designed for using fsdp2 and muon optimizer.
Install
export HF_TOKEN=then use the default install script to install the dependencies
curl -sSL https://raw.githubusercontent.com/samsja/muon_fsdp_2/main/install.sh | bash
source $HOME/.local/bin/env
or do it manually
git clone https://github.com/samsja/muon_fsdp_2
cd muon_fsdp_2
uv syncrun debug
PRIME_DEBUG=1 uv run torchrun --nproc_per_node=2 train_fsdp.py @ configs/debug/normal.tomlrun 150
uv run torchrun --nproc_per_node=8 train_fsdp.py @ configs/150M/H100.tomlrun 1b
uv run torchrun --nproc_per_node=8 train_fsdp.py @ configs/1B/H100.tomlrun 7b
uv run torchrun --nproc_per_node=8 train_fsdp.py @ configs/7B/H100.tomlbenchmark
old code not updated yet
| Model Size | GPUs | GPU Type | MFU |
|---|---|---|---|
| 1B | 8 | H100 sxm | 45% |
| 7B | 8 | H100 sxm | 49% |
convergence 150M
old code not updated yet
to reproduce the convergence, run
uv run torchrun --nproc_per_node=8 train_fsdp.py @ configs/150M/H100.toml
uv run torchrun --nproc_per_node=8 train_ddp.py @ configs/150M/H100.toml