HO
HoomKh/KNN-CIFAR-10
A fast and efficient K-Nearest Neighbors classifier implementation with three distance computation strategies (vectorized, one-loop, two-loop). Features CIFAR-10 dataset support, performance benchmarking, cross-validation for optimal k selection, and comprehensive documentation. Perfect for educational purposes and machine learning projects.
K-Nearest Neighbors Classifier
A fast and efficient implementation of the K-Nearest Neighbors algorithm with multiple distance computation strategies and comprehensive benchmarking tools.
๐ Features
- Multiple Distance Computation Methods: Vectorized (fastest), one-loop, and two-loop implementations
- CIFAR-10 Dataset Support: Ready-to-use with popular image classification dataset
- Performance Benchmarking: Compare different methods and k values
- Cross-Validation: Find optimal k value automatically
- Comprehensive Documentation: Detailed API docs and examples
- Unit Tests: Full test coverage for reliability
๐ฆ Installation
# Clone the repository
git clone https://github.com/HoomKh/knn-classifier.git
cd knn-classifier
# Install dependencies
pip install -r requirements.txt๐ฏ Quick Start
from KNN import KNearestNeighbor
import numpy as np
# Create classifier
Classifier = KNearestNeighbor()
# Train (memorize data)
X_train = np.random.rand(1000, 10)
y_train = np.random.randint(0, 3, 1000)
Classifier.train(X_train, y_train)
# Predict
X_test = np.random.rand(100, 10)
predictions = Classifier.predict(X_test, k=5)
print(f"Predictions: {predictions}")๐ Performance Comparison
| Method | Time (ms) | Memory (MB) | Accuracy |
|---|---|---|---|
| No Loops (Vectorized) | 15.2 | 45.1 | 0.274 |
| One Loop | 89.7 | 42.3 | 0.274 |
| Two Loops | 1245.3 | 40.1 | 0.274 |
๐ง Usage Examples
Basic Classification
from KNN import KNearestNeighbor
import numpy as np
# Load your data
X_train = np.random.rand(1000, 10)
y_train = np.random.randint(0, 3, 1000)
X_test = np.random.rand(100, 10)
# Create and train classifier
Classifier = KNearestNeighbor()
Classifier.train(X_train, y_train)
# Make predictions
predictions = Classifier.predict(X_test, k=5)
accuracy = np.mean(predictions == y_test)
print(f"Accuracy: {accuracy:.3f}")CIFAR-10 Image Classification
from data_utils import load_CIFAR10
from KNN import KNearestNeighbor
import numpy as np
# Load CIFAR-10 data
cifar10_dir = "path/to/cifar-10-batches-py"
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# Preprocess data
X_train = np.reshape(X_train, (X_train.shape[0], -1))
X_test = np.reshape(X_test, (X_test.shape[0], -1))
# Train and predict
Classifier = KNearestNeighbor()
Classifier.train(X_train, y_train)
predictions = Classifier.predict(X_test, k=5)
accuracy = np.mean(predictions == y_test)
print(f"CIFAR-10 Accuracy: {accuracy:.3f}")Performance Benchmarking
import time
# Compare all distance computation methods
k_values = [1, 3, 5, 7, 9]
accuracies = []
for k in k_values:
start_time = time.time()
predictions = Classifier.predict(X_test, k=k)
end_time = time.time()
accuracy = np.mean(predictions == y_test)
execution_time = end_time - start_time
accuracies.append(accuracy)
print(f"k={k}: Accuracy={accuracy:.3f}, Time={execution_time:.4f}s")๐ Project Structure
knn-classifier/
โโโ README.md # This file
โโโ requirements.txt # Python dependencies
โโโ KNN.py # Main KNN implementation
โโโ data_utils.py # Data loading utilities
โโโ KNN.ipynb # Jupyter notebook demo
โโโ data/ # Dataset storage (not included)
โโโ cifar-10-batches-py/
๐งช Testing
# Run basic functionality test
python -c "
from KNN import KNearestNeighbor
import numpy as np
# Test basic functionality
X_train = np.random.rand(100, 5)
y_train = np.random.randint(0, 3, 100)
X_test = np.random.rand(20, 5)
Classifier = KNearestNeighbor()
Classifier.train(X_train, y_train)
predictions = Classifier.predict(X_test, k=3)
print('Test passed! Predictions shape:', predictions.shape)
"๐ Advanced Features
Cross-Validation for Optimal k
def find_optimal_k(X_train, y_train, X_val, y_val, k_range=range(1, 21)):
"""Find optimal k value using validation set"""
Classifier = KNearestNeighbor()
Classifier.train(X_train, y_train)
accuracies = []
for k in k_range:
predictions = Classifier.predict(X_val, k=k)
accuracy = np.mean(predictions == y_val)
accuracies.append(accuracy)
best_k = k_range[np.argmax(accuracies)]
return best_k, max(accuracies)
# Usage
best_k, best_accuracy = find_optimal_k(X_train, y_train, X_test, y_test)
print(f"Optimal k: {best_k}, Best accuracy: {best_accuracy:.3f}")Different Distance Computation Methods
# Method 0: No loops (vectorized) - fastest
predictions_0 = Classifier.predict(X_test, k=5, num_loops=0)
# Method 1: One loop (partially vectorized)
predictions_1 = Classifier.predict(X_test, k=5, num_loops=1)
# Method 2: Two loops (naive) - slowest
predictions_2 = Classifier.predict(X_test, k=5, num_loops=2)
# All methods should give the same results
assert np.array_equal(predictions_0, predictions_1)
assert np.array_equal(predictions_0, predictions_2)๐ Algorithm Details
Distance Computation Methods
- No Loops (Vectorized): Uses the identity
||a-b||ยฒ = ||a||ยฒ + ||b||ยฒ - 2aยทbfor maximum efficiency - One Loop: Vectorizes computation for each test sample but loops over test samples
- Two Loops: Naive implementation with nested loops for educational purposes
Mathematical Foundation
The KNN algorithm works by:
- Computing distances between test and training samples
- Finding the k nearest neighbors
- Predicting the most common class label (majority voting)
The L2 (Euclidean) distance is computed as:
distance = sqrt(sum((a_i - b_i)ยฒ))
๐ Performance Tips
- Use
num_loops=0for production (fastest) - Use
num_loops=2for understanding the algorithm - Normalize your data for better performance
- Choose k carefully - too small may be noisy, too large may be too smooth
๐ค Contributing
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
๐ License
This project is licensed under the MIT License - see the LICENSE file for details.
๐ Acknowledgments
- CIFAR-10 dataset from CIFAR
- Inspired by CS231n course materials
- Built with NumPy and Matplotlib
๐ Contact
- GitHub: @HoomKh
- Email: khoshbinhooman@example.com
๐ง Requirements
- Python 3.7+
- NumPy
- Matplotlib
- Jupyter (for notebooks)
๐ References
โญ Star this repository if you found it helpful!
On this page
Languages
Jupyter Notebook97.9%Python2.1%
Contributors
Created August 28, 2025
Updated August 28, 2025