Source code for dgllife.model.gnn.gcn

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Graph Convolutional Networks
# pylint: disable= no-member, arguments-differ, invalid-name

import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import GraphConv

__all__ = ['GCN']

# pylint: disable=W0221, C0103
class GCNLayer(nn.Module):
    r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
    <https://arxiv.org/abs/1609.02907>`__

    Parameters
    ----------
    in_feats : int
        Number of input node features.
    out_feats : int
        Number of output node features.
    gnn_norm : str
        The message passing normalizer, which can be `'right'`, `'both'` or `'none'`. The
        `'right'` normalizer divides the aggregated messages by each node's in-degree.
        The `'both'` normalizer corresponds to the symmetric adjacency normalization in
        the original GCN paper. The `'none'` normalizer simply sums the messages.
        Default to be 'none'.
    activation : activation function
        Default to be None.
    residual : bool
        Whether to use residual connection, default to be True.
    batchnorm : bool
        Whether to use batch normalization on the output,
        default to be True.
    dropout : float
        The probability for dropout. Default to be 0., i.e. no
        dropout is performed.
    """
    def __init__(self, in_feats, out_feats, gnn_norm='none', activation=None,
                 residual=True, batchnorm=True, dropout=0.):
        super(GCNLayer, self).__init__()

        self.activation = activation
        self.graph_conv = GraphConv(in_feats=in_feats, out_feats=out_feats,
                                    norm=gnn_norm, activation=activation)
        self.dropout = nn.Dropout(dropout)

        self.residual = residual
        if residual:
            self.res_connection = nn.Linear(in_feats, out_feats)

        self.bn = batchnorm
        if batchnorm:
            self.bn_layer = nn.BatchNorm1d(out_feats)

    def reset_parameters(self):
        """Reinitialize model parameters."""
        self.graph_conv.reset_parameters()
        if self.residual:
            self.res_connection.reset_parameters()
        if self.bn:
            self.bn_layer.reset_parameters()

    def forward(self, g, feats):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        feats : FloatTensor of shape (N, M1)
            * N is the total number of nodes in the batch of graphs
            * M1 is the input node feature size, which must match in_feats in initialization

        Returns
        -------
        new_feats : FloatTensor of shape (N, M2)
            * M2 is the output node feature size, which must match out_feats in initialization
        """
        new_feats = self.graph_conv(g, feats)
        if self.residual:
            res_feats = self.activation(self.res_connection(feats))
            new_feats = new_feats + res_feats
        new_feats = self.dropout(new_feats)

        if self.bn:
            new_feats = self.bn_layer(new_feats)

        return new_feats

[docs]class GCN(nn.Module): r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks <https://arxiv.org/abs/1609.02907>`__ Parameters ---------- in_feats : int Number of input node features. hidden_feats : list of int ``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer. ``len(hidden_feats)`` equals the number of GCN layers. By default, we use ``[64, 64]``. gnn_norm : list of str ``gnn_norm[i]`` gives the message passing normalizer for the i-th GCN layer, which can be `'right'`, `'both'` or `'none'`. The `'right'` normalizer divides the aggregated messages by each node's in-degree. The `'both'` normalizer corresponds to the symmetric adjacency normalization in the original GCN paper. The `'none'` normalizer simply sums the messages. ``len(gnn_norm)`` equals the number of GCN layers. By default, we use ``['none', 'none']``. activation : list of activation functions or None If not None, ``activation[i]`` gives the activation function to be used for the i-th GCN layer. ``len(activation)`` equals the number of GCN layers. By default, ReLU is applied for all GCN layers. residual : list of bool ``residual[i]`` decides if residual connection is to be used for the i-th GCN layer. ``len(residual)`` equals the number of GCN layers. By default, residual connection is performed for each GCN layer. batchnorm : list of bool ``batchnorm[i]`` decides if batch normalization is to be applied on the output of the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default, batch normalization is applied for all GCN layers. dropout : list of float ``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer. ``len(dropout)`` equals the number of GCN layers. By default, no dropout is performed for all layers. """ def __init__(self, in_feats, hidden_feats=None, gnn_norm=None, activation=None, residual=None, batchnorm=None, dropout=None): super(GCN, self).__init__() if hidden_feats is None: hidden_feats = [64, 64] n_layers = len(hidden_feats) if gnn_norm is None: gnn_norm = ['none' for _ in range(n_layers)] if activation is None: activation = [F.relu for _ in range(n_layers)] if residual is None: residual = [True for _ in range(n_layers)] if batchnorm is None: batchnorm = [True for _ in range(n_layers)] if dropout is None: dropout = [0. for _ in range(n_layers)] lengths = [len(hidden_feats), len(gnn_norm), len(activation), len(residual), len(batchnorm), len(dropout)] assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, gnn_norm, ' \ 'activation, residual, batchnorm and dropout to ' \ 'be the same, got {}'.format(lengths) self.hidden_feats = hidden_feats self.gnn_layers = nn.ModuleList() for i in range(n_layers): self.gnn_layers.append(GCNLayer(in_feats, hidden_feats[i], gnn_norm[i], activation[i], residual[i], batchnorm[i], dropout[i])) in_feats = hidden_feats[i]
[docs] def reset_parameters(self): """Reinitialize model parameters.""" for gnn in self.gnn_layers: gnn.reset_parameters()
[docs] def forward(self, g, feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which equals in_feats in initialization Returns ------- feats : FloatTensor of shape (N, M2) * N is the total number of nodes in the batch of graphs * M2 is the output node representation size, which equals hidden_sizes[-1] in initialization. """ for gnn in self.gnn_layers: feats = gnn(g, feats) return feats