Source code for dgllife.model.gnn.gcn

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Graph Convolutional Networks
# pylint: disable= no-member, arguments-differ, invalid-name

import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv

__all__ = ["GCN"]

# pylint: disable=W0221, C0103
class GCNLayer(nn.Module):
    r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
    <https://arxiv.org/abs/1609.02907>`__

    Parameters
    ----------
    in_feats : int
        Number of input node features.
    out_feats : int
        Number of output node features.
    gnn_norm : str
        The message passing normalizer, which can be `'right'`, `'both'` or `'none'`. The
        `'right'` normalizer divides the aggregated messages by each node's in-degree.
        The `'both'` normalizer corresponds to the symmetric adjacency normalization in
        the original GCN paper. The `'none'` normalizer simply sums the messages.
        Default to be 'none'.
    activation : activation function
        Default to be None.
    residual : bool
        Whether to use residual connection, default to be True.
    batchnorm : bool
        Whether to use batch normalization on the output,
        default to be True.
    dropout : float
        The probability for dropout. Default to be 0., i.e. no
        dropout is performed.
    allow_zero_in_degree: bool
        Whether to allow zero in degree nodes in graph. Defaults to False.
    """

    def __init__(
        self,
        in_feats,
        out_feats,
        gnn_norm="none",
        activation=None,
        residual=True,
        batchnorm=True,
        dropout=0.0,
        allow_zero_in_degree=False,
    ):
        super(GCNLayer, self).__init__()

        self.activation = activation
        self.graph_conv = GraphConv(
            in_feats=in_feats,
            out_feats=out_feats,
            norm=gnn_norm,
            activation=activation,
            allow_zero_in_degree=allow_zero_in_degree,
        )
        self.dropout = nn.Dropout(dropout)

        self.residual = residual
        if residual:
            self.res_connection = nn.Linear(in_feats, out_feats)

        self.bn = batchnorm
        if batchnorm:
            self.bn_layer = nn.BatchNorm1d(out_feats)

    def reset_parameters(self):
        """Reinitialize model parameters."""
        self.graph_conv.reset_parameters()
        if self.residual:
            self.res_connection.reset_parameters()
        if self.bn:
            self.bn_layer.reset_parameters()

    def forward(self, g, feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        feats : FloatTensor of shape (N, M1)
            * N is the total number of nodes in the batch of graphs
            * M1 is the input node feature size, which must match in_feats in initialization

        Returns
        -------
        new_feats : FloatTensor of shape (N, M2)
            * M2 is the output node feature size, which must match out_feats in initialization
        """
        new_feats = self.graph_conv(g, feats)
        if self.residual:
            res_feats = self.activation(self.res_connection(feats))
            new_feats = new_feats + res_feats
        new_feats = self.dropout(new_feats)

        if self.bn:
            new_feats = self.bn_layer(new_feats)

        return new_feats


[docs]class GCN(nn.Module): r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/abs/1609.02907>`__ Parameters ---------- in_feats : int Number of input node features. hidden_feats : list of int ``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer. ``len(hidden_feats)`` equals the number of GCN layers. By default, we use ``[64, 64]``. gnn_norm : list of str ``gnn_norm[i]`` gives the message passing normalizer for the i-th GCN layer, which can be `'right'`, `'both'` or `'none'`. The `'right'` normalizer divides the aggregated messages by each node's in-degree. The `'both'` normalizer corresponds to the symmetric adjacency normalization in the original GCN paper. The `'none'` normalizer simply sums the messages. ``len(gnn_norm)`` equals the number of GCN layers. By default, we use ``['none', 'none']``. activation : list of activation functions or None If not None, ``activation[i]`` gives the activation function to be used for the i-th GCN layer. ``len(activation)`` equals the number of GCN layers. By default, ReLU is applied for all GCN layers. residual : list of bool ``residual[i]`` decides if residual connection is to be used for the i-th GCN layer. ``len(residual)`` equals the number of GCN layers. By default, residual connection is performed for each GCN layer. batchnorm : list of bool ``batchnorm[i]`` decides if batch normalization is to be applied on the output of the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default, batch normalization is applied for all GCN layers. dropout : list of float ``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer. ``len(dropout)`` equals the number of GCN layers. By default, no dropout is performed for all layers. allow_zero_in_degree: bool Whether to allow zero in degree nodes in graph for all layers. By default, will not allow zero in degree nodes. """ def __init__( self, in_feats, hidden_feats=None, gnn_norm=None, activation=None, residual=None, batchnorm=None, dropout=None, allow_zero_in_degree=None, ): super(GCN, self).__init__() if hidden_feats is None: hidden_feats = [64, 64] n_layers = len(hidden_feats) if gnn_norm is None: gnn_norm = ["none" for _ in range(n_layers)] if activation is None: activation = [F.relu for _ in range(n_layers)] if residual is None: residual = [True for _ in range(n_layers)] if batchnorm is None: batchnorm = [True for _ in range(n_layers)] if dropout is None: dropout = [0.0 for _ in range(n_layers)] lengths = [ len(hidden_feats), len(gnn_norm), len(activation), len(residual), len(batchnorm), len(dropout), ] assert len(set(lengths)) == 1, ( "Expect the lengths of hidden_feats, gnn_norm, " "activation, residual, batchnorm and dropout to " "be the same, got {}".format(lengths) ) self.hidden_feats = hidden_feats self.gnn_layers = nn.ModuleList() for i in range(n_layers): self.gnn_layers.append( GCNLayer( in_feats, hidden_feats[i], gnn_norm[i], activation[i], residual[i], batchnorm[i], dropout[i], allow_zero_in_degree=allow_zero_in_degree, ) ) in_feats = hidden_feats[i]
[docs] def reset_parameters(self): """Reinitialize model parameters.""" for gnn in self.gnn_layers: gnn.reset_parameters()
[docs] def forward(self, g, feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which equals in_feats in initialization Returns ------- feats : FloatTensor of shape (N, M2) * N is the total number of nodes in the batch of graphs * M2 is the output node representation size, which equals hidden_sizes[-1] in initialization. """ for gnn in self.gnn_layers: feats = gnn(g, feats) return feats