icon-lab/FedGAT
Official implementation of FedGAT: Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction (https://arxiv.org/abs/2502.04521)
FedGAT
Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction
1UMRAM, 2Bilkent University
[arXiv:2502.04521]
Official PyTorch implementation of FedGAT, a novel model-agnostic federated learning technique based on generative autoregressive transformers for MRI reconstruction. Unlike conventional federated learning that requires homogeneous model architectures across sites, FedGAT enables flexible collaborations among sites with distinct reconstruction models by decentralizing the training of a global generative prior. This prior captures the distribution of multi-site MRI data via autoregressive prediction across spatial scales, guided by a site-specific prompt. Site-specific reconstruction models are trained using hybrid datasets combining local and synthetic samples. Comprehensive experiments demonstrate that FedGAT achieves superior within-site and across-site reconstruction performance compared to state-of-the-art FL baselines while preserving privacy.
⚙️ Installation
# Clone repo
git clone https://github.com/icon-lab/FedGAT.git
cd FedGAT
# Create and activate conda environment
conda env create -f environment.yml
conda activate fedgat📚 Data Preparation
Expected dataset structure:
data/
├── Site_0/
│ ├── train/
│ │ └── data/
│ └── val/
│ └── data/
├── Site_1/
│ ├── train/
│ │ └── data/
│ └── val/
│ └── data/
├── Site_2/
│ ├── train/
│ │ └── data/
│ └── val/
│ └── data/
Each train/ and val/ folder contains MRI images (e.g., .png files) for each site.
🏋️ Training
Basic Training Command
torchrun --nproc_per_node=1 train.py \
--case='multicoil' \
--depth=16 \
--bs=16 \
--ep=500 \
--fp16=1 \
--alng=1e-3 \
--wpe=0.1 \
--client_num=3 \
--comm_round=1
Training Parameters
| Parameter | Description | Default |
|---|---|---|
--case |
Dataset type ('singlecoil' or 'multicoil') |
- |
--client_num |
Number of federated Sites | 3 |
--comm_round |
Number of communication rounds | 1 |
--depth |
Model depth | 16 |
--bs |
Batch size | 16 |
--ep |
Number of epochs | 500 |
--fp16 |
Mixed precision training | 1 |
--alng |
AdaLN gamma | 1e-3 |
--wpe |
Final learning rate ratio at the end of training | 0.1 |
FedGAT will create a fedGAT_output/ directory to store all checkpoints and logs. You can monitor training by:
- Inspecting
fedGAT_output/log.txtandfedGAT_output/stdout.txt
If your run is interrupted, simply re-execute the same training command—FedGAT will automatically pick up from the latest fedGAT_output/ckpt*.pth checkpoint (see utils/misc.py, lines 344–357).
📖 Citation
You are welcome to use, modify, and distribute this code. We kindly request that you acknowledge this repository and cite our paper appropriately.
@article{nezhad2025generative,
title={Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction},
author={Nezhad, Valiyeh A and Elmas, Gokberk and Kabas, Bilal and Arslan, Fuat and {\c{C}}ukur, Tolga},
journal={arXiv preprint arXiv:2502.04521},
year={2025}
}🙏 Acknowledgments
This repository uses code from the following projects:
Copyright © 2025, ICON Lab.