Source code for dgllife.model.gnn.mpnn

# -*- coding: utf-8 -*-
# Copyright, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable= no-member, arguments-differ, invalid-name

import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import NNConv

__all__ = ['MPNNGNN']

# pylint: disable=W0221
[docs]class MPNNGNN(nn.Module): """MPNN. MPNN is introduced in `Neural Message Passing for Quantum Chemistry <>`__. This class performs message passing in MPNN and returns the updated node representations. Parameters ---------- node_in_feats : int Size for the input node features. node_out_feats : int Size for the output node representations. Default to 64. edge_in_feats : int Size for the input edge features. Default to 128. edge_hidden_feats : int Size for the hidden edge representations. num_step_message_passing : int Number of message passing steps. Default to 6. """ def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64, edge_hidden_feats=128, num_step_message_passing=6): super(MPNNGNN, self).__init__() self.project_node_feats = nn.Sequential( nn.Linear(node_in_feats, node_out_feats), nn.ReLU() ) self.num_step_message_passing = num_step_message_passing edge_network = nn.Sequential( nn.Linear(edge_in_feats, edge_hidden_feats), nn.ReLU(), nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats) ) self.gnn_layer = NNConv( in_feats=node_out_feats, out_feats=node_out_feats, edge_func=edge_network, aggregator_type='sum' ) self.gru = nn.GRU(node_out_feats, node_out_feats)
[docs] def reset_parameters(self): """Reinitialize model parameters.""" self.project_node_feats[0].reset_parameters() self.gnn_layer.reset_parameters() for layer in self.gnn_layer.edge_func: if isinstance(layer, nn.Linear): layer.reset_parameters() self.gru.reset_parameters()
[docs] def forward(self, g, node_feats, edge_feats): """Performs message passing and updates node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes in the batch of graphs. edge_feats : float32 tensor of shape (E, edge_in_feats) Input edge features. E for the number of edges in the batch of graphs. Returns ------- node_feats : float32 tensor of shape (V, node_out_feats) Output node representations. """ node_feats = self.project_node_feats(node_feats) # (V, node_out_feats) hidden_feats = node_feats.unsqueeze(0) # (1, V, node_out_feats) for _ in range(self.num_step_message_passing): node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats)) node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats) node_feats = node_feats.squeeze(0) return node_feats