Source code for dgllife.model.model_zoo.attentivefp_predictor

# -*- coding: utf-8 -*-
# Copyright, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# AttentiveFP
# pylint: disable= no-member, arguments-differ, invalid-name

import torch.nn as nn

from ..gnn import AttentiveFPGNN
from ..readout import AttentiveFPReadout

__all__ = ['AttentiveFPPredictor']

# pylint: disable=W0221
[docs]class AttentiveFPPredictor(nn.Module): """AttentiveFP for regression and classification on graphs. AttentiveFP is introduced in `Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism. <>`__ Parameters ---------- node_feat_size : int Size for the input node features. edge_feat_size : int Size for the input edge features. num_layers : int Number of GNN layers. Default to 2. num_timesteps : int Times of updating the graph representations with GRU. Default to 2. graph_feat_size : int Size for the learned graph representations. Default to 200. n_tasks : int Number of tasks, which is also the output size. Default to 1. dropout : float Probability for performing the dropout. Default to 0. """ def __init__(self, node_feat_size, edge_feat_size, num_layers=2, num_timesteps=2, graph_feat_size=200, n_tasks=1, dropout=0.): super(AttentiveFPPredictor, self).__init__() self.gnn = AttentiveFPGNN(node_feat_size=node_feat_size, edge_feat_size=edge_feat_size, num_layers=num_layers, graph_feat_size=graph_feat_size, dropout=dropout) self.readout = AttentiveFPReadout(feat_size=graph_feat_size, num_timesteps=num_timesteps, dropout=dropout) self.predict = nn.Sequential( nn.Dropout(dropout), nn.Linear(graph_feat_size, n_tasks) )
[docs] def forward(self, g, node_feats, edge_feats, get_node_weight=False): """Graph-level regression/soft classification. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs. node_feats : float32 tensor of shape (V, node_feat_size) Input node features. V for the number of nodes. edge_feats : float32 tensor of shape (E, edge_feat_size) Input edge features. E for the number of edges. get_node_weight : bool Whether to get the weights of atoms during readout. Default to False. Returns ------- float32 tensor of shape (G, n_tasks) Prediction for the graphs in the batch. G for the number of graphs. node_weights : list of float32 tensor of shape (V, 1), optional This is returned when ``get_node_weight`` is ``True``. The list has a length ``num_timesteps`` and ``node_weights[i]`` gives the node weights in the i-th update. """ node_feats = self.gnn(g, node_feats, edge_feats) if get_node_weight: g_feats, node_weights = self.readout(g, node_feats, get_node_weight) return self.predict(g_feats), node_weights else: g_feats = self.readout(g, node_feats, get_node_weight) return self.predict(g_feats)