Source code for dgllife.model.model_zoo.wln_reaction_ranking

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Weisfeiler-Lehman Network (WLN) for ranking candidate products
# pylint: disable= no-member, arguments-differ, invalid-name

import torch
import torch.nn as nn

from dgl.nn.pytorch import SumPooling

from ..gnn.wln import WLN

__all__ = ['WLNReactionRanking']

# pylint: disable=W0221, E1101
[docs]class WLNReactionRanking(nn.Module): r"""Weisfeiler-Lehman Network (WLN) for Candidate Product Ranking The model is introduced in `Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network <https://arxiv.org/abs/1709.04555>`__ and then further improved in `A graph-convolutional neural network model for the prediction of chemical reactivity <https://pubs.rsc.org/en/content/articlelanding/2019/sc/c8sc04228d#!divAbstract>`__ The model updates representations of nodes in candidate products with WLN and predicts the score for candidate products to be the real product. Parameters ---------- node_in_feats : int Size for the input node features. edge_in_feats : int Size for the input edge features. node_hidden_feats : int Size for the hidden node representations. Default to 500. num_encode_gnn_layers : int Number of WLN layers for updating node representations. """ def __init__(self, node_in_feats, edge_in_feats, node_hidden_feats=500, num_encode_gnn_layers=3): super(WLNReactionRanking, self).__init__() self.gnn = WLN(node_in_feats=node_in_feats, edge_in_feats=edge_in_feats, node_out_feats=node_hidden_feats, n_layers=num_encode_gnn_layers, set_comparison=False) self.diff_gnn = WLN(node_in_feats=node_hidden_feats, edge_in_feats=edge_in_feats, node_out_feats=node_hidden_feats, n_layers=1, project_in_feats=False, set_comparison=False) self.readout = SumPooling() self.predict = nn.Sequential( nn.Linear(node_hidden_feats, node_hidden_feats), nn.ReLU(), nn.Linear(node_hidden_feats, 1) )
[docs] def forward(self, reactant_graph, reactant_node_feats, reactant_edge_feats, product_graphs, product_node_feats, product_edge_feats, candidate_scores, batch_num_candidate_products): r"""Predicts the score for candidate products to be the true product Parameters ---------- reactant_graph : DGLGraph DGLGraph for a batch of reactants. reactant_node_feats : float32 tensor of shape (V1, node_in_feats) Input node features for the reactants. V1 for the number of nodes. reactant_edge_feats : float32 tensor of shape (E1, edge_in_feats) Input edge features for the reactants. E1 for the number of edges in reactant_graph. product_graphs : DGLGraph DGLGraph for the candidate products in a batch of reactions. product_node_feats : float32 tensor of shape (V2, node_in_feats) Input node features for the candidate products. V2 for the number of nodes. product_edge_feats : float32 tensor of shape (E2, edge_in_feats) Input edge features for the candidate products. E2 for the number of edges in the graphs for candidate products. candidate_scores : float32 tensor of shape (B, 1) Scores for candidate products based on the model for reaction center prediction batch_num_candidate_products : list of int Number of candidate products for the reactions in the batch Returns ------- float32 tensor of shape (B, 1) Predicted scores for candidate products """ # Update representations for nodes in both reactants and candidate products batch_reactant_node_feats = self.gnn( reactant_graph, reactant_node_feats, reactant_edge_feats) batch_product_node_feats = self.gnn( product_graphs, product_node_feats, product_edge_feats) # Iterate over the reactions in the batch reactant_node_start = 0 product_graph_start = 0 product_node_start = 0 batch_diff_node_feats = [] for i, num_candidate_products in enumerate(batch_num_candidate_products): reactant_node_end = reactant_node_start + reactant_graph.batch_num_nodes()[i] product_graph_end = product_graph_start + num_candidate_products product_node_end = product_node_start + sum( product_graphs.batch_num_nodes()[product_graph_start: product_graph_end]) # (N, node_out_feats) reactant_node_feats = batch_reactant_node_feats[reactant_node_start: reactant_node_end, :] product_node_feats = batch_product_node_feats[product_node_start: product_node_end, :] old_feats_shape = reactant_node_feats.shape # (1, N, node_out_feats) expanded_reactant_node_feats = reactant_node_feats.reshape((1,) + old_feats_shape) # (B, N, node_out_feats) expanded_reactant_node_feats = expanded_reactant_node_feats.expand( (num_candidate_products,) + old_feats_shape) # (B, N, node_out_feats) candidate_product_node_feats = product_node_feats.reshape( (num_candidate_products,) + old_feats_shape) # Get the node representation difference between candidate products and reactants diff_node_feats = candidate_product_node_feats - expanded_reactant_node_feats diff_node_feats = diff_node_feats.reshape(-1, diff_node_feats.shape[-1]) batch_diff_node_feats.append(diff_node_feats) reactant_node_start = reactant_node_end product_graph_start = product_graph_end product_node_start = product_node_end batch_diff_node_feats = torch.cat(batch_diff_node_feats, dim=0) # One more GNN layer for message passing with the node representation difference diff_node_feats = self.diff_gnn(product_graphs, batch_diff_node_feats, product_edge_feats) candidate_product_feats = self.readout(product_graphs, diff_node_feats) return self.predict(candidate_product_feats) + candidate_scores