# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# MPNN
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import NNConv
__all__ = ['MPNNGNN']
# pylint: disable=W0221
[docs]class MPNNGNN(nn.Module):
"""MPNN.
MPNN is introduced in `Neural Message Passing for Quantum Chemistry
<https://arxiv.org/abs/1704.01212>`__.
This class performs message passing in MPNN and returns the updated node representations.
Parameters
----------
node_in_feats : int
Size for the input node features.
node_out_feats : int
Size for the output node representations. Default to 64.
edge_in_feats : int
Size for the input edge features. Default to 128.
edge_hidden_feats : int
Size for the hidden edge representations.
num_step_message_passing : int
Number of message passing steps. Default to 6.
"""
def __init__(self, node_in_feats, edge_in_feats, node_out_feats=64,
edge_hidden_feats=128, num_step_message_passing=6):
super(MPNNGNN, self).__init__()
self.project_node_feats = nn.Sequential(
nn.Linear(node_in_feats, node_out_feats),
nn.ReLU()
)
self.num_step_message_passing = num_step_message_passing
edge_network = nn.Sequential(
nn.Linear(edge_in_feats, edge_hidden_feats),
nn.ReLU(),
nn.Linear(edge_hidden_feats, node_out_feats * node_out_feats)
)
self.gnn_layer = NNConv(
in_feats=node_out_feats,
out_feats=node_out_feats,
edge_func=edge_network,
aggregator_type='sum'
)
self.gru = nn.GRU(node_out_feats, node_out_feats)
[docs] def reset_parameters(self):
"""Reinitialize model parameters."""
self.project_node_feats[0].reset_parameters()
self.gnn_layer.reset_parameters()
for layer in self.gnn_layer.edge_func:
if isinstance(layer, nn.Linear):
layer.reset_parameters()
self.gru.reset_parameters()
[docs] def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
Returns
-------
node_feats : float32 tensor of shape (V, node_out_feats)
Output node representations.
"""
node_feats = self.project_node_feats(node_feats) # (V, node_out_feats)
hidden_feats = node_feats.unsqueeze(0) # (1, V, node_out_feats)
for _ in range(self.num_step_message_passing):
node_feats = F.relu(self.gnn_layer(g, node_feats, edge_feats))
node_feats, hidden_feats = self.gru(node_feats.unsqueeze(0), hidden_feats)
node_feats = node_feats.squeeze(0)
return node_feats