Source code for dgllife.model.model_zoo.pagtn_predictor

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Path-Augmented Graph Transformer Network
# pylint: disable= no-member, arguments-differ, invalid-name

import torch
import torch.nn as nn
from ..gnn import PAGTNGNN
from ..readout import MLPNodeReadout


[docs]class PAGTNPredictor(nn.Module): """PAGTN model for regression and classification on graphs. PAGTN is introduced in `Path-Augmented Graph Transformer Network <https://arxiv.org/abs/1905.12712>`__. Parameters ---------- node_in_feats : int Size for the input node features. node_out_feats : int Size for the output node features in PAGTN layers. node_hid_feats : int Size for the hidden node features in PAGTN layers. edge_feats : int Size for the input edge features. depth : int Number of PAGTN layers to be applied. nheads : int Number of attention heads. dropout : float The probability for performing dropout. Default to 0.1 activation : callable Activation function to apply. Default to LeakyReLU. n_tasks : int Number of tasks, which is also the output size. Default to 1. mode : 'max' or 'mean' or 'sum' Whether to compute elementwise maximum, mean or sum of the node representations. """ def __init__(self, node_in_feats, node_out_feats, node_hid_feats, edge_feats, depth=5, nheads=1, dropout=0.1, activation=nn.LeakyReLU(0.2), n_tasks=1, mode='sum'): super(PAGTNPredictor, self).__init__() self.model = PAGTNGNN(node_in_feats, node_out_feats, node_hid_feats, edge_feats, depth, nheads, dropout, activation) self.readout = MLPNodeReadout(node_out_feats + node_in_feats, node_out_feats, n_tasks, mode=mode)
[docs] def forward(self, g, node_feats, edge_feats): """Graph-level regression/soft classification. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes in the batch of graphs. edge_feats : float32 tensor of shape (E, edge_in_feats) Input edge features. E for the number of edges in the batch of graphs. Returns ------- float32 tensor of shape (V, node_out_feats) Updated node features. """ atom_h = self.model(g, node_feats, edge_feats) atom_h = torch.cat([atom_h, node_feats], dim=1) return self.readout(g, atom_h)