# -*- coding: utf-8 -*-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
from ..gnn import AttentiveFPGNN
from ..readout import AttentiveFPReadout
__all__ = ['AttentiveFPPredictor']
# pylint: disable=W0221
"""AttentiveFP for regression and classification on graphs.
AttentiveFP is introduced in
`Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph
Attention Mechanism. <https://www.ncbi.nlm.nih.gov/pubmed/31408336>`__
node_feat_size : int
Size for the input node features.
edge_feat_size : int
Size for the input edge features.
num_layers : int
Number of GNN layers. Default to 2.
num_timesteps : int
Times of updating the graph representations with GRU. Default to 2.
graph_feat_size : int
Size for the learned graph representations. Default to 200.
n_tasks : int
Number of tasks, which is also the output size. Default to 1.
dropout : float
Probability for performing the dropout. Default to 0.
self.gnn = AttentiveFPGNN(node_feat_size=node_feat_size,
self.readout = AttentiveFPReadout(feat_size=graph_feat_size,
self.predict = nn.Sequential(
[docs] def forward(self, g, node_feats, edge_feats, get_node_weight=False):
"""Graph-level regression/soft classification.
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feat_size)
Input node features. V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_feat_size)
Input edge features. E for the number of edges.
get_node_weight : bool
Whether to get the weights of atoms during readout. Default to False.
float32 tensor of shape (G, n_tasks)
Prediction for the graphs in the batch. G for the number of graphs.
node_weights : list of float32 tensor of shape (V, 1), optional
This is returned when ``get_node_weight`` is ``True``.
The list has a length ``num_timesteps`` and ``node_weights[i]``
gives the node weights in the i-th update.
node_feats = self.gnn(g, node_feats, edge_feats)
g_feats, node_weights = self.readout(g, node_feats, get_node_weight)
return self.predict(g_feats), node_weights
g_feats = self.readout(g, node_feats, get_node_weight)