# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Readout for AttentiveFP
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['AttentiveFPReadout']
# pylint: disable=W0221
class GlobalPool(nn.Module):
"""One-step readout in AttentiveFP
Parameters
----------
feat_size : int
Size for the input node features, graph features and output graph
representations.
dropout : float
The probability for performing dropout.
"""
def __init__(self, feat_size, dropout):
super(GlobalPool, self).__init__()
self.compute_logits = nn.Sequential(
nn.Linear(2 * feat_size, 1),
nn.LeakyReLU()
)
self.project_nodes = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(feat_size, feat_size)
)
self.gru = nn.GRUCell(feat_size, feat_size)
def forward(self, g, node_feats, g_feats, get_node_weight=False):
"""Perform one-step readout
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
g_feats : float32 tensor of shape (G, graph_feat_size)
Input graph features. G for the number of graphs.
get_node_weight : bool
Whether to get the weights of atoms during readout.
Returns
-------
float32 tensor of shape (G, graph_feat_size)
Updated graph features.
float32 tensor of shape (V, 1)
The weights of nodes in readout.
"""
with g.local_scope():
g.ndata['z'] = self.compute_logits(
torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
g.ndata['a'] = dgl.softmax_nodes(g, 'z')
g.ndata['hv'] = self.project_nodes(node_feats)
g_repr = dgl.sum_nodes(g, 'hv', 'a')
context = F.elu(g_repr)
if get_node_weight:
return self.gru(context, g_feats), g.ndata['a']
else:
return self.gru(context, g_feats)
[docs]class AttentiveFPReadout(nn.Module):
"""Readout in AttentiveFP
AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for
Drug Discovery with the Graph Attention Mechanism
<https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
This class computes graph representations out of node features.
Parameters
----------
feat_size : int
Size for the input node features, graph features and output graph
representations.
num_timesteps : int
Times of updating the graph representations with GRU. Default to 2.
dropout : float
The probability for performing dropout. Default to 0.
"""
def __init__(self, feat_size, num_timesteps=2, dropout=0.):
super(AttentiveFPReadout, self).__init__()
self.readouts = nn.ModuleList()
for _ in range(num_timesteps):
self.readouts.append(GlobalPool(feat_size, dropout))
[docs] def forward(self, g, node_feats, get_node_weight=False):
"""Computes graph representations out of node features.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
get_node_weight : bool
Whether to get the weights of nodes in readout. Default to False.
Returns
-------
g_feats : float32 tensor of shape (G, graph_feat_size)
Graph representations computed. G for the number of graphs.
node_weights : list of float32 tensor of shape (V, 1), optional
This is returned when ``get_node_weight`` is ``True``.
The list has a length ``num_timesteps`` and ``node_weights[i]``
gives the node weights in the i-th update.
"""
with g.local_scope():
g.ndata['hv'] = node_feats
g_feats = dgl.sum_nodes(g, 'hv')
if get_node_weight:
node_weights = []
for readout in self.readouts:
if get_node_weight:
g_feats, node_weights_t = readout(g, node_feats, g_feats, get_node_weight)
node_weights.append(node_weights_t)
else:
g_feats = readout(g, node_feats, g_feats)
if get_node_weight:
return g_feats, node_weights
else:
return g_feats