Source code for dgllife.model.readout.attentivefp_readout

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Readout for AttentiveFP
# pylint: disable= no-member, arguments-differ, invalid-name

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['AttentiveFPReadout']

# pylint: disable=W0221
class GlobalPool(nn.Module):
    """One-step readout in AttentiveFP

    Parameters
    ----------
    feat_size : int
        Size for the input node features, graph features and output graph
        representations.
    dropout : float
        The probability for performing dropout.
    """
    def __init__(self, feat_size, dropout):
        super(GlobalPool, self).__init__()

        self.compute_logits = nn.Sequential(
            nn.Linear(2 * feat_size, 1),
            nn.LeakyReLU()
        )
        self.project_nodes = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(feat_size, feat_size)
        )
        self.gru = nn.GRUCell(feat_size, feat_size)

    def forward(self, g, node_feats, g_feats, get_node_weight=False):
        """Perform one-step readout

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs.
        node_feats : float32 tensor of shape (V, node_feat_size)
            Input node features. V for the number of nodes.
        g_feats : float32 tensor of shape (G, graph_feat_size)
            Input graph features. G for the number of graphs.
        get_node_weight : bool
            Whether to get the weights of atoms during readout.

        Returns
        -------
        float32 tensor of shape (G, graph_feat_size)
            Updated graph features.
        float32 tensor of shape (V, 1)
            The weights of nodes in readout.
        """
        with g.local_scope():
            g.ndata['z'] = self.compute_logits(
                torch.cat([dgl.broadcast_nodes(g, F.relu(g_feats)), node_feats], dim=1))
            g.ndata['a'] = dgl.softmax_nodes(g, 'z')
            g.ndata['hv'] = self.project_nodes(node_feats)

            g_repr = dgl.sum_nodes(g, 'hv', 'a')
            context = F.elu(g_repr)

            if get_node_weight:
                return self.gru(context, g_feats), g.ndata['a']
            else:
                return self.gru(context, g_feats)

[docs]class AttentiveFPReadout(nn.Module): """Readout in AttentiveFP AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__ This class computes graph representations out of node features. Parameters ---------- feat_size : int Size for the input node features, graph features and output graph representations. num_timesteps : int Times of updating the graph representations with GRU. Default to 2. dropout : float The probability for performing dropout. Default to 0. """ def __init__(self, feat_size, num_timesteps=2, dropout=0.): super(AttentiveFPReadout, self).__init__() self.readouts = nn.ModuleList() for _ in range(num_timesteps): self.readouts.append(GlobalPool(feat_size, dropout))
[docs] def forward(self, g, node_feats, get_node_weight=False): """Computes graph representations out of node features. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_feat_size) Input node features. V for the number of nodes. get_node_weight : bool Whether to get the weights of nodes in readout. Default to False. Returns ------- g_feats : float32 tensor of shape (G, graph_feat_size) Graph representations computed. G for the number of graphs. node_weights : list of float32 tensor of shape (V, 1), optional This is returned when ``get_node_weight`` is ``True``. The list has a length ``num_timesteps`` and ``node_weights[i]`` gives the node weights in the i-th update. """ with g.local_scope(): g.ndata['hv'] = node_feats g_feats = dgl.sum_nodes(g, 'hv') if get_node_weight: node_weights = [] for readout in self.readouts: if get_node_weight: g_feats, node_weights_t = readout(g, node_feats, g_feats, get_node_weight) node_weights.append(node_weights_t) else: g_feats = readout(g, node_feats, g_feats) if get_node_weight: return g_feats, node_weights else: return g_feats