# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Graph Isomorphism Networks.
# pylint: disable= no-member, arguments-differ, invalid-name
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['GIN']
# pylint: disable=W0221, C0103
class GINLayer(nn.Module):
r"""Single Layer GIN from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
Parameters
----------
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
emb_dim : int
The size of each embedding vector.
batch_norm : bool
Whether to apply batch normalization to the output of message passing.
Default to True.
activation : None or callable
Activation function to apply to the output node representations.
Default to None.
"""
def __init__(self, num_edge_emb_list, emb_dim, batch_norm=True, activation=None):
super(GINLayer, self).__init__()
self.mlp = nn.Sequential(
nn.Linear(emb_dim, 2 * emb_dim),
nn.ReLU(),
nn.Linear(2 * emb_dim, emb_dim)
)
self.edge_embeddings = nn.ModuleList()
for num_emb in num_edge_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
self.edge_embeddings.append(emb_module)
if batch_norm:
self.bn = nn.BatchNorm1d(emb_dim)
else:
self.bn = None
self.activation = activation
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize model parameters."""
for layer in self.mlp:
if isinstance(layer, nn.Linear):
layer.reset_parameters()
for emb_module in self.edge_embeddings:
nn.init.xavier_uniform_(emb_module.weight.data)
if self.bn is not None:
self.bn.reset_parameters()
def forward(self, g, node_feats, categorical_edge_feats):
"""Update node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
node_feats : FloatTensor of shape (N, emb_dim)
* Input node features
* N is the total number of nodes in the batch of graphs
* emb_dim is the input node feature size, which must match emb_dim in initialization
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as len(self.edge_embeddings)
* E is the total number of edges in the batch of graphs
Returns
-------
node_feats : float32 tensor of shape (N, emb_dim)
Output node representations
"""
edge_embeds = []
for i, feats in enumerate(categorical_edge_feats):
edge_embeds.append(self.edge_embeddings[i](feats))
edge_embeds = torch.stack(edge_embeds, dim=0).sum(0)
g = g.local_var()
g.ndata['feat'] = node_feats
g.edata['feat'] = edge_embeds
g.update_all(fn.u_add_e('feat', 'feat', 'm'), fn.sum('m', 'feat'))
node_feats = self.mlp(g.ndata.pop('feat'))
if self.bn is not None:
node_feats = self.bn(node_feats)
if self.activation is not None:
node_feats = self.activation(node_feats)
return node_feats
[docs]class GIN(nn.Module):
r"""Graph Isomorphism Network from `Strategies for
Pre-training Graph Neural Networks <https://arxiv.org/abs/1905.12265>`__
This module is for updating node representations only.
Parameters
----------
num_node_emb_list : list of int
num_node_emb_list[i] gives the number of items to embed for the
i-th categorical node feature variables. E.g. num_node_emb_list[0] can be
the number of atom types and num_node_emb_list[1] can be the number of
atom chirality types.
num_edge_emb_list : list of int
num_edge_emb_list[i] gives the number of items to embed for the
i-th categorical edge feature variables. E.g. num_edge_emb_list[0] can be
the number of bond types and num_edge_emb_list[1] can be the number of
bond direction types.
num_layers : int
Number of GIN layers to use. Default to 5.
emb_dim : int
The size of each embedding vector. Default to 300.
JK : str
JK for jumping knowledge as in `Representation Learning on Graphs with
Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__. It decides
how we are going to combine the all-layer node representations for the final output.
There can be four options for this argument, ``concat``, ``last``, ``max`` and ``sum``.
Default to 'last'.
* ``'concat'``: concatenate the output node representations from all GIN layers
* ``'last'``: use the node representations from the last GIN layer
* ``'max'``: apply max pooling to the node representations across all GIN layers
* ``'sum'``: sum the output node representations from all GIN layers
dropout : float
Dropout to apply to the output of each GIN layer. Default to 0.5
"""
def __init__(self, num_node_emb_list, num_edge_emb_list,
num_layers=5, emb_dim=300, JK='last', dropout=0.5):
super(GIN, self).__init__()
self.num_layers = num_layers
self.JK = JK
self.dropout = nn.Dropout(dropout)
if num_layers < 2:
raise ValueError('Number of GNN layers must be greater '
'than 1, got {:d}'.format(num_layers))
self.node_embeddings = nn.ModuleList()
for num_emb in num_node_emb_list:
emb_module = nn.Embedding(num_emb, emb_dim)
self.node_embeddings.append(emb_module)
self.gnn_layers = nn.ModuleList()
for layer in range(num_layers):
if layer == num_layers - 1:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim))
else:
self.gnn_layers.append(GINLayer(num_edge_emb_list, emb_dim, activation=F.relu))
self.reset_parameters()
[docs] def reset_parameters(self):
"""Reinitialize model parameters."""
for emb_module in self.node_embeddings:
nn.init.xavier_uniform_(emb_module.weight.data)
for layer in self.gnn_layers:
layer.reset_parameters()
[docs] def forward(self, g, categorical_node_feats, categorical_edge_feats):
"""Update node representations
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs
categorical_node_feats : list of LongTensor of shape (N)
* Input categorical node features
* len(categorical_node_feats) should be the same as len(self.node_embeddings)
* N is the total number of nodes in the batch of graphs
categorical_edge_feats : list of LongTensor of shape (E)
* Input categorical edge features
* len(categorical_edge_feats) should be the same as
len(num_edge_emb_list) in the arguments
* E is the total number of edges in the batch of graphs
Returns
-------
final_node_feats : float32 tensor of shape (N, M)
Output node representations, N for the number of nodes and
M for output size. In particular, M will be emb_dim * (num_layers + 1)
if self.JK == 'concat' and emb_dim otherwise.
"""
node_embeds = []
for i, feats in enumerate(categorical_node_feats):
node_embeds.append(self.node_embeddings[i](feats))
node_embeds = torch.stack(node_embeds, dim=0).sum(0)
all_layer_node_feats = [node_embeds]
for layer in range(self.num_layers):
node_feats = self.gnn_layers[layer](g, all_layer_node_feats[layer],
categorical_edge_feats)
node_feats = self.dropout(node_feats)
all_layer_node_feats.append(node_feats)
if self.JK == 'concat':
final_node_feats = torch.cat(all_layer_node_feats, dim=1)
elif self.JK == 'last':
final_node_feats = all_layer_node_feats[-1]
elif self.JK == 'max':
all_layer_node_feats = [h.unsqueeze(0) for h in all_layer_node_feats]
final_node_feats = torch.max(torch.cat(all_layer_node_feats, dim=0), dim=0)[0]
elif self.JK == 'sum':
all_layer_node_feats = [h.unsqueeze(0) for h in all_layer_node_feats]
final_node_feats = torch.sum(torch.cat(all_layer_node_feats, dim=0), dim=0)
else:
return ValueError("Expect self.JK to be 'concat', 'last', "
"'max' or 'sum', got {}".format(self.JK))
return final_node_feats