Keras RNN (batch_size

I created an RNN model for text classification with the LSTM layer, but when I put the batch_size in the fit method, my model trained on the whole batch instead of just the mini-batch _size. This also happened when I used GRU and Bidirectional layer instead of LSTM. What could be wrong?

def create_rnn_lstm():
    input_layer = layers.Input((70, ))
    embedding_layer = layers.Embedding(len(word_index) + 1, 300, weights=[embedding_matrix], trainable=False)(input_layer)
    embedding_layer = layers.SpatialDropout1D(0.3)(embedding_layer)
    lstm_layer = layers.LSTM(100)(embedding_layer)
    output_layer1 = layers.Dense(70, activation=relu)(lstm_layer)
    output_layer1 = layers.Dropout(0.25)(output_layer1)
    output_layer2 = layers.Dense(2, activation=softmax)(output_layer1)
    model = models.Model(inputs=input_layer, outputs=output_layer2)
    model.compile(optimizer=optimizers.Adam(), loss='binary_crossentropy')
    return model
LSTM_classifier = create_rnn_lstm()
LSTM_classifier.fit(X_train_seq, y_train, batch_size=128, epochs = 10, shuffle=True)

Topic gru lstm keras rnn

Category Data Science


In Keras, with verbose=1 (default parameter of the fit method) will display the total number of samples, not the batch number.

If your batch size is 128, then the progress bar will jump by multiples of 128. You can try to change batch_size parameter to 13714, and you will see the progress bar jumping straight from 0/13714 to 13714/13714, since you would have defined a batch size of the size of your whole training dataset.

Finally, if you want to train your model only on one mini batch, you can extract it from your training set and fit your model on it.

import random
indices = random.sample(range(13714), 128)

X_train_seq_one_batch = X_train_seq[indices]
y_train_one_batch = y_train[indices] 
LSTM_classifier.fit(X_train_seq_one_batch, y_train_one_batch, batch_size=128, epochs = 10)

The above code will train your network only on one random batch only. random library is used to draw a random batch of indices in your training set.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.