Novice machine learner wondering how to interpret big variance in batch error across batches in MNIST perceptron

I'm trying to get a better understanding of basic neural networks by implementing a little framework in C++.

I've started with the classical MNIST exercise. I get to 91% accuracy on the test sample which I'm already pretty happy about.

The thing is, the maximum accuracy is almost reached after just one epoch. The next epochs do not seem to improve the situation much.

I am optimizing using stochastic gradient descent with a batch size of 40.

During the training, I am plotting the average error of each batch, and I have noticed that it stops diminishing rather quickly, and as it stabilizes on average, it oscillates a lot more at the end than it does at the beginning of the training.

Here is a screenshot describing the situation, the batch error is in white, while the test accuracy is in green:

The network has the following configuration:

  • input layer is 784 neurons as per usual, input is normalized by dividing each value by 255
  • there is one single hidden layer of 32 units with biases using a leaky ReLU (0.01) activation function
  • the output layer is naturally composed of 10 units and uses a SoftMax activation function
  • all layers are fully connected, with weights randomly initialized and normalized so that the weights coming to any neuron add up to one
  • I'm using the categorical cross-entropy loss function to compute the error
  • batch size is 40
  • I'm training over the whole 60,000 images over 10 epochs and the training samples are shuffled after each epoch
  • learning rate is set initially to 0.01 and linearly decreases to 0.001 with each epoch (I have tried fiddling with this setting, it doesn't seem to affect the results much)

I would appreciate any pointers as to what might be happening. What should I look for to understand the behavior of the network? What seems wrong to you?

Topic mnist perceptron backpropagation neural-network

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.