Classifying nodes with vanilla graph neural networks
Instead of directly introducing well-known GNN architectures, let’s try to build our own model to understand the thought process behind GNNs. First, we need to go back to the definition of a simple linear layer.
A basic neural network layer corresponds to a linear transformation , where is the input vector of node and is the weight matrix. In PyTorch, this equation can be implemented with the torch.mm()
function, or with the nn.Linear
class that adds other parameters such as biases.
With our graph datasets, the input vectors are node features. It means that nodes are completely separate from each other. This is not enough to capture a good understanding of the graph: like a pixel in an image, the context of a node is essential to understand it. If you look at a group of pixels instead of a single one, you can recognize edges, patterns, and so on. Likewise, to understand a node, you need to look at its neighborhood...