# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Graph Convolutional Networks
# pylint: disable= no-member, arguments-differ, invalid-name
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv
__all__ = ["GCN"]
# pylint: disable=W0221, C0103
class GCNLayer(nn.Module):
r"""Single GCN layer from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
out_feats : int
Number of output node features.
gnn_norm : str
The message passing normalizer, which can be `'right'`, `'both'` or `'none'`. The
`'right'` normalizer divides the aggregated messages by each node's in-degree.
The `'both'` normalizer corresponds to the symmetric adjacency normalization in
the original GCN paper. The `'none'` normalizer simply sums the messages.
Default to be 'none'.
activation : activation function
Default to be None.
residual : bool
Whether to use residual connection, default to be True.
batchnorm : bool
Whether to use batch normalization on the output,
default to be True.
dropout : float
The probability for dropout. Default to be 0., i.e. no
dropout is performed.
allow_zero_in_degree: bool
Whether to allow zero in degree nodes in graph. Defaults to False.
"""
def __init__(
self,
in_feats,
out_feats,
gnn_norm="none",
activation=None,
residual=True,
batchnorm=True,
dropout=0.0,
allow_zero_in_degree=False,
):
super(GCNLayer, self).__init__()
self.activation = activation
self.graph_conv = GraphConv(
in_feats=in_feats,
out_feats=out_feats,
norm=gnn_norm,
activation=activation,
allow_zero_in_degree=allow_zero_in_degree,
)
self.dropout = nn.Dropout(dropout)
self.residual = residual
if residual:
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn = batchnorm
if batchnorm:
self.bn_layer = nn.BatchNorm1d(out_feats)
def reset_parameters(self):
"""Reinitialize model parameters."""
self.graph_conv.reset_parameters()
if self.residual:
self.res_connection.reset_parameters()
if self.bn:
self.bn_layer.reset_parameters()
def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which must match in_feats in initialization
Returns
-------
new_feats : FloatTensor of shape (N, M2)
* M2 is the output node feature size, which must match out_feats in initialization
"""
new_feats = self.graph_conv(g, feats)
if self.residual:
res_feats = self.activation(self.res_connection(feats))
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
if self.bn:
new_feats = self.bn_layer(new_feats)
return new_feats
[docs]class GCN(nn.Module):
r"""GCN from `Semi-Supervised Classification with Graph Convolutional Networks
<https://arxiv.org/abs/1609.02907>`__
Parameters
----------
in_feats : int
Number of input node features.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of node representations after the i-th GCN layer.
``len(hidden_feats)`` equals the number of GCN layers. By default, we use
``[64, 64]``.
gnn_norm : list of str
``gnn_norm[i]`` gives the message passing normalizer for the i-th GCN layer, which
can be `'right'`, `'both'` or `'none'`. The `'right'` normalizer divides the aggregated
messages by each node's in-degree. The `'both'` normalizer corresponds to the symmetric
adjacency normalization in the original GCN paper. The `'none'` normalizer simply sums
the messages. ``len(gnn_norm)`` equals the number of GCN layers. By default, we use
``['none', 'none']``.
activation : list of activation functions or None
If not None, ``activation[i]`` gives the activation function to be used for
the i-th GCN layer. ``len(activation)`` equals the number of GCN layers.
By default, ReLU is applied for all GCN layers.
residual : list of bool
``residual[i]`` decides if residual connection is to be used for the i-th GCN layer.
``len(residual)`` equals the number of GCN layers. By default, residual connection
is performed for each GCN layer.
batchnorm : list of bool
``batchnorm[i]`` decides if batch normalization is to be applied on the output of
the i-th GCN layer. ``len(batchnorm)`` equals the number of GCN layers. By default,
batch normalization is applied for all GCN layers.
dropout : list of float
``dropout[i]`` decides the dropout probability on the output of the i-th GCN layer.
``len(dropout)`` equals the number of GCN layers. By default, no dropout is
performed for all layers.
allow_zero_in_degree: bool
Whether to allow zero in degree nodes in graph for all layers. By default, will not
allow zero in degree nodes.
"""
def __init__(
self,
in_feats,
hidden_feats=None,
gnn_norm=None,
activation=None,
residual=None,
batchnorm=None,
dropout=None,
allow_zero_in_degree=None,
):
super(GCN, self).__init__()
if hidden_feats is None:
hidden_feats = [64, 64]
n_layers = len(hidden_feats)
if gnn_norm is None:
gnn_norm = ["none" for _ in range(n_layers)]
if activation is None:
activation = [F.relu for _ in range(n_layers)]
if residual is None:
residual = [True for _ in range(n_layers)]
if batchnorm is None:
batchnorm = [True for _ in range(n_layers)]
if dropout is None:
dropout = [0.0 for _ in range(n_layers)]
lengths = [
len(hidden_feats),
len(gnn_norm),
len(activation),
len(residual),
len(batchnorm),
len(dropout),
]
assert len(set(lengths)) == 1, (
"Expect the lengths of hidden_feats, gnn_norm, "
"activation, residual, batchnorm and dropout to "
"be the same, got {}".format(lengths)
)
self.hidden_feats = hidden_feats
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(
GCNLayer(
in_feats,
hidden_feats[i],
gnn_norm[i],
activation[i],
residual[i],
batchnorm[i],
dropout[i],
allow_zero_in_degree=allow_zero_in_degree,
)
)
in_feats = hidden_feats[i]
[docs] def reset_parameters(self):
"""Reinitialize model parameters."""
for gnn in self.gnn_layers:
gnn.reset_parameters()
[docs] def forward(self, g, feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
feats : FloatTensor of shape (N, M1)
* N is the total number of nodes in the batch of graphs
* M1 is the input node feature size, which equals in_feats in initialization
Returns
-------
feats : FloatTensor of shape (N, M2)
* N is the total number of nodes in the batch of graphs
* M2 is the output node representation size, which equals
hidden_sizes[-1] in initialization.
"""
for gnn in self.gnn_layers:
feats = gnn(g, feats)
return feats