RNN-T joint network significantly consuming memory and causing OOM

I am trying to implement the RNN-T model for the ASR task. For the encoder, I am using CNNs, RNNs, and Linear layers with activation. For the decoder, I am using an LSTM. The model is relatively small. It has about 9.8M trainable params. Now the problem is when the training approaches at the end of an epoch, suddenly the memory consumption increases and OOM occurs.

So, I was trying to observe the memory consumption at several points. I found that, after the joint network operations, the memory consumption increases significantly. Here is some trace:

-- BEFORE FORWARD
allocated: 154M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER FEATURES
allocated: 160M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER ENCODER
allocated: 378M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER EMBEDDING
allocated: 378M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER DECODER
allocated: 387M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER JOIN NET
allocated: 4697M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER FINAL LIN
allocated: 5749M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER LOG SOFTMAX
allocated: 6802M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER FORWARD and BEFORE COMPUTE OBJs
allocated: 6773M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER COMPUTE OBJS
allocated: 7825M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER LOSS BACKWARD
allocated: 1211M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER OPTIM STEP
allocated: 1211M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- AFTER ZERO GRAD
allocated: 1211M, max allocated: 13403M, cached: 19736M, max cached: 21898M
-- BEFORE FORWARD (for next batch)
allocated: 154M, max allocated: 13403M, cached: 19736M, max cached: 21898M

I am using speechbrain.nnet.transducer.transducer_joint.Transducer_joint as joint network with following configuration:

join_network: !new:speechbrain.nnet.transducer.transducer_joint.Transducer_joint
    joint: concat  # sum or concat
    nonlinearity: !ref <activation>  # torch.nn.LeakyReLU

I am tracing the memory stats with this function:

def stat_cuda(msg, stat_file):
    print('--', msg, file=stat_file)
    print('allocated: %dM, max allocated: %dM, cached: %dM, max cached: %dM' % (
        torch.cuda.memory_allocated() / 1024 / 1024,
        torch.cuda.max_memory_allocated() / 1024 / 1024,
        torch.cuda.memory_cached() / 1024 / 1024,
        torch.cuda.max_memory_cached() / 1024 / 1024
    ), file=stat_file)

The compute_forward function is as follows:

def compute_forward(self, batch, stage):
    batch = batch.to(self.device)
        
    # get features
    audio_features, self.audio_feature_lengths = self.prepare_features(stage, batch.sig)
    tokens, token_lens = batch.tokens_bos
    stat_cuda('AFTER FEATURES')
    
    # feed data to model
    encoder_output = self.modules.encoder(audio_features.detach())
    stat_cuda('AFTER ENCODER')
    embeddings = self.modules.embedding(tokens.detach())
    stat_cuda('AFTER EMBEDDING')
    decoder_output, _ = self.modules.decoder(embeddings)
    stat_cuda('AFTER DECODER')
    joint = self.modules.join_net(encoder_output.unsqueeze(2), decoder_output.unsqueeze(1))
    stat_cuda('AFTER JOIN NET')
    logits = self.modules.final_lin(joint)
    stat_cuda('AFTER FINAL LIN')
    log_prob = self.hparams.log_softmax(logits)
    stat_cuda('AFTER LOG SOFTMAX')
    
    output = { 'log_prob': log_prob}
    
    if stage != sb.Stage.TRAIN:
        # some stuffs
    
    return output

What can I do here to avoid this problem? Thanks in advance!

@aheba may know more about this.

Hi Shahad_Mahmud, Thank you for your interest of using our RNN-T implementation.

let me share with you some of basics that allows us to move forward together :):
1- the encoder tensor have [B,T, 1, DIM_e], the predictor tensor [B,N, 1, DIM_p]
2- In the case of sum, the operation is very easy because of the BroadCasting operations.
3- In case of concat, it more tricky, we apply this kind of solution: speechbrain/transducer_joint.py at 7443adca55c1c673360b4d382e525413aa1de312 · speechbrain/speechbrain · GitHub (which may more hungry “in term of memory consumption” :slight_smile: ).

I will take a deeper look on that next week, if you want to contribute with us on that part, we can discuss deeper on it. there is a solution from microsoft to reduce the memory consumption (see paper: IMPROVING RNN TRANSDUCER MODELING FOR END-TO-END SPEECH RECOGNITION, section 3).
We can consider working together on that part…

Thank you @aheba and @titouan.parcollet. Actually, I have already found out the cause of the problem after posting the problem here. Let me share what I have found.

While training, I am sorting the dataset according to the audio length. So the longer audios remain at the end. Also, I am using a kernel of an odd size (say F) and padding of (F -1) / 2, so that the encoder output shape is the same as the input feature-length. As far I have explored, this is needed for streaming ASR.

Now, when comparatively at the end of an epoch, I have the longer audio samples, and the joint network shape for a sample is (8, 1301, 51, 2048). So using float32 we need a memory of 8 * 1301 * 51 * 2048 * 32 = 34787033088 bits or about 4.05 GB. And this is exactly the amount of memory the system trying to allocate and facing OOM.

So, what I have done now, I have halved the batch size, reduced the output size to 768 of encoder and predictor. Then I have removed the longer audio (more than 10s) from the dataset as in this paper from Google, the authors have used the dataset with smaller audio files (summary shown in the following figure) for streaming data.
image

Please let me know if I am doing anything wrong. And obviously I would love to work with you to improve this. I’ll also look at the paper. Thank you @aheba for your suggestions again.