HA
HawkAaron/mxnet-transducer
Fast parallel RNN-Transducer.
mxnet-transducer
A fast parallel implementation of RNN Transducer (Graves 2013 joint network), on both CPU and GPU for mxnet.
GPU version is now available for Graves2012 add network.
Install and Test
First get mxnet and the code:
git clone --recursive https://github.com/apache/incubator-mxnet
git clone https://github.com/HawkAaron/mxnet-transducerCopy all files into mxnet dir:
cp -r mxnet-transducer/rnnt* incubator-mxnet/src/operator/contrib/Then follow the installation instructions of mxnet:
https://mxnet.incubator.apache.org/install/index.html
Finally, add Python API into /path/to/mxnet_root/mxnet/gluon/loss.py:
class RNNTLoss(Loss):
def __init__(self, batch_first=True, blank_label=0, weight=None, **kwargs):
batch_axis = 0 if batch_first else 2
super(RNNTLoss, self).__init__(weight, batch_axis, **kwargs)
self.batch_first = batch_first
self.blank_label = blank_label
def hybrid_forward(self, F, pred, label, pred_lengths, label_lengths):
if not self.batch_first:
pred = F.transpose(pred, (2, 0, 1, 3))
loss = F.contrib.RNNTLoss(pred, label.astype('int32', False),
pred_lengths.astype('int32', False),
label_lengths.astype('int32', False),
blank_label=self.blank_label)
return lossFrom the repo test with:
python test/test.py 10 300 100 50 --mxReference
On this page
Languages
C++73.8%Python22.0%Cuda3.6%C0.5%
Contributors
MIT License
Created April 18, 2018
Updated February 4, 2023