nanoTabPFN
The purpose of this repository is to provide a fully open source playground for tabular foundation models.
It contains a much smaller and simpler implementation of the TabPFNv2 architecture as well as a training loop and code for loading data that was pre-generated by a prior. We are planning to rapidly extend the repository with more features (e.g. regression, missing values, categorical features), prior interfaces and architectures.
It is supposed to be a good starting point for students and researchers that are interested in learning about how TabPFN works under the hood.
Clone the repository, afterwards install dependencies via:
pip install -e .
We offer the same interface as TabPFN:
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from nanotabpfn import NanoTabPFNClassifier
# Load data
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)
# Initialize a classifier
clf = NanoTabPFNClassifier()
clf.fit(X_train, y_train)
# Predict probabilities
prediction_probabilities = clf.predict_proba(X_test)
print("ROC AUC:", roc_auc_score(y_test, prediction_probabilities[:, 1]))
# Predict labels
predictions = clf.predict(X_test)
print("Accuracy", accuracy_score(y_test, predictions))Our Code
nanotabpfn/model.py contains the implementation of the architecture in less than 250 lines of code. nanotabpfn/train.py implements a simple training loop in under 100 lines and nanotabpfn/priors.py implements a dataloader that allows you to load a dump pre-generated from a prior.
We will release multiple dumps of different scales soon. We also offer an interface where you can provide your own get_batch function.
Pretrain your own nanoTabPFN
First we download 100k pre-generated datasets with 50 datapoints, 3 features and up to 3 classes each from here.
Then run python pretrain_classification.py -epochs 80 -steps 25 -batchsize 50 -priordump 50x3_3_100k_classification.h5