# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Weave
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['WeaveGNN']
# pylint: disable=W0221, E1101
class WeaveLayer(nn.Module):
r"""Single Weave layer from `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
node_node_hidden_feats : int
Size for the hidden node representations in updating node representations.
Default to 50.
edge_node_hidden_feats : int
Size for the hidden edge representations in updating node representations.
Default to 50.
node_out_feats : int
Size for the output node representations. Default to 50.
node_edge_hidden_feats : int
Size for the hidden node representations in updating edge representations.
Default to 50.
edge_edge_hidden_feats : int
Size for the hidden edge representations in updating edge representations.
Default to 50.
edge_out_feats : int
Size for the output edge representations. Default to 50.
activation : callable
Activation function to apply. Default to ReLU.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
node_node_hidden_feats=50,
edge_node_hidden_feats=50,
node_out_feats=50,
node_edge_hidden_feats=50,
edge_edge_hidden_feats=50,
edge_out_feats=50,
activation=F.relu):
super(WeaveLayer, self).__init__()
self.activation = activation
# Layers for updating node representations
self.node_to_node = nn.Linear(node_in_feats, node_node_hidden_feats)
self.edge_to_node = nn.Linear(edge_in_feats, edge_node_hidden_feats)
self.update_node = nn.Linear(
node_node_hidden_feats + edge_node_hidden_feats, node_out_feats)
# Layers for updating edge representations
self.left_node_to_edge = nn.Linear(node_in_feats, node_edge_hidden_feats)
self.right_node_to_edge = nn.Linear(node_in_feats, node_edge_hidden_feats)
self.edge_to_edge = nn.Linear(edge_in_feats, edge_edge_hidden_feats)
self.update_edge = nn.Linear(
2 * node_edge_hidden_feats + edge_edge_hidden_feats, edge_out_feats)
def reset_parameters(self):
"""Reinitialize model parameters."""
self.node_to_node.reset_parameters()
self.edge_to_node.reset_parameters()
self.update_node.reset_parameters()
self.left_node_to_edge.reset_parameters()
self.right_node_to_edge.reset_parameters()
self.edge_to_edge.reset_parameters()
self.update_edge.reset_parameters()
def forward(self, g, node_feats, edge_feats, node_only=False):
r"""Update node and edge representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to update node representations only. If False, edge representations
will be updated as well. Default to False.
Returns
-------
new_node_feats : float32 tensor of shape (V, node_out_feats)
Updated node representations.
new_edge_feats : float32 tensor of shape (E, edge_out_feats)
Updated edge representations.
"""
g = g.local_var()
# Update node features
node_node_feats = self.activation(self.node_to_node(node_feats))
g.edata['e2n'] = self.activation(self.edge_to_node(edge_feats))
g.update_all(fn.copy_edge('e2n', 'm'), fn.sum('m', 'e2n'))
edge_node_feats = g.ndata.pop('e2n')
new_node_feats = self.activation(self.update_node(
torch.cat([node_node_feats, edge_node_feats], dim=1)))
if node_only:
return new_node_feats
# Update edge features
g.ndata['left_hv'] = self.left_node_to_edge(node_feats)
g.ndata['right_hv'] = self.right_node_to_edge(node_feats)
g.apply_edges(fn.u_add_v('left_hv', 'right_hv', 'first'))
g.apply_edges(fn.u_add_v('right_hv', 'left_hv', 'second'))
first_edge_feats = self.activation(g.edata.pop('first'))
second_edge_feats = self.activation(g.edata.pop('second'))
third_edge_feats = self.activation(self.edge_to_edge(edge_feats))
new_edge_feats = self.activation(self.update_edge(
torch.cat([first_edge_feats, second_edge_feats, third_edge_feats], dim=1)))
return new_node_feats, new_edge_feats
[docs]class WeaveGNN(nn.Module):
r"""The component of Weave for updating node and edge representations.
Weave is introduced in `Molecular Graph Convolutions: Moving Beyond Fingerprints
<https://arxiv.org/abs/1603.00856>`__.
Parameters
----------
node_in_feats : int
Size for the input node features.
edge_in_feats : int
Size for the input edge features.
num_layers : int
Number of Weave layers to use, which is equivalent to the times of message passing.
Default to 2.
hidden_feats : int
Size for the hidden node and edge representations. Default to 50.
activation : callable
Activation function to be used. It cannot be None. Default to ReLU.
"""
def __init__(self,
node_in_feats,
edge_in_feats,
num_layers=2,
hidden_feats=50,
activation=F.relu):
super(WeaveGNN, self).__init__()
self.gnn_layers = nn.ModuleList()
for i in range(num_layers):
if i == 0:
self.gnn_layers.append(WeaveLayer(node_in_feats=node_in_feats,
edge_in_feats=edge_in_feats,
node_node_hidden_feats=hidden_feats,
edge_node_hidden_feats=hidden_feats,
node_out_feats=hidden_feats,
node_edge_hidden_feats=hidden_feats,
edge_edge_hidden_feats=hidden_feats,
edge_out_feats=hidden_feats,
activation=activation))
else:
self.gnn_layers.append(WeaveLayer(node_in_feats=hidden_feats,
edge_in_feats=hidden_feats,
node_node_hidden_feats=hidden_feats,
edge_node_hidden_feats=hidden_feats,
node_out_feats=hidden_feats,
node_edge_hidden_feats=hidden_feats,
edge_edge_hidden_feats=hidden_feats,
edge_out_feats=hidden_feats,
activation=activation))
[docs] def reset_parameters(self):
"""Reinitialize model parameters."""
for layer in self.gnn_layers:
layer.reset_parameters()
[docs] def forward(self, g, node_feats, edge_feats, node_only=True):
"""Updates node representations (and edge representations).
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_in_feats)
Input node features. V for the number of nodes in the batch of graphs.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features. E for the number of edges in the batch of graphs.
node_only : bool
Whether to return updated node representations only or to return both
node and edge representations. Default to True.
Returns
-------
float32 tensor of shape (V, gnn_hidden_feats)
Updated node representations.
float32 tensor of shape (E, gnn_hidden_feats), optional
This is returned only when ``node_only==False``. Updated edge representations.
"""
for i in range(len(self.gnn_layers) - 1):
node_feats, edge_feats = self.gnn_layers[i](g, node_feats, edge_feats)
return self.gnn_layers[-1](g, node_feats, edge_feats, node_only)