Source code for dgllife.model.model_zoo.wln_reaction_center

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction.
# pylint: disable= no-member, arguments-differ, invalid-name

import dgl.function as fn
import torch
import torch.nn as nn

from ..gnn.wln import WLNLinear, WLN

__all__ = ['WLNReactionCenter']

# pylint: disable=W0221, E1101
class WLNContext(nn.Module):
    """Attention-based context computation for each node.

    A context vector is computed by taking a weighted sum of node representations,
    with weights computed from an attention module.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    node_pair_in_feats : int
        Size for the input features of node pairs.
    """
    def __init__(self, node_in_feats, node_pair_in_feats):
        super(WLNContext, self).__init__()

        self.project_feature_sum = WLNLinear(node_in_feats, node_in_feats, bias=False)
        self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_in_feats)
        self.compute_attention = nn.Sequential(
            nn.ReLU(),
            WLNLinear(node_in_feats, 1),
            nn.Sigmoid()
        )

    def forward(self, batch_complete_graphs, node_feats, feat_sum, node_pair_feat):
        """Compute context vectors for each node.

        Parameters
        ----------
        batch_complete_graphs : DGLGraph
            A batch of fully connected graphs.
        node_feats : float32 tensor of shape (V, node_in_feats)
            Input node features. V for the number of nodes.
        feat_sum : float32 tensor of shape (E_full, node_in_feats)
            Sum of node_feats between each pair of nodes. E_full for the number of
            edges in the batch of complete graphs.
        node_pair_feat : float32 tensor of shape (E_full, node_pair_in_feats)
            Input features for each pair of nodes. E_full for the number of edges in
            the batch of complete graphs.

        Returns
        -------
        node_contexts : float32 tensor of shape (V, node_in_feats)
            Context vectors for nodes.
        """
        with batch_complete_graphs.local_scope():
            batch_complete_graphs.ndata['hv'] = node_feats
            batch_complete_graphs.edata['a'] = self.compute_attention(
                self.project_feature_sum(feat_sum) + \
                self.project_node_pair_feature(node_pair_feat)
            )
            batch_complete_graphs.update_all(
                fn.src_mul_edge('hv', 'a', 'm'), fn.sum('m', 'context'))
            node_contexts = batch_complete_graphs.ndata.pop('context')

        return node_contexts

[docs]class WLNReactionCenter(nn.Module): r"""Weisfeiler-Lehman Network (WLN) for Reaction Center Prediction. The model is introduced in `Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__. The model uses WLN to update atom representations and then predicts the score for each pair of atoms to form a bond. Parameters ---------- node_in_feats : int Size for the input node features. edge_in_feats : int Size for the input edge features. node_out_feats : int Size for the output node representations. Default to 300. node_pair_in_feats : int Size for the input features of node pairs. n_layers : int Number of times for message passing. Note that same parameters are shared across n_layers message passing. Default to 3. n_tasks : int Number of tasks for prediction. """ def __init__(self, node_in_feats, edge_in_feats, node_pair_in_feats, node_out_feats=300, n_layers=3, n_tasks=5): super(WLNReactionCenter, self).__init__() self.gnn = WLN(node_in_feats=node_in_feats, edge_in_feats=edge_in_feats, node_out_feats=node_out_feats, n_layers=n_layers) self.context_module = WLNContext(node_in_feats=node_out_feats, node_pair_in_feats=node_pair_in_feats) self.project_feature_sum = WLNLinear(node_out_feats, node_out_feats, bias=False) self.project_node_pair_feature = WLNLinear(node_pair_in_feats, node_out_feats, bias=False) self.project_context_sum = WLNLinear(node_out_feats, node_out_feats) self.predict = nn.Sequential( nn.ReLU(), WLNLinear(node_out_feats, n_tasks) )
[docs] def forward(self, batch_mol_graphs, batch_complete_graphs, node_feats, edge_feats, node_pair_feats): r"""Predict score for each pair of nodes. Parameters ---------- batch_mol_graphs : DGLGraph A batch of molecular graphs. batch_complete_graphs : DGLGraph A batch of fully connected graphs. node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes. edge_feats : float32 tensor of shape (E, edge_in_feats) Input edge features. E for the number of edges. node_pair_feats : float32 tensor of shape (E_full, node_pair_in_feats) Input features for each pair of nodes. E_full for the number of edges in the batch of complete graphs. Returns ------- scores : float32 tensor of shape (E_full, 5) Predicted scores for each pair of atoms to perform one of the following 5 actions in reaction: * The bond between them gets broken * Forming a single bond * Forming a double bond * Forming a triple bond * Forming an aromatic bond biased_scores : float32 tensor of shape (E_full, 5) Comparing to scores, a bias is added if the pair is for a same atom. """ node_feats = self.gnn(batch_mol_graphs, node_feats, edge_feats) # Compute context vectors for all atoms, which are weighted sum of atom # representations in all reactants. with batch_complete_graphs.local_scope(): batch_complete_graphs.ndata['hv'] = node_feats batch_complete_graphs.apply_edges(fn.u_add_v('hv', 'hv', 'feature_sum')) feat_sum = batch_complete_graphs.edata.pop('feature_sum') node_contexts = self.context_module(batch_complete_graphs, node_feats, feat_sum, node_pair_feats) # Predict score with batch_complete_graphs.local_scope(): batch_complete_graphs.ndata['context'] = node_contexts batch_complete_graphs.apply_edges(fn.u_add_v('context', 'context', 'context_sum')) scores = self.predict( self.project_feature_sum(feat_sum) + \ self.project_node_pair_feature(node_pair_feats) + \ self.project_context_sum(batch_complete_graphs.edata['context_sum']) ) # Masking self loops nodes = batch_complete_graphs.nodes() e_ids = batch_complete_graphs.edge_ids(nodes, nodes) bias = torch.zeros(scores.shape[0], 5).to(scores.device) bias[e_ids.long(), :] = 1e4 biased_scores = scores - bias return scores, biased_scores