Source code for dgllife.model.gnn.pagtn

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Path-Augmented Graph Transformer Network
# pylint: disable= no-member, arguments-differ, invalid-name

import torch
import torch.nn as nn
from dgl.nn.functional import edge_softmax
import dgl.function as fn

__all__ = ['PAGTNGNN']

class PAGTNLayer(nn.Module):
    """
    Single PAGTN layer from `Path-Augmented Graph Transformer Network
    <https://arxiv.org/abs/1905.12712>`__

    This will be used for incorporating the information of edge features
    into node features for message passing.

    Parameters
    ----------
    node_in_feats : int
        Size for the input node features.
    node_out_feats : int
        Size for the output node features.
    edge_feats : int
        Size for the input edge features.
    dropout : float
        The probability for performing dropout. Default to 0.1
    activation : callable
        Activation function to apply. Default to LeakyReLU.
    """
    def __init__(self,
                 node_in_feats,
                 node_out_feats,
                 edge_feats,
                 dropout=0.1,
                 activation=nn.LeakyReLU(0.2)):
        super(PAGTNLayer, self).__init__()
        self.attn_src = nn.Linear(node_in_feats, node_in_feats)
        self.attn_dst = nn.Linear(node_in_feats, node_in_feats)
        self.attn_edg = nn.Linear(edge_feats, node_in_feats)
        self.attn_dot = nn.Linear(node_in_feats, 1)
        self.msg_src = nn.Linear(node_in_feats, node_out_feats)
        self.msg_dst = nn.Linear(node_in_feats, node_out_feats)
        self.msg_edg = nn.Linear(edge_feats, node_out_feats)
        self.wgt_n = nn.Linear(node_in_feats, node_out_feats)
        self.dropout = nn.Dropout(dropout)
        self.act = activation
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.attn_src.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_dst.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_edg.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_dot.weight, gain=gain)
        nn.init.xavier_normal_(self.msg_src.weight, gain=gain)
        nn.init.xavier_normal_(self.msg_dst.weight, gain=gain)
        nn.init.xavier_normal_(self.msg_edg.weight, gain=gain)
        nn.init.xavier_normal_(self.wgt_n.weight, gain=gain)

    def forward(self, g, node_feats, edge_feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        node_feats : float32 tensor of shape (V, node_in_feats) or (V, n_head, node_in_feats)
            Input node features. V for the number of nodes in the batch of graphs.
        edge_feats : float32 tensor of shape (E, edge_in_feats)
            Input edge features. E for the number of edges in the batch of graphs.

        Returns
        -------
        float32 tensor of shape (V, node_out_feats) or (V, n_head, node_out_feats)
            Updated node features.
        """

        g = g.local_var()
        # In the paper node_src, node_dst, edge feats are concatenated
        # and multiplied with the matrix. We have optimized this step
        # by having three separate matrix multiplication.
        g.ndata['src'] = self.dropout(self.attn_src(node_feats))
        g.ndata['dst'] = self.dropout(self.attn_dst(node_feats))
        edg_atn = self.dropout(self.attn_edg(edge_feats)).unsqueeze(-2)
        g.apply_edges(fn.u_add_v('src', 'dst', 'e'))
        atn_scores = self.act(g.edata.pop('e') + edg_atn)

        atn_scores = self.attn_dot(atn_scores)
        atn_scores = self.dropout(edge_softmax(g, atn_scores))

        g.ndata['src'] = self.msg_src(node_feats)
        g.ndata['dst'] = self.msg_dst(node_feats)
        g.apply_edges(fn.u_add_v('src', 'dst', 'e'))
        atn_inp = g.edata.pop('e') + self.msg_edg(edge_feats).unsqueeze(-2)
        atn_inp = self.act(atn_inp)
        g.edata['msg'] = atn_scores * atn_inp
        g.update_all(fn.copy_e('msg', 'm'), fn.sum('m', 'feat'))
        out = g.ndata.pop('feat') + self.wgt_n(node_feats)
        return self.act(out)


[docs]class PAGTNGNN(nn.Module): """Multilayer PAGTN model for updating node representations. PAGTN is introduced in `Path-Augmented Graph Transformer Network <https://arxiv.org/abs/1905.12712>`__. Parameters ---------- node_in_feats : int Size for the input node features. node_out_feats : int Size for the output node features. node_hid_feats : int Size for the hidden node features. edge_feats : int Size for the input edge features. depth : int Number of PAGTN layers to be applied. nheads : int Number of attention heads. dropout : float The probability for performing dropout. Default to 0.1 activation : callable Activation function to apply. Default to LeakyReLU. """ def __init__(self, node_in_feats, node_out_feats, node_hid_feats, edge_feats, depth=5, nheads=1, dropout=0.1, activation=nn.LeakyReLU(0.2)): super(PAGTNGNN, self).__init__() self.depth = depth self.nheads = nheads self.node_hid_feats = node_hid_feats self.atom_inp = nn.Linear(node_in_feats, node_hid_feats * nheads) self.model = nn.ModuleList([PAGTNLayer(node_hid_feats, node_hid_feats, edge_feats, dropout, activation) for _ in range(depth)]) self.atom_out = nn.Linear(node_in_feats + node_hid_feats * nheads, node_out_feats) self.act = activation
[docs] def reset_parameters(self): """Reinitialize learnable parameters.""" gain = nn.init.calculate_gain('relu') nn.init.xavier_normal_(self.atom_inp.weight, gain=gain) nn.init.xavier_normal_(self.atom_out.weight, gain=gain) self.model.reset_parameters()
[docs] def forward(self, g, node_feats, edge_feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes in the batch of graphs. edge_feats : float32 tensor of shape (E, edge_in_feats) Input edge features. E for the number of edges in the batch of graphs. Returns ------- float32 tensor of shape (V, node_out_feats) Updated node features. """ g = g.local_var() atom_input = self.atom_inp(node_feats).view(-1, self.nheads, self.node_hid_feats) atom_input = self.act(atom_input) atom_h = atom_input for i in range(self.depth): attn_h = self.model[i](g, atom_h, edge_feats) atom_h = torch.nn.functional.relu(attn_h + atom_input) atom_h = atom_h.view(-1, self.nheads*self.node_hid_feats) atom_output = torch.cat([node_feats, atom_h], dim=1) atom_h = torch.nn.functional.relu(self.atom_out(atom_output)) return atom_h