Tensorflow simple neural network has very bad performance in regression (curve fitting) problem

I'm trying to implement a very simple one layered MLP for a toy regression problem with one variable (dimension = 1) and one target (dimension = 1). It's a simple curve fitting problem with zero noise.

Matlab - Deep Learning Toolbox

Using levenberg-marquardt backpropagation on a MLP with a single hidden layer with 100 neurons and hyperbolic tangent activation I got pretty decent performance with almost zero effort:

MSE = 7.18e-08

Here's a plot of the fitting:

Edit: This is the working matlab code. Please note that the "feedforwardnet(100)" function only produces a network object with one hidden layer with 100 neurons and tanh activation and output layer with linear activation:

net = feedforwardnet(100);
net.trainParam.min_grad = 1e-25;
net.trainParam.max_fail = 50;
net.trainParam.epochs = 500;
%net1.trainParam.showWindow = false;
net.inputs{1,1}.processFcns = {};
net.outputs{1,2}.processFcns = {};
net = train(net,Train_Vars,Train_Target);
Test_Predictions = net(Test_Vars);
Accuracy = msemetric({Test_Predictions},{Test_Target});

Python - TensorFlow - Keras

With the same network settings I used in matlab there's almost no training. No matter how hard I try to tune the training parameters or switch the optimizer.

MSE = 0.12900154

I can obtain something better using RELU activations for the hidden layer but we're still far:

MSE = 0.0582045

This is the code I used in Python:

#  IMPORT LIBRARIES
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras


#  IMPORT DATASET FROM CSV FILE, SHUFFLE TRAINING SET 
#  AND MAKE NUMPY ARRAY FOR TRAINING (DATA ARE ALREADY NORMALIZED)
dataset_path = "C:/Users/Rob/Desktop/Learning1.csv"
Learning_Dataset = pd.read_csv(dataset_path
                          , comment='\t',sep=","
                          ,skipinitialspace=False)
Learning_Dataset = Learning_Dataset.sample(frac = 1)  # SHUFFLING


test_dataset_path = "C:/Users/Rob/Desktop/Test1.csv"
Test_Dataset = pd.read_csv(test_dataset_path
                          , comment='\t',sep=","
                          ,skipinitialspace=False)


Learning_Target = Learning_Dataset.pop('Target')
Test_Target = Test_Dataset.pop('Target')

Learning_Dataset = np.array(Learning_Dataset,dtype = "float32")
Test_Dataset = np.array(Test_Dataset,dtype = "float32")
Learning_Target = np.array(Learning_Target,dtype = "float32")
Test_Target = np.array(Test_Target,dtype = "float32")





#  DEFINE SIMPLE MLP MODEL
inputs = tf.keras.layers.Input(shape=(1,))
x = tf.keras.layers.Dense(100, activation='relu')(inputs)
y = tf.keras.layers.Dense(1)(x)
model = tf.keras.Model(inputs=inputs, outputs=y)




#  TRAIN MODEL
opt = tf.keras.optimizers.RMSprop(learning_rate = 0.001,
                                  rho = 0.9,
                                  momentum = 0.0,
                                  epsilon = 1e-07,
                                  centered = False)
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=100)
model.compile(optimizer = opt,
              loss = 'mse',
              metrics = ['mse'])


model.fit(Learning_Dataset, 
          Learning_Target,  
          epochs=500, 
          validation_split = 0.2, 
          verbose=0, 
          callbacks=[early_stop], 
          shuffle = False, 
          batch_size = 100)




#  INFERENCE AND CHECK ACCURACY
Predictions = model.predict(Test_Dataset)
Predictions = Predictions.reshape(10000)

print(np.square(np.subtract(Test_Target,Predictions)).mean()) #  MSE

plt.plot(Test_Dataset,Test_Target,'o',Test_Dataset,Predictions,'o')
plt.legend(('Target','Model Prediction'))
plt.show()

What am i doing wrong?

Thanks

Topic keras tensorflow matlab

Category Data Science


I am not sure which network training function is your MATLAB feedforwardnet is using. Also, most likely the activation function in your MATLAB net is 'tansig'. Try using 'tanh' activation function in your python/keras code.

About

Geeks Mental is a community that publishes articles and tutorials about Web, Android, Data Science, new techniques and Linux security.