CO
connor-mccarthy/tf-ssgan
Implementing GANBERT model from recent deep learning research paper as an easy-to-use Tensorflow model.
A simple API for a complex idea in current deep learning research: semi-supervised classification using generative adversarial networks (SSGANs).
This particular flavor of SSGANs is motivated by and modeled after the 2020 research paper on GANBERT (see citation). See ganbert/ for an implementation of the GANBERT model descibred in the paper using the tf-ssgan library.
Getting Started
Installation
pip install git+https://github.com/connor-mccarthy/tf-ssgan.gitCode
This implementation uses the simple Keras Model API. This makes it easy to implement an SSGAN for diverse classification problems.
from tf_ssgan import SSGAN
# see ./ganbert/model_components.py for generator/discriminator details
generator = make_generator(...)
discriminator = make_discriminator(...)
ssgan = SSGAN(
generator=generator,
discriminator=discriminator,
name="my_ssgan",
)
ssgan.compile(
g_optimizer=tf.keras.optimizers.Adam(1e-4),
d_optimizer=tf.keras.optimizers.Adam(1e-4),
)
ssgan.fit(
train_ds,
validation_data=val_ds,
epochs=1000,
)Reproducing GANBERT
With Python 3.8.10:
python -m venv .venv
source .venv/bin/activate
pip install -r ganbert/ganbert_requirements.txt
python ganbertCitation
GANBERT paper:
@inproceedings{croce-etal-2020-gan,
title = "{GAN}-{BERT}: Generative Adversarial Learning for Robust Text Classification with a Bunch of Labeled Examples",
author = "Croce, Danilo and
Castellucci, Giuseppe and
Basili, Roberto",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-main.191",
pages = "2114--2119"
}On this page
Contributors
Created June 14, 2021
Updated March 8, 2023