Help Creating a XOR Neural Network in Java?
I have been trying to create a neural network in Java, but it doesn't quite work as intended. I am using a XOR test before I move on to more advanced problems, and it doesn't seem to be learning much. I may have the algorithms wrong, but as far as I can tell, they all seem fine (I am using a tutorial on Brilliant.org - https://brilliant.org/wiki/backpropagation/). I've provided my Network and Main class below. Thank you for any help!
import java.util.*;
////Neural network creator
public class Network {
//Constant values for activation functions
public static final int identity = 0;
public static final int sigmoid = 1;
public static final int tanh = 2;
////Arraylists for each layers' nodes, weights, and biases
//List to hold all neural network node values (output)
private ArrayListArrayListDouble nodes = new ArrayListArrayListDouble();
//List for all error (delta) of nodes
private ArrayListArrayListDouble error = new ArrayListArrayListDouble();
//List for all biases from layer to layer (index 0 affects 0, index 1 affects 1... index n affects n to n)
private ArrayListArrayListDouble biases = new ArrayListArrayListDouble();
//List for changes to biases
private ArrayListArrayListDouble biasChanges = new ArrayListArrayListDouble();
//List for all weights (index n,k in layer l affects node n in layer l to node k in layer l+1)
private ArrayListArrayListArrayListDouble weights = new ArrayListArrayListArrayListDouble();
//List for changes to each weight
private ArrayListArrayListArrayListDouble weightChanges = new ArrayListArrayListArrayListDouble();
//List for activations of nodes for backpropagation algorithm
private ArrayListArrayListDouble nodeActivations = new ArrayListArrayListDouble();
//List for activation functions (first layer is always identity)
private ArrayListInteger activations = new ArrayListInteger();
//Learning rate
private double LEARNING_RATE;
////Constructor
/*
* firstLayerLength: number of nodes in first layer
* learningRate: rate at which network changes its values to fit examples
*/
public Network(int firstLayerLength, double learningRate) {
//Set learning rate
LEARNING_RATE = learningRate;
//Add only node layer and activation (identity)
addNodeLayer(firstLayerLength);
addActivation(identity);
//Add empty bias layer (to make calculations easier)
addBiasLayer(0);
}
////Modifiers
////Add a layer of nodes to the network
//For activation use Network.identity or Network.sigmoid... for exact index of function
public void addLayer(int length, int activation) {
//Add each layer
addNodeLayer(length);
addBiasLayer(length);
addWeightLayer(length);
addActivation(activation);
}
//Hidden functions to add each individual layer
private void addNodeLayer(int length) {
//Add ArrayList of size length to nodes ArrayList and nodeActivations ArrayList
nodes.add(new ArrayListDouble());
for (int i = 0; i length; i++) nodes.get(nodes.size()-1).add(0.0);
nodeActivations.add(new ArrayListDouble());
for (int i = 0; i length; i++) nodeActivations.get(nodeActivations.size()-1).add(0.0);
error.add(new ArrayListDouble());
for (int i = 0; i length; i++) error.get(error.size()-1).add(0.0);
}
private void addBiasLayer(int length) {
biasChanges.add(new ArrayListDouble());
for (int i = 0; i length; i++) biasChanges.get(biasChanges.size()-1).add(0.0);
//Create ArrayList
ArrayListDouble newList = new ArrayListDouble();
//Randomly instantiate each value between -1 and 1
for (int i = 0; i length; i++) {
newList.add(Math.random() * 2 - 1);
}
//Add ArrayList biases ArrayList
biases.add(newList);
}
private void addWeightLayer(int length) {
//Instantiate new Arraylist of size layer-1.size - 1
ArrayListArrayListDouble newList = new ArrayListArrayListDouble();
//Randomly instantiate each index
for (int i = 0; i nodes.get(nodes.size()-2).size(); i++) {
ArrayListDouble newList2 = new ArrayListDouble();
for (int n = 0; n length; n++) {
newList2.add(Math.random() * 2 - 1);
}
newList.add(newList2);
}
//Add newList to weights array
weights.add(newList);
newList = new ArrayListArrayListDouble();
//Instantiate to 0
for (int i = 0; i nodes.get(nodes.size()-2).size(); i++) {
ArrayListDouble newList2 = new ArrayListDouble();
for (int n = 0; n length; n++) {
newList2.add(0.0);
}
newList.add(newList2);
}
weightChanges.add(newList);
}
private void addActivation(int activationIndex) {
//Add activation
activations.add(activationIndex);
}
//Print weights array (for debugging)
public void printWeights() {
System.out.println(weights);
}
//Print nodes array (debugging)
public void printNodes() {
System.out.println(nodes);
}
//Print biases array (debugging)
public void printBiases() {
System.out.println(biases);
}
//Print biasChanges array (debugging)
public void printBiasChanges() {
System.out.println(biasChanges);
}
//Print weightChanges array (debugging)
public void printWeightChanges() {
System.out.println(weightChanges);
}
//Feed forward and find error terms
public void feedForward(double[] values, double[] answer) {
//Set intput nodes
for (int i = 0; i values.length; i++) {
nodes.get(0).set(i, values[i]);
nodeActivations.get(0).set(i, values[i]);
}
//Loop through each layer, feeding the values forward
for (int i = 1; i nodes.size(); i++) {
//Loop through each value of the layer
for (int n = 0; n nodes.get(i).size(); n++) {
//Reset node value
nodes.get(i).set(n, 0.0);
//Loop through each value of the layer before
for (int x = 0; x nodes.get(i-1).size(); x++) {
//Add previous activation times weight
double newValue = nodes.get(i-1).get(x) * weights.get(i-1).get(x).get(n);
nodes.get(i).set(n, nodes.get(i).get(n) + newValue);
}
//Activation function and bias
nodeActivations.get(i).set(n, nodes.get(i).get(n) + biases.get(i).get(n));
switch (activations.get(i)) {
case identity:
nodes.get(i).set(n, nodeActivations.get(i).get(n));
break;
case sigmoid:
nodes.get(i).set(n, sigmoid(nodeActivations.get(i).get(n)));
break;
case tanh:
nodes.get(i).set(n, tanh(nodeActivations.get(i).get(n)));
break;
}
}
}
////Backpropagate and add changes to total changes
//Output error
for (int i = 0; i nodes.get(nodes.size()-1).size(); i++) {
double singleError = 0;
switch (activations.get(activations.size()-1)) {
case identity:
singleError = nodes.get(nodes.size()-1).get(i) - answer[i];
break;
case sigmoid:
//if (nodeActivations.get(nodes.size()-1).get(i) 0) System.out.println(line 183);
singleError = aSigmoid(nodes.get(nodes.size()-1).get(i)) * (nodes.get(nodes.size()-1).get(i) - answer[i]);
break;
case tanh:
singleError = aTanh(nodeActivations.get(nodes.size()-1).get(i)) * (nodes.get(nodes.size()-1).get(i) - answer[i]);
break;
}
error.get(error.size()-1).set(i, singleError);
//Final weights backpropagation
for (int n = 0; n nodes.get(nodes.size()-2).size(); n++) {
double change = error.get(error.size()-1).get(i) * nodes.get(nodes.size()-2).get(n);
weightChanges.get(nodes.size()-2).get(n).set(i, weightChanges.get(nodes.size()-2).get(n).get(i) + change);
}
}
//Every layer before output
for (int layer = nodes.size() - 2; layer 0; layer--) {
//Go through each value in layer
for (int i = 0; i nodes.get(layer).size(); i++) {
double singleError = 0;
switch (activations.get(layer)) {
case identity:
singleError = nodes.get(layer).get(i);
break;
case sigmoid:
//if (nodeActivations.get(layer).get(i) 0) System.out.println(line 208);
singleError = aSigmoid(nodes.get(layer).get(i));
break;
case tanh:
singleError = aTanh(nodeActivations.get(layer).get(i));
break;
}
//Loop through layer after
double sum = 0;
for (int n = 0; n nodes.get(layer+1).size(); n++) {
sum += error.get(layer+1).get(n) * weights.get(layer).get(i).get(n);
}
//Multiply single error by sum of errors times weights
error.get(layer).set(i, singleError * sum);
//Backpropagate to find change
//Biases
biasChanges.get(layer).set(i, biasChanges.get(layer).get(i) + error.get(layer).get(i));
for (int n = 0; n nodes.get(layer-1).size(); n++) {
//Weights
double change = error.get(layer).get(i) * nodes.get(layer-1).get(n);
weightChanges.get(layer-1).get(n).set(i, weightChanges.get(layer-1).get(n).get(i) + change);
}
}
}
//After all backpropagations have been done, neural network must change weights
}
//Returns true if highest value of intputted array is highest value of output layer
public boolean correctAnswerClass(double[] values) {
int answer = 0;
int highest = 0;
for (int i = 1; i values.length; i++) {
if (values[i] values[answer]) answer = i;
if (nodes.get(nodes.size()-1).get(i) nodes.get(nodes.size()-1).get(highest)) highest = i;
}
if (answer == highest) return true;
return false;
}
//Returns error array between values
public double[] errorArray(double[] values) {
double[] errors = new double[values.length];
for (int i = 0; i values.length; i++) {
errors[i] = values[i] - nodes.get(nodes.size()-1).get(i);
}
return errors;
}
//Implement weight changes
public void learn(int batchSize) {
//Loop through each layer
for (int layer = 0; layer nodes.size(); layer++) {
for (int i = 0; i nodes.get(layer).size(); i++) {
//Change bias if layer != 0
if (layer != 0) {
double newBias = biases.get(layer).get(i) - biasChanges.get(layer).get(i) / batchSize * LEARNING_RATE;
biases.get(layer).set(i, newBias);
//Reset changes to zero
biasChanges.get(layer).set(i, 0.0);
}
if (layer != nodes.size()-1) {
for (int n = 0; n nodes.get(layer+1).size(); n++) {
//Change weight by negative of the gradient
double newWeight = weights.get(layer).get(i).get(n) - weightChanges.get(layer).get(i).get(n) / batchSize * LEARNING_RATE;
weights.get(layer).get(i).set(n, newWeight);
//Reset changes
weightChanges.get(layer).get(i).set(n, 0.0);
}
}
}
}
}
//Activation functions
private double sigmoid(double x) {
return 1 / (1 + Math.pow(2.71828, -x));
}
private double aSigmoid(double x) {
return x * (1 - x);
}
private double tanh(double x) {
return Math.tanh(x);
}
private double aTanh(double x) {
return Math.log((1 + x) / (1 - x)) / 2;
}
}
import java.util.Scanner;
class Main {
public static void main(String[] args) {
int batchSize = 100;
int trainExamples = 100000;
int testExamples = 1000;
double accuracyBenchmark = 0.1;
Network bobby = new Network(2, 0.01);
bobby.addLayer(2, Network.sigmoid);
bobby.addLayer(1, Network.identity);
//Loop through all examples
for (int i = 0; i trainExamples; i += batchSize) {
double accuracy = 0;
for (int n = 0; n batchSize; n++) {
int firstValue = 0;
if (Math.random() 0.5) firstValue++;
int secondValue = 0;
if (Math.random() 0.5) secondValue++;
double[] output = {0};
if (firstValue == 1 ^ secondValue == 1) output[0] = 1;
double[] input = {firstValue, secondValue};
bobby.feedForward(input, output);
double[] array = bobby.errorArray(output);
if (Math.abs(array[0]) accuracyBenchmark) accuracy++;
}
accuracy /= batchSize;
System.out.println(Accuracy: + accuracy);
bobby.learn(batchSize);
}
}
}
```
Topic backpropagation java neural-network machine-learning
Category Data Science