GitHunt
SR

srijith1996/FSL-SAGE

Federated Split Learning via Smashed Activation Gradient Estimation

FSL-SAGE: Accelerating Federated Split Learning via Smashed Activation Gradient Estimation

Static Badge
OpenReview

Introduction

Our Federated Split Learning (FSL) algorithm cuts down on communication
overheads in traditional Split Learning methods by directly estimating
server-returned gradients at each client using auxiliary models. The auxiliary
models are much smaller versions of the server model which are explicitly
trained to estimate the gradients that the server model would return for the
client's local input.

The algorithm is summarized in the following schematic:

FSL-SAGE schematic

Please refer to our paper for further details, and consider citing the same if you find this work useful.

Requirements

The project requirements can be simply installed using the environment config
file conda_env.yaml as follows:

conda env create -f conda_env.yaml

which will create a conda environment by the name sage. You can activate the
conda environment using:

conda activate sage

and all dependency requirements should be met.

Configuration

This project is powered by Hydra, which allows
hierarchical configurations and easy running of multiple ML experiments.
The config files for hydra are located in the folder
hydra_config.

There is a high degree of customizability here; datasets, models and FL
algorithms can be plugged in using configs. Please check out our
contributing readme for more details.

Datasets

Currently datasets are read and imported from the datas folder in the root of
the repository. You can simply create a folder for the repository and download
the dataset there. After performing the necessary preprocessing, simply
use/extend the get_dataset() function in datasets/__init__.py

Training

To train FSL-SAGE with defaults from
config.yaml, you can simply run

python main.py

Training results are saved in a saves folder in the root, where they will be
saved in folders segregated based on the FL algorithm, model, dataset and
distribution used.

The default number of clients (num_clients) is set to 10 and the default
number of rounds is rounds=200.
Each method will train upto the fixed rounds or until the number of MBs
specified in comm_threshold_mb is reached.

To choose a specific model or algorithm, the Hydra
command-line override
functionality can be used as follows

python main.py model=resnet18 algorithm=cse_fsl dataset=cifar100 dataset.distribution=iid

The following options are currently supported, click them to reveal the details:

Algorithm

Syntax : algorithm=<key>.
The FL algorithm to use for training.
List of algorithms currently supported:

Key Algorithm
fed_avg FedAvg
sl_multi_server SplitFedv1
sl_single_server SplitFedv2
cse_fsl CSE-FSL
fsl_sage FSL-SAGE
Dataset

Syntax : dataset=<key>.
The dataset used in training.
List of datasets currently supported:

Key Dataset
cifar10 cifar10
cifar100 cifar100
Model

Syntax : model=<key>.
The ML model to use for training.
List of models currently supported:

Key Model
resnet18 ResNet-18
resnet50 ResNet-50
resnet56 ResNet-56
resnet110 ResNet-110

Note that currently the above resnet models apart from resnet18 haven't been
tuned yet, so the results may not optimally represent FSL-SAGE's communication
benefits.

Data distribution

Syntax : dataset.distribution=<key>.
Determines the distrbution of the dataset across clients List of distributions
currently supported:

Key Distribution
iid homogeneous
noniid_dirichlet heterogeneous

For noniid_dirichlet you can specify the value of alpha using the key
dataset.alpha, e.g., dataset.alpha=1.


We also support multiruns in parallel using the
hydra-joblib-launcher
Thus, it is possible to run multiple experiments for different combinations of
hyperparams, models, datasets or algorithms given sufficient GPU memory.

python main.py -m model=resnet18,simple_conv algorithm=fed_avg,sl_single_server,sl_multi_server,cse_fsl,fsl_sage

The above would create parallel jobs that would run main.py on all combinations
of specified options.
The number of jobs can be controlled by modifying the hydra.launcher.n_jobs
option in config.yaml or by specifying
hydra.launcher.n_jobs=<jobs> as an option to the script.

Inference Plots

Please check out the readme, the functions used in the
plot_results.py and the configs in
exp_configs.yaml on how to generate the plots
for accuracy, communication load, etc.

Note on source code for LLMs

This code has been extended to allow for LLMs and a basic setup on natural
language generation (NLG) using the WebNLG E2E dataset has been tested.
We use LoRA fine-tuning on the GPT-2 medium model to learn the E2E task.
The code is available on the
llm branch of this
repository and is built upon the LoRA
codebase.
The code currently does not use the
PEFT or
Transformers
libraries by HuggingFace, since building model splitting on top of those is
challenging.
However, contributions to this end on the LLM branch would be welcome.
We would like to encourage research on developing automatic model splitters for PyTorch and other popular deep learning frameworks, since these could become increasingly relevant as split learning or federated split learning methods become more popular.