Search icon CANCEL
Subscription
0
Cart icon
Your Cart (0 item)
Close icon
You have no products in your basket yet
Save more on your purchases! discount-offer-chevron-icon
Savings automatically calculated. No voucher code required.
Arrow left icon
Explore Products
Best Sellers
New Releases
Books
Videos
Audiobooks
Learning Hub
Free Learning
Arrow right icon
Arrow up icon
GO TO TOP
Applied Deep Learning on Graphs

You're reading from   Applied Deep Learning on Graphs Leverage graph data for business applications using specialized deep learning architectures

Arrow left icon
Product type Paperback
Published in Dec 2024
Publisher Packt
ISBN-13 9781835885963
Length 250 pages
Edition 1st Edition
Arrow right icon
Authors (2):
Arrow left icon
Lakshya Khandelwal Lakshya Khandelwal
Author Profile Icon Lakshya Khandelwal
Lakshya Khandelwal
Subhajoy Das Subhajoy Das
Author Profile Icon Subhajoy Das
Subhajoy Das
Arrow right icon
View More author details
Toc

Table of Contents (19) Chapters Close

Preface 1. Part 1: Foundations of Graph Learning FREE CHAPTER
2. Chapter 1: Introduction to Graph Learning 3. Chapter 2: Graph Learning in the Real World 4. Chapter 3: Graph Representation Learning 5. Part 2: Advanced Graph Learning Techniques
6. Chapter 4: Deep Learning Models for Graphs 7. Chapter 5: Graph Deep Learning Challenges 8. Chapter 6: Harnessing Large Language Models for Graph Learning 9. Part 3: Practical Applications and Implementation
10. Chapter 7: Graph Deep Learning in Practice 11. Chapter 8: Graph Deep Learning for Natural Language Processing 12. Chapter 9: Building Recommendation Systems Using Graph Deep Learning 13. Chapter 10: Graph Deep Learning for Computer Vision 14. Part 4: Future Directions
15. Chapter 11: Emerging Applications 16. Chapter 12: The Future of Graph Learning 17. Index 18. Other Books You May Enjoy

Node classification – predicting student interests

In this section, we’ll implement a GCN to predict student interests based on their features and connections in the social network. This task demonstrates how GNNs can leverage both node attributes and network structure to make predictions about individual nodes.

Let’s start by defining our GCN model:

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
# Initialize the model
model = GCN(in_channels=num_features,
            hidden_channels=16,
            out_channels=num_classes)

As you can see, our GCN model consists of two graph convolutional layers. The first layer takes the input features and produces hidden representations, while the second layer produces the final class predictions.

Now, let’s train the model. You can view the complete code at https://github.com/PacktPublishing/Applied-Deep-Learning-on-Graphs. Let’s break down the training process here:

  1. We define an optimizer (Adam) and a loss function (CrossEntropyLoss) for training our model:
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.CrossEntropyLoss()
  2. The train() function performs one step of training:
    def train():
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        return loss
    • It sets the model to training mode.
    • It computes the forward pass of the model.
    • It calculates the loss between predictions and true labels.
    • It performs backpropagation and updates the model parameters.
  3. The test() function evaluates the model’s performance:
    def test():
        model.eval()
        out = model(data.x, data.edge_index)
        pred = out.argmax(dim=1)
        correct = (pred == data.y).sum()
        acc = int(correct) / int(data.num_nodes)
        return acc
    • It sets the model to evaluation mode.
    • It computes the forward pass.
    • It calculates the accuracy of predictions.
  4. We train the model for 200 epochs, printing the loss and accuracy every 10 epochs:
    for epoch in range(200):
        loss = train()
        if epoch % 10 == 0:
            acc = test()
            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, \
                  Accuracy: {acc:.4f}')
  5. Finally, we evaluate the model one last time to get the final accuracy:
    final_acc = test()
    print(f"Final Accuracy: {final_acc:.4f}")

To visualize how well our model is performing, let’s create a function to plot the predicted versus true interest groups:

def visualize_predictions(model, data):
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    plt.figure(figsize=(10, 5))
    plt.subplot(121)
    plt.title("True Interests")
    plt.scatter(
        data.x[:, 0], data.x[:, 1], c=data.y, cmap='viridis')
    plt.colorbar()
    plt.subplot(122)
    plt.title("Predicted Interests")
    plt.scatter(
        data.x[:, 0], data.x[:, 1], c=pred, cmap='viridis')
    plt.colorbar()
    plt.tight_layout()
    plt.show()
visualize_predictions(model, data)

This function creates two scatter plots: one showing the true interest groups and another showing the predicted interest groups.

Figure 7.1: Model performance: true versus predicted interests

Figure 7.1: Model performance: true versus predicted interests

In Figure 7.1, each point represents a student, positioned based on their first two features. The visualization consists of two scatter plots labeled True Interests and Predicted Interests. Each point in these plots represents one of the students. The colors in both plots, ranging from purple (0.0) to yellow (4.0), indicate different interest groups. The left plot shows the actual or “true” interest groups of the students, while the right plot displays the model’s predictions of these interest groups. The similarity between the distributions conveys the effectiveness of graph learning techniques in such predictions.

In a real-world application, this node classification task could be used to predict student interests based on their profile information and social connections. This could be valuable for personalized content recommendations, targeted advertising, or improving student engagement in university activities.

Remember that while our synthetic dataset provides a clean example, real-world data often requires more preprocessing, handling of missing values, and careful consideration of privacy and ethical concerns when working with personal data.

Another aspect of graph learning is the task of link prediction. This comes up in a lot of real-world scenarios, especially ones where we are trying to predict certain connections.

lock icon The rest of the chapter is locked
Register for a free Packt account to unlock a world of extra content!
A free Packt account unlocks extra newsletters, articles, discounted offers, and much more. Start advancing your knowledge today.
Unlock this book and the full library FREE for 7 days
Get unlimited access to 7000+ expert-authored eBooks and videos courses covering every tech area you can think of
Renews at $19.99/month. Cancel anytime
Banner background image