Readout for Computing Graph Representations¶
After updating node/edge representations with graph neural networks (GNNs), a common operation is to compute graph representations out of updated node/edge representations. For example, we need to compute molecular representations out of atom/bond representations in molecular property prediction. We call the various modules for computing graph-level representations readout as in Neural Message Passing for Quantum Chemistry and this section lists the readout modules implemented in DGL-LifeSci.
AttentiveFP Readout¶
-
class
dgllife.model.readout.attentivefp_readout.
AttentiveFPReadout
(feat_size, num_timesteps=2, dropout=0.0)[source]¶ Readout in AttentiveFP
AttentiveFP is introduced in Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism
This class computes graph representations out of node features.
- Parameters
-
forward
(g, node_feats, get_node_weight=False)[source]¶ Computes graph representations out of node features.
- Parameters
g (DGLGraph) – DGLGraph for a batch of graphs.
node_feats (float32 tensor of shape (V, node_feat_size)) – Input node features. V for the number of nodes.
get_node_weight (bool) – Whether to get the weights of nodes in readout. Default to False.
- Returns
g_feats (float32 tensor of shape (G, graph_feat_size)) – Graph representations computed. G for the number of graphs.
node_weights (list of float32 tensor of shape (V, 1), optional) – This is returned when
get_node_weight
isTrue
. The list has a lengthnum_timesteps
andnode_weights[i]
gives the node weights in the i-th update.
MLP Readout¶
-
class
dgllife.model.readout.mlp_readout.
MLPNodeReadout
(node_feats, hidden_feats, graph_feats, activation=None, mode='sum')[source]¶ MLP-based Readout.
This layer updates node representations with a MLP and computes graph representations out of node representations with max, mean or sum.
- Parameters
node_feats (int) – Size for the input node features.
hidden_feats (int) – Size for the hidden representations.
graph_feats (int) – Size for the output graph representations.
activation (callable) – Activation function. Default to None.
mode ('max' or 'mean' or 'sum') – Whether to compute elementwise maximum, mean or sum of the node representations.
-
forward
(g, node_feats)[source]¶ Computes graph representations out of node features.
- Parameters
g (DGLGraph) – DGLGraph for a batch of graphs.
node_feats (float32 tensor of shape (V, node_feats)) – Input node features, V for the number of nodes.
- Returns
graph_feats – Graph representations computed. G for the number of graphs.
- Return type
float32 tensor of shape (G, graph_feats)
Weighted Sum and Max Readout¶
-
class
dgllife.model.readout.weighted_sum_and_max.
WeightedSumAndMax
(in_feats)[source]¶ Apply weighted sum and max pooling to the node representations and concatenate the results.
- Parameters
in_feats (int) – Input node feature size
-
forward
(bg, feats)[source]¶ Readout
- Parameters
bg (DGLGraph) – DGLGraph for a batch of graphs.
feats (FloatTensor of shape (N, M1)) –
N is the total number of nodes in the batch of graphs
M1 is the input node feature size, which must match in_feats in initialization
- Returns
h_g –
B is the number of graphs in the batch
M1 is the input node feature size, which must match in_feats in initialization
- Return type
FloatTensor of shape (B, 2 * M1)
Weave Readout¶
-
class
dgllife.model.readout.weave_readout.
WeaveGather
(node_in_feats, gaussian_expand=True, gaussian_memberships=None, activation=Tanh())[source]¶ Readout in Weave
- Parameters
node_in_feats (int) – Size for the input node features.
gaussian_expand (bool) – Whether to expand each dimension of node features by gaussian histogram. Default to True.
gaussian_memberships (list of 2-tuples) – For each tuple, the first and second element separately specifies the mean and std for constructing a normal distribution. This argument comes into effect only when
gaussian_expand==True
. By default, we set this to be a list consisting of(-1.645, 0.283)
,(-1.080, 0.170)
,(-0.739, 0.134)
,(-0.468, 0.118)
,(-0.228, 0.114)
,(0., 0.114)
,(0.228, 0.114)
,(0.468, 0.118)
,(0.739, 0.134)
,(1.080, 0.170)
,(1.645, 0.283)
.activation (callable) – Activation function to apply. Default to tanh.
-
forward
(g, node_feats)[source]¶ Computes graph representations out of node representations.
- Parameters
g (DGLGraph) – DGLGraph for a batch of graphs.
node_feats (float32 tensor of shape (V, node_in_feats)) – Input node features. V for the number of nodes in the batch of graphs.
- Returns
g_feats – Output graph representations. G for the number of graphs in the batch.
- Return type
float32 tensor of shape (G, node_in_feats)
-
gaussian_histogram
(node_feats)[source]¶ Constructs a gaussian histogram to capture the distribution of features
- Parameters
node_feats (float32 tensor of shape (V, node_in_feats)) – Input node features. V for the number of nodes in the batch of graphs.
- Returns
Updated node representations
- Return type
float32 tensor of shape (V, node_in_feats * len(self.means))