GitHunt
HU

hurkanugur/MNIST-Digit-Classifier

This project implements a CNN for handwritten digit classification on the MNIST dataset using PyTorch. It uses stacked convolutional layers with dropout, batch normalization, and max pooling to classify 28ร—28 grayscale digits (0โ€“9) with Softmax output.

๐Ÿ“š MNIST Handwritten Digit Classifier

๐Ÿ“– Overview

This project predicts handwritten digit classes (0โ€“9) using the MNIST dataset and a convolutional neural network (CNN) built with PyTorch. It demonstrates a full machine learning pipeline from data loading to inference, including:

  • ๐Ÿง  CNN with stacked convolutional layers, Batch Normalization, Max Pooling, LeakyReLU activation, and Dropout
  • โš–๏ธ Cross-Entropy Loss for multi-class classification
  • ๐Ÿš€ Adam optimizer for gradient updates
  • ๐Ÿ”€ Mini-batch training with DataLoader
  • ๐Ÿ“Š Train/Validation/Test split for robust evaluation
  • ๐Ÿ“ˆ Live training & validation loss monitoring
  • โœ… Softmax activation on the output for probability distribution across 10 classes
  • ๐ŸŽจ Interactive Gradio Interface for real-time prediction

๐Ÿ–ผ๏ธ Application Screenshot

Below is a preview of the Gradio Interface used for real-time classification:

Application Screenshot


๐Ÿงฉ Libraries

  • PyTorch โ€“ model, training, and inference
  • pandas โ€“ data handling
  • matplotlib โ€“ loss visualization
  • pickle โ€“ saving/loading normalization params and trained model
  • Gradio โ€” interactive web interface for real-time model demos

โš™๏ธ Requirements

  • Python 3.13+
  • Recommended editor: VS Code

๐Ÿ“ฆ Installation

  • Clone the repository
git clone https://github.com/hurkanugur/MNIST-Digit-Classifier.git
  • Navigate to the MNIST-Digit-Classifier directory
cd MNIST-Digit-Classifier
  • Install dependencies
pip install -r requirements.txt

๐Ÿ”ง Setup Python Environment in VS Code

  1. View โ†’ Command Palette โ†’ Python: Create Environment
  2. Choose Venv and your Python version
  3. Select requirements.txt to install dependencies
  4. Click OK

๐Ÿ“‚ Project Structure

assets/
โ”œโ”€โ”€ app_screenshot.png                # Screenshot of the application
โ””โ”€โ”€ 1, 2, 3 ... 9.png                 # Digit samples

data/
โ””โ”€โ”€ MNIST                             # MNIST dataset

model/
โ””โ”€โ”€ mnist_digit_classifier.pth        # Trained model

src/
โ”œโ”€โ”€ config.py                         # Paths, hyperparameters, split ratios
โ”œโ”€โ”€ dataset.py                        # Data loading & preprocessing
โ”œโ”€โ”€ device_manager.py                 # Selects and manages compute device
โ”œโ”€โ”€ train.py                          # Training pipeline
โ”œโ”€โ”€ inference.py                      # Inference pipeline
โ”œโ”€โ”€ model.py                          # Neural network definition
โ””โ”€โ”€ visualize.py                      # Training/validation plots

main/
โ”œโ”€โ”€ main_train.py                     # Entry point for training
โ””โ”€โ”€ main_inference.py                 # Entry point for inference

requirements.txt                      # Python dependencies

๐Ÿ“‚ Model Architecture

Input (1ร—28ร—28)

Conv Block 1:
  โ†’ Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1, padding_mode="reflect")
  โ†’ BatchNorm2d(32)
  โ†’ ReLU
  โ†’ MaxPool2d(kernel_size=2, stride=2)
  โ†’ Dropout(0.25)

Conv Block 2:
  โ†’ Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, padding_mode="reflect")
  โ†’ BatchNorm2d(64)
  โ†’ ReLU
  โ†’ MaxPool2d(kernel_size=2, stride=2)
  โ†’ Dropout(0.25)

Fully Connected:
  โ†’ Flatten
  โ†’ Linear(64ร—7ร—7, 128)
  โ†’ ReLU
  โ†’ BatchNorm1d(128)
  โ†’ Dropout(0.5)
  โ†’ Linear(128, 10)
  โ†’ Softmax(Output)

๐Ÿ“‚ Train the Model

Navigate to the project directory:

cd MNIST-Digit-Classifier

Run the training script:

python -m main.main_train

or

python3 -m main.main_train

๐Ÿ“‚ Run Inference / Make Predictions

Navigate to the project directory:

cd MNIST-Digit-Classifier

Run the app:

python -m main.main_inference

or

python3 -m main.main_inference
hurkanugur/MNIST-Digit-Classifier | GitHunt