milad1378yz/MOTFM
Flow Matching for Medical Image Synthesis: Bridging the Gap Between Speed and Quality
MOTFM (Medical Optimal Transport Flow Matching)
MOTFM (Medical Optimal Transport Flow Matching) accelerates medical image generation while preserving, and often improving, quality, across 2D/3D and class/mask-conditional setups.
Paper
Checkpoints
Synthetic Data
Requirements
- Python: 3.9 - 3.12
- Core pinned stack (from
pyproject.toml):torch==2.5.1flow_matching==1.0.10pytorch-lightning==2.5.6numpy==1.26.4monai_generative==0.2.3
To install from pyproject.toml, run:
pip install -e .Data Preparation
Important Note:
- Your training data must be stored in a single
.pklfile, which itself must follow the structure below.
Within that .pkl file, your data dictionary should look like:
{
"train": [ # List of training samples
{
"image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
"mask": "Tensor[1, Height, Width, ...] (int32)",
"class": "Scalar integer (int32)",
"metadata": "Structured data (dict or other format)"
},
...
],
"valid": [ # List of validation samples
{
"image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
"mask": "Tensor[1, Height, Width, ...] (int32)",
"class": "Scalar integer (int32)",
"metadata": "Structured data (dict or other format)"
},
...
],
"test": [ # List of test samples
{
"image": "Tensor[Channels, Height, Width, ...] (float32, normalized)",
"mask": "Tensor[1, Height, Width, ...] (int32)",
"class": "Scalar integer (int32)",
"metadata": "Structured data (dict or other format)"
},
...
]
}Make sure your dataset adheres to the described data structure, saved in a single .pkl file, before running the training or inference pipelines.
Configuration Files
You must either create or modify a YAML configuration file to suit your dataset paths, model parameters, and hyperparameters. Some sample configuration files are provided in the configs/ folder. By default, configs/default.yaml is used if no custom path is provided.
Training
To train the model, run:
python trainer.py --config_path configs/default.yamlor (after installation):
motfm-train --config_path configs/default.yaml--config_path: Path to your YAML configuration file. Defaults toconfigs/default.yamlif not provided.
Note: Make sure you have prepared your dataset (as a single .pkl file) and configuration file properly before starting training.
Inference
Use inferer.py to generate synthetic samples from a trained checkpoint and save them as a .pkl.
Quick start
Run with your config and checkpoint directory:
python inferer.py \
--config_path configs/default.yaml \
--model_path mask_class_conditioning_checkpoints/default \
--num_samples 200or (after installation):
motfm-infer \
--config_path configs/default.yaml \
--model_path mask_class_conditioning_checkpoints/default \
--num_samples 200Arguments
--config_path(str, default:configs/default.yaml): Config file used for model/data setup.--model_path(str, optional): Checkpoint.ckptfile or directory.--num_samples(int, optional): Number of samples to save. If omitted, saves all validation samples.--num_inference_steps(int, optional): Number of solver time points used during sampling. If omitted, usessolver_args.time_pointsfrom the config.--output_path(str, optional): Explicit output.pklpath.--overwrite(flag): Overwrite an existing file at--output_path.--output_norm(str, default:per_sample_minmax): One ofclip_0_1,per_sample_minmax,global_minmax,none.--allow_config_mismatch(flag): Allow loading a checkpoint whose saved critical model fields differ from current config.--seed(int, optional): Override RNG seed for reproducible inference. Defaults totrain_args.seedif provided.
Checkpoint resolution behavior
If --model_path is omitted, inferer searches:
train_args.checkpoint_dir/<config_basename>
If --model_path is provided, inferer checks (in order):
<model_path><model_path>/<config_basename><model_path>/latest
If a directory is selected, checkpoint preference is:
last.ckpt(if present)- otherwise, the most recently modified
*.ckpt
Output behavior
- If
--output_pathis omitted, output is saved in the resolved checkpoint directory as:samples_<config_basename>_<checkpoint_name>_steps<time_points>.pkl
- If output file exists and
--overwriteis not set, a timestamp suffix is appended automatically. - Generated samples are produced from the validation split and saved under:
data_args.split_train- and also
data_args.split_valif that key is different.
CPU-only note
If you run inference on CPU, set model_args.use_flash_attention: false in your config.
Flash attention requires CUDA and will raise an error otherwise.
3D Evaluation
A dedicated script is available in evaluation_3d/ to compute 3D metrics between two datasets:
- MMD
- MS-SSIM
- 3D-FID (R3D-18 features + MONAI FIDMetric)
python evaluation_3d/evaluate_3d.py \
--generated_path /path/to/generated.pkl \
--reference_path /path/to/reference.pkl \
--generated_split train \
--reference_split valid \
--num_samples 200Use --skip_fid to skip 3D-FID when torchvision video weights are unavailable.
News
2025-04-09| Code released.2025-03-29| The paper became available on arXiv.2025-05-27| The paper was accepted to MICCAI 2025.
Citation
If you find this code or our work useful in your research, please cite:
@inproceedings{yazdani2025flow,
title={Flow matching for medical image synthesis: Bridging the gap between speed and quality},
author={Yazdani, Milad and Medghalchi, Yasamin and Ashrafian, Pooria and Hacihaliloglu, Ilker and Shahriari, Dena},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={216--226},
year={2025},
organization={Springer}
}Enjoy working with MOTFM! Feel free to open an issue or pull request if you have any questions or suggestions.
