IrinaMBejan/OptimizationML
Generalization properties of learning algorithms and Sharpness Aware Minimization based on Minimum Sharpness (for Optimization for Machine Learning Course, EPFL).
Generalization properties of learning algorithms and Sharpness Aware Minimization based on Minimum Sharpness
CS-439: Optimization for machine learning
Description:
One of the biggest challenges in deep learning is an understanding generalization. Sharpness is one of the indicators of generalization properties that perform well in practice. Moreover, sharpness aware minimization (SAM) is a new state-of-the-art technique based on simultaneously minimizing both loss and sharpness. In this work, we investigate the recently introduced notion of sharpness, known as minimum sharpness. We investigate its correlation with the generalization gap, by considering many different optimizers and SAM. Finally, we tackle the question of adaptivity of learning algorithms as that also has an impact on generalization, and investigate how the choice of optimizer influences sharpness.
Project structure
- The folder
/checkpointscontains results from current runs that you do within the repository once you fork it. - The folder
/checkpoints_testcontains some precomputed checkpoints for the model illustrated in theTrainingSample.ipynbnotebook. The structure of the checkpoints folder is:/DATASET(FashionMNIST/CIFAR10)/MODEL_ARCHITECTURE(SimpleBatch/MiddleBatch/ComplexBatch)/epochX(50/100/150/200): We train all models up to 200 epochs and save the checkpoints every 50 epochs./converged: Whenever the model converges (loss is lower than tolerance set), we save again the checkpoints.
- The folder
/datashould be empty by default and will be populated with data when training the models. - The folder
/resultscontains .csv files with results for each of the datasets. - The folder
/optimizerscontains implementation of AdaBound, AdaShift and SAM in torch, collected from external sources. - The folder
/sharpnesscontains the approximate calculation of the Hessian and sharpness - Notebooks
TrainingSample.ipynb,DataAnalysis.ipynbillustrate our work and are presented below. - The files within the repository represent:
models.py- Contains the architecture of the models we considered.
main.py- Able to run the trainings and computation for a given configuration. A configuration is given by dataset, model architecture, optimizer
helpers.py- Various utils used for training, testing, computation, data preprocessing.
Running the code
We require installation of Python. The needed libraries are stated in requirements.txt, to install them run: pip install -r requirements.txt,
or pip3 install -r requirements.txt (Python 3).
- To explore our work, we encourage you to look through our notebooks:
TrainingSample.ipynballows you to train a model and compute sharpness for a given dataset, architecture and optimizerDataAnalysis.ipynbloads all the results from trainings and prepares the plots. If results are missing, it requires you to download all the existing checkpoints from training to extract the results, which might take a longer time, or alternatively retrain the models.
All the existing checkpoints are available at: https://drive.google.com/drive/folders/10LuJDXzP6P_xH-z66Kh4KaWPfR1s0-t9?usp=sharing
However, due to limited size on Github, we have not added them here (>30GB).
- For running a model for a given configuration, we also offer a runnable Python file:
python main.py train $dataset $model $optimizer $use_sam $load_existingallows you to train a model- dataset should be
CIFAR10orFashionMNIST - model should be
SimpleBatch,MiddleBatchorComplexBatch - optimizer should be
SGD,PHB,Adagrad,Adam,AdaShift,AdaBound - use_sam should be 0 (do not use) or 1 (use)
- load_existing is not used here, it can be 0 or 1.
- dataset should be
python main.py compute_sharpness $dataset $model $optimizer $use_sam $load_existingallows you to compute sharpness for the given model. All params stay the same, except for:- load_existing should be 1 if you trained the model already and would like to load from file, 0 otherwise
python main.py plot $dataset $model $optimizer $use_sam $load_existingallows you to visualize the computations
- To automatically run all the configuration (dataset, optimizer, arhitecture), we offer you some shell scripts which can be run as:
- To train all models, run:
chmod +x train_alland./train_all - To compute sharpness for all trained models, run:
chmod +x compute_sharpness_alland./compute_sharpness_all.sh
Authors
- Jana Vuckovic: jana.vuckovic@epfl.ch
- Miguel-Angel Sanchez Ndoye: miguel-angel.sanchezndoye@epfl.ch
- Irina Bejan: irina.bejan@epfl.ch