Source code for dgllife.model.gnn.gat

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Graph Attention Networks
#
# pylint: disable= no-member, arguments-differ, invalid-name

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv

__all__ = ["GAT"]

# pylint: disable=W0221
class GATLayer(nn.Module):
    r"""Single GAT layer from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__

    Parameters
    ----------
    in_feats : int
        Number of input node features
    out_feats : int
        Number of output node features
    num_heads : int
        Number of attention heads
    feat_drop : float
        Dropout applied to the input features
    attn_drop : float
        Dropout applied to attention values of edges
    alpha : float
        Hyperparameter in LeakyReLU, which is the slope for negative values.
        Default to 0.2.
    residual : bool
        Whether to perform skip connection, default to True.
    agg_mode : str
        The way to aggregate multi-head attention results, can be either
        'flatten' for concatenating all-head results or 'mean' for averaging
        all head results.
    activation : activation function or None
        Activation function applied to the aggregated multi-head results, default to None.
    bias : bool
        Whether to use bias in the GAT layer.
    allow_zero_in_degree: bool
        Whether to allow zero in degree nodes in graph. Defaults to False.
    """

    def __init__(
        self,
        in_feats,
        out_feats,
        num_heads,
        feat_drop,
        attn_drop,
        alpha=0.2,
        residual=True,
        agg_mode="flatten",
        activation=None,
        bias=True,
        allow_zero_in_degree=False,
    ):
        super(GATLayer, self).__init__()

        self.gat_conv = GATConv(
            in_feats=in_feats,
            out_feats=out_feats,
            num_heads=num_heads,
            feat_drop=feat_drop,
            attn_drop=attn_drop,
            negative_slope=alpha,
            residual=residual,
            bias=bias,
            allow_zero_in_degree=allow_zero_in_degree,
        )
        assert agg_mode in ["flatten", "mean"]
        self.agg_mode = agg_mode
        self.activation = activation

    def reset_parameters(self):
        """Reinitialize model parameters."""
        self.gat_conv.reset_parameters()

    def forward(self, bg, feats):
        """Update node representations

        Parameters
        ----------
        bg : DGLGraph
            DGLGraph for a batch of graphs.
        feats : FloatTensor of shape (N, M1)
            * N is the total number of nodes in the batch of graphs
            * M1 is the input node feature size, which equals in_feats in initialization

        Returns
        -------
        feats : FloatTensor of shape (N, M2)
            * N is the total number of nodes in the batch of graphs
            * M2 is the output node representation size, which equals
              out_feats in initialization if self.agg_mode == 'mean' and
              out_feats * num_heads in initialization otherwise.
        """
        feats = self.gat_conv(bg, feats)
        if self.agg_mode == "flatten":
            feats = feats.flatten(1)
        else:
            feats = feats.mean(1)

        if self.activation is not None:
            feats = self.activation(feats)

        return feats


[docs]class GAT(nn.Module): r"""GAT from `Graph Attention Networks <https://arxiv.org/abs/1710.10903>`__ Parameters ---------- in_feats : int Number of input node features hidden_feats : list of int ``hidden_feats[i]`` gives the output size of an attention head in the i-th GAT layer. ``len(hidden_feats)`` equals the number of GAT layers. By default, we use ``[32, 32]``. num_heads : list of int ``num_heads[i]`` gives the number of attention heads in the i-th GAT layer. ``len(num_heads)`` equals the number of GAT layers. By default, we use 4 attention heads for each GAT layer. feat_drops : list of float ``feat_drops[i]`` gives the dropout applied to the input features in the i-th GAT layer. ``len(feat_drops)`` equals the number of GAT layers. By default, this will be zero for all GAT layers. attn_drops : list of float ``attn_drops[i]`` gives the dropout applied to attention values of edges in the i-th GAT layer. ``len(attn_drops)`` equals the number of GAT layers. By default, this will be zero for all GAT layers. alphas : list of float Hyperparameters in LeakyReLU, which are the slopes for negative values. ``alphas[i]`` gives the slope for negative value in the i-th GAT layer. ``len(alphas)`` equals the number of GAT layers. By default, this will be 0.2 for all GAT layers. residuals : list of bool ``residual[i]`` decides if residual connection is to be used for the i-th GAT layer. ``len(residual)`` equals the number of GAT layers. By default, residual connection is performed for each GAT layer. agg_modes : list of str The way to aggregate multi-head attention results for each GAT layer, which can be either 'flatten' for concatenating all-head results or 'mean' for averaging all-head results. ``agg_modes[i]`` gives the way to aggregate multi-head attention results for the i-th GAT layer. ``len(agg_modes)`` equals the number of GAT layers. By default, we flatten all-head results for each GAT layer. activations : list of activation function or None ``activations[i]`` gives the activation function applied to the aggregated multi-head results for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers. By default, no activation is applied for each GAT layer. biases : list of bool ``biases[i]`` gives whether to use bias for the i-th GAT layer. ``len(activations)`` equals the number of GAT layers. By default, we use bias for all GAT layers. allow_zero_in_degree: bool Whether to allow zero in degree nodes in graph for all layers. By default, will not allow zero in degree nodes. """ def __init__( self, in_feats, hidden_feats=None, num_heads=None, feat_drops=None, attn_drops=None, alphas=None, residuals=None, agg_modes=None, activations=None, biases=None, allow_zero_in_degree=False, ): super(GAT, self).__init__() if hidden_feats is None: hidden_feats = [32, 32] n_layers = len(hidden_feats) if num_heads is None: num_heads = [4 for _ in range(n_layers)] if feat_drops is None: feat_drops = [0.0 for _ in range(n_layers)] if attn_drops is None: attn_drops = [0.0 for _ in range(n_layers)] if alphas is None: alphas = [0.2 for _ in range(n_layers)] if residuals is None: residuals = [True for _ in range(n_layers)] if agg_modes is None: agg_modes = ["flatten" for _ in range(n_layers - 1)] agg_modes.append("mean") if activations is None: activations = [F.elu for _ in range(n_layers - 1)] activations.append(None) if biases is None: biases = [True for _ in range(n_layers)] lengths = [ len(hidden_feats), len(num_heads), len(feat_drops), len(attn_drops), len(alphas), len(residuals), len(agg_modes), len(activations), len(biases), ] assert len(set(lengths)) == 1, ( "Expect the lengths of hidden_feats, num_heads, " "feat_drops, attn_drops, alphas, residuals, " "agg_modes, activations, and biases to be the same, " "got {}".format(lengths) ) self.hidden_feats = hidden_feats self.num_heads = num_heads self.agg_modes = agg_modes self.gnn_layers = nn.ModuleList() for i in range(n_layers): self.gnn_layers.append( GATLayer( in_feats, hidden_feats[i], num_heads[i], feat_drops[i], attn_drops[i], alphas[i], residuals[i], agg_modes[i], activations[i], biases[i], allow_zero_in_degree=allow_zero_in_degree, ) ) if agg_modes[i] == "flatten": in_feats = hidden_feats[i] * num_heads[i] else: in_feats = hidden_feats[i]
[docs] def reset_parameters(self): """Reinitialize model parameters.""" for gnn in self.gnn_layers: gnn.reset_parameters()
[docs] def forward(self, g, feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which equals in_feats in initialization Returns ------- feats : FloatTensor of shape (N, M2) * N is the total number of nodes in the batch of graphs * M2 is the output node representation size, which equals hidden_sizes[-1] if agg_modes[-1] == 'mean' and hidden_sizes[-1] * num_heads[-1] otherwise. """ for gnn in self.gnn_layers: feats = gnn(g, feats) return feats