Force Matching in Coarse Grained Molecular Dynamics with Jax - Forces do not match when neglecting energy loss
I am currently exploring force matching approaches for molecular dynamic simulations. As I am still in an exploration state, I'd tried investigated
Force Matching Neural Network Colab Notebook
corresponding to Unveiling the predictive power of static structure in glassy systems.
They train a graph neural network to match to estimate forces from positions.
Therefore, they calculate a loss where they match energy and forces.
Loss = $(energy_{predicted} - energy_{target})^2 + ( Forces_{predicted} - Forces_{target})^2$
where the Energy is defined as $U(x,\phi)$ and the force is defined as $-\frac{dU}{dx} = F$.
When neglecting the energy loss and purely matching forces, the prediction seems to converge to 0. This is not intuitive, as Coarse Graining Molecular Dynamics with Graph Neural Networks seem to train their GNN just with Force matching. The question is: Does anybody knows why the neural networks behave like this.
To reproduce my observations:
Change the loss-function in 1 to:
@jit
def loss(params, R, targets):
  return force_loss(params, R, targets[1]) 
and the training to:
train_epochs = 20
opt_state = opt.init(params)
train_energy_error = []
test_energy_error = []
for iteration in range(train_epochs):
  train_energy_error += [float(np.sqrt(force_loss(params, batch_Rs[0], batch_Fs[0])))]
  test_energy_error += [float(np.sqrt(force_loss(params, test_positions, test_forces)))]
 
  draw_training(params)
  params, opt_state = update_epoch((params, opt_state), 
                                   (batch_Rs, (batch_Es, batch_Fs)))
  onp.random.shuffle(lookup)
  batch_Rs, batch_Es, batch_Fs = make_batches(lookup)
as well as the visualization of the trainging:
 def draw_training(params):
  display.clear_output(wait=True)
  display.display(plt.gcf())
  plt.subplot(1, 2, 1)
  plt.semilogy(train_energy_error)
  plt.semilogy(test_energy_error)
  plt.xlim([0, train_epochs])
  format_plot('Epoch', '$L$')
  plt.subplot(1, 2, 2)
  predicted = vmap(force_fn, (None, 0))(params, example_positions).reshape((-1,))
  plt.plot(example_forces.reshape((-1,)), predicted, 'o')
  #plt.plot(np.linspace(-400, -300, 10), np.linspace(-400, -300, 10), '--')
  format_plot('$E_{label}$', '$E_{prediction}$')
  finalize_plot((2, 1))
  plt.show()
to reproduce my observations.
Topic graph-neural-network convergence neural-network machine-learning
Category Data Science