Graph Neural Network fails at generalizing on unseen graph topologies
I'm using PytorchGeometric to train a graph convolutional network for regression over nodes problem (the graph models physical phenomena in the network of sensors; the network of sensors is actually the network of measurements distributed across the power grid (powers, currents, voltages), and the goal of the GNN is to predict some unmeasured variables in the graph.). In the training dataset there graphs with different topologies (i.e. different edge_index tensors), and each of which has input and label tensors, which consist of float values for each node in the graph.
The training curves look good, the loss curve is converging to a small value and there are no exploding nor vanishing gradients.
There are 1000 different graph topologies in the training set and around 2000 training samples. So, when the trained model is tested on graphs whose topology occurs 2 or 3 times in the training set, the results are great, almost the same as the test sample labels for each node (the input values of nodes are different, only the topology is already seen). When the trained model is tested on graphs whose topology occurs one in the training set, the results are slightly worse.
But when the model is tested on the unseen (but similar) graph topology, the results are completely wrong. All of the graphs in the training and test set are generated synthetically, using the same random process, so the topologies come from the same graph distribution.
Since the graph models physical phenomena in the network of sensors, I would expect that the GNN should be able to learn how sensor information impacts the neighboring variables, even for the unseen graphs.
I've tried going deeper into the graph and adding the convolutional layers. I used the convolutional layers: https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/conv/gcn_conv.py#L188
Did someone have a similar problem? Are there some GNN models that are better at generalizing on unseen graph topologies?
Cheers!
Topic graph-neural-network generalization machine-learning
Category Data Science