Help Creating a XOR Neural Network in Java?

I have been trying to create a neural network in Java, but it doesn't quite work as intended. I am using a XOR test before I move on to more advanced problems, and it doesn't seem to be learning much. I may have the algorithms wrong, but as far as I can tell, they all seem fine (I am using a tutorial on Brilliant.org - https://brilliant.org/wiki/backpropagation/). I've provided my Network and Main class below. Thank you for any help!

import java.util.*;

////Neural network creator
public class Network {
  
  //Constant values for activation functions
  public static final int identity = 0;
  public static final int sigmoid = 1;
  public static final int tanh = 2;

  ////Arraylists for each layers' nodes, weights, and biases
  //List to hold all neural network node values (output)
  private ArrayListArrayListDouble nodes = new ArrayListArrayListDouble();
  //List for all error (delta) of nodes
  private ArrayListArrayListDouble error = new ArrayListArrayListDouble();
  //List for all biases from layer to layer (index 0 affects 0, index 1 affects 1... index n affects n to n)
  private ArrayListArrayListDouble biases = new ArrayListArrayListDouble();
  //List for changes to biases
  private ArrayListArrayListDouble biasChanges = new ArrayListArrayListDouble();
  //List for all weights (index n,k in layer l affects node n in layer l to node k in layer l+1)
  private ArrayListArrayListArrayListDouble weights = new ArrayListArrayListArrayListDouble();
  //List for changes to each weight
  private ArrayListArrayListArrayListDouble weightChanges = new ArrayListArrayListArrayListDouble();
  //List for activations of nodes for backpropagation algorithm
  private ArrayListArrayListDouble nodeActivations = new ArrayListArrayListDouble();
  //List for activation functions (first layer is always identity)
  private ArrayListInteger activations = new ArrayListInteger();

  //Learning rate
  private double LEARNING_RATE;

  ////Constructor
  /*
  * firstLayerLength: number of nodes in first layer
  * learningRate: rate at which network changes its values to fit examples
  */
  public Network(int firstLayerLength, double learningRate) {
    //Set learning rate
    LEARNING_RATE = learningRate;
    //Add only node layer and activation (identity)
    addNodeLayer(firstLayerLength);
    addActivation(identity);
    //Add empty bias layer (to make calculations easier)
    addBiasLayer(0);
  }

  ////Modifiers

  ////Add a layer of nodes to the network
  //For activation use Network.identity or Network.sigmoid... for exact index of function
  public void addLayer(int length, int activation) {
    //Add each layer
    addNodeLayer(length);
    addBiasLayer(length);
    addWeightLayer(length);
    addActivation(activation);
  }
  //Hidden functions to add each individual layer
  private void addNodeLayer(int length) {
    //Add ArrayList of size length to nodes ArrayList and nodeActivations ArrayList
    nodes.add(new ArrayListDouble());
    for (int i = 0; i  length; i++) nodes.get(nodes.size()-1).add(0.0);
    nodeActivations.add(new ArrayListDouble());
    for (int i = 0; i  length; i++) nodeActivations.get(nodeActivations.size()-1).add(0.0);
    error.add(new ArrayListDouble());
    for (int i = 0; i  length; i++) error.get(error.size()-1).add(0.0);
  }
  private void addBiasLayer(int length) {
    biasChanges.add(new ArrayListDouble());
    for (int i = 0; i  length; i++) biasChanges.get(biasChanges.size()-1).add(0.0);

    //Create ArrayList
    ArrayListDouble newList = new ArrayListDouble();
    //Randomly instantiate each value between -1 and 1
    for (int i = 0; i  length; i++) {
      newList.add(Math.random() * 2 - 1);
    }

    //Add ArrayList biases ArrayList
    biases.add(newList);
  }
  private void addWeightLayer(int length) {
    //Instantiate new Arraylist of size layer-1.size - 1
    ArrayListArrayListDouble newList = new ArrayListArrayListDouble();

    //Randomly instantiate each index
    for (int i = 0; i  nodes.get(nodes.size()-2).size(); i++) {
      ArrayListDouble newList2 = new ArrayListDouble();
      for (int n = 0; n  length; n++) {
        newList2.add(Math.random() * 2 - 1);
      }
      newList.add(newList2);
    }

    //Add newList to weights array
    weights.add(newList);
    newList = new ArrayListArrayListDouble();

    //Instantiate to 0
    for (int i = 0; i  nodes.get(nodes.size()-2).size(); i++) {
      ArrayListDouble newList2 = new ArrayListDouble();
      for (int n = 0; n  length; n++) {
        newList2.add(0.0);
      }
      newList.add(newList2);
    }

    weightChanges.add(newList);
    
  }
  private void addActivation(int activationIndex) {
    //Add activation
    activations.add(activationIndex);
  }

  //Print weights array (for debugging)
  public void printWeights() {
    System.out.println(weights);
  }
  //Print nodes array (debugging)
  public void printNodes() {
    System.out.println(nodes);
  }
  //Print biases array (debugging)
  public void printBiases() {
    System.out.println(biases);
  }
  //Print biasChanges array (debugging)
  public void printBiasChanges() {
    System.out.println(biasChanges);
  }
  //Print weightChanges array (debugging)
  public void printWeightChanges() {
    System.out.println(weightChanges);
  }


  //Feed forward and find error terms
  public void feedForward(double[] values, double[] answer) {

    //Set intput nodes
    for (int i = 0; i  values.length; i++) {
      nodes.get(0).set(i, values[i]);
      nodeActivations.get(0).set(i, values[i]);
    }
    //Loop through each layer, feeding the values forward
    for (int i = 1; i  nodes.size(); i++) {
      //Loop through each value of the layer
      for (int n = 0; n  nodes.get(i).size(); n++) {
        //Reset node value
        nodes.get(i).set(n, 0.0);
        //Loop through each value of the layer before
        for (int x = 0; x  nodes.get(i-1).size(); x++) {
          //Add previous activation times weight
          double newValue = nodes.get(i-1).get(x) * weights.get(i-1).get(x).get(n);
          nodes.get(i).set(n, nodes.get(i).get(n) + newValue);
        }
        //Activation function and bias
        nodeActivations.get(i).set(n, nodes.get(i).get(n) + biases.get(i).get(n));
        switch (activations.get(i)) {
          case identity:
            nodes.get(i).set(n, nodeActivations.get(i).get(n));
          break;
          case sigmoid:
            nodes.get(i).set(n, sigmoid(nodeActivations.get(i).get(n)));
          break;
          case tanh:
            nodes.get(i).set(n, tanh(nodeActivations.get(i).get(n)));
          break;
        }
      }
    }

    ////Backpropagate and add changes to total changes
    //Output error
    for (int i = 0; i  nodes.get(nodes.size()-1).size(); i++) {
      double singleError = 0;
      switch (activations.get(activations.size()-1)) {
        case identity:
          singleError = nodes.get(nodes.size()-1).get(i) - answer[i];
        break;
        case sigmoid:
          //if (nodeActivations.get(nodes.size()-1).get(i)  0) System.out.println(line 183);
          singleError = aSigmoid(nodes.get(nodes.size()-1).get(i)) * (nodes.get(nodes.size()-1).get(i) - answer[i]);
        break;
        case tanh:
          singleError = aTanh(nodeActivations.get(nodes.size()-1).get(i)) * (nodes.get(nodes.size()-1).get(i) - answer[i]);
        break;
      }
      error.get(error.size()-1).set(i, singleError);
      //Final weights backpropagation
      for (int n = 0; n  nodes.get(nodes.size()-2).size(); n++) {
        double change = error.get(error.size()-1).get(i) * nodes.get(nodes.size()-2).get(n);
        weightChanges.get(nodes.size()-2).get(n).set(i, weightChanges.get(nodes.size()-2).get(n).get(i) + change);
      }
    }

    //Every layer before output
    for (int layer = nodes.size() - 2; layer  0; layer--) {
      //Go through each value in layer
      for (int i = 0; i  nodes.get(layer).size(); i++) {
        double singleError = 0;
        switch (activations.get(layer)) {
          case identity:
            singleError = nodes.get(layer).get(i);
          break;
          case sigmoid:
            //if (nodeActivations.get(layer).get(i)  0) System.out.println(line 208);
            singleError = aSigmoid(nodes.get(layer).get(i));
          break;
          case tanh:
            singleError = aTanh(nodeActivations.get(layer).get(i));
          break;
        }
        //Loop through layer after
        double sum = 0;
        for (int n = 0; n  nodes.get(layer+1).size(); n++) {
          sum += error.get(layer+1).get(n) * weights.get(layer).get(i).get(n);
        }
        //Multiply single error by sum of errors times weights
        error.get(layer).set(i, singleError * sum);
        //Backpropagate to find change
        //Biases
        biasChanges.get(layer).set(i, biasChanges.get(layer).get(i) + error.get(layer).get(i));
        for (int n = 0; n  nodes.get(layer-1).size(); n++) {
          //Weights
          double change = error.get(layer).get(i) * nodes.get(layer-1).get(n);
          weightChanges.get(layer-1).get(n).set(i, weightChanges.get(layer-1).get(n).get(i) + change);
        }
      }
    }
    //After all backpropagations have been done, neural network must change weights
  }

  //Returns true if highest value of intputted array is highest value of output layer
  public boolean correctAnswerClass(double[] values) {
    int answer = 0;
    int highest = 0;
    for (int i = 1; i  values.length; i++) {
      if (values[i]  values[answer]) answer = i;
      if (nodes.get(nodes.size()-1).get(i)  nodes.get(nodes.size()-1).get(highest)) highest = i;
    }
    if (answer == highest) return true;
    return false;
  }

  //Returns error array between values
  public double[] errorArray(double[] values) {
    double[] errors = new double[values.length];
    for (int i = 0; i  values.length; i++) {
      errors[i] = values[i] - nodes.get(nodes.size()-1).get(i);
    }
    return errors;
  }

  //Implement weight changes
  public void learn(int batchSize) {
    //Loop through each layer
    for (int layer = 0; layer  nodes.size(); layer++) {
      for (int i = 0; i  nodes.get(layer).size(); i++) {

        //Change bias if layer != 0
        if (layer != 0) {
          double newBias = biases.get(layer).get(i) - biasChanges.get(layer).get(i) / batchSize * LEARNING_RATE;
          biases.get(layer).set(i, newBias);
          //Reset changes to zero
          biasChanges.get(layer).set(i, 0.0);
        }

        if (layer != nodes.size()-1) {
          for (int n = 0; n  nodes.get(layer+1).size(); n++) {
            //Change weight by negative of the gradient
            double newWeight = weights.get(layer).get(i).get(n) - weightChanges.get(layer).get(i).get(n) / batchSize * LEARNING_RATE;
            weights.get(layer).get(i).set(n, newWeight);
            //Reset changes
            weightChanges.get(layer).get(i).set(n, 0.0);
          }
        }
        
      }
    }
  }

  //Activation functions
  private double sigmoid(double x) {
    return 1 / (1 + Math.pow(2.71828, -x));
  }
  private double aSigmoid(double x) {
    return x * (1 - x);
  }
  private double tanh(double x) {
    return Math.tanh(x);
  }
  private double aTanh(double x) {
    return Math.log((1 + x) / (1 - x)) / 2;
  }

}
import java.util.Scanner;

class Main {
  public static void main(String[] args) {

    int batchSize = 100;
    int trainExamples = 100000;
    int testExamples = 1000;
    double accuracyBenchmark = 0.1;

    Network bobby = new Network(2, 0.01);
    bobby.addLayer(2, Network.sigmoid);
    bobby.addLayer(1, Network.identity);

    //Loop through all examples
    for (int i = 0; i  trainExamples; i += batchSize) {
      double accuracy = 0;
      for (int n = 0; n  batchSize; n++) {
        int firstValue = 0;
        if (Math.random()  0.5) firstValue++;
        int secondValue = 0;
        if (Math.random()  0.5) secondValue++;
        double[] output = {0};
        if (firstValue == 1 ^ secondValue == 1) output[0] = 1;
        double[] input = {firstValue, secondValue};
        bobby.feedForward(input, output);

        double[] array = bobby.errorArray(output);
        if (Math.abs(array[0])  accuracyBenchmark) accuracy++;
      }
      
      accuracy /= batchSize;
      System.out.println(Accuracy:  + accuracy);

      bobby.learn(batchSize);
    }
    
  }
}
```

Topic backpropagation java neural-network machine-learning

Category Data Science

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.