Source code for dgllife.model.gnn.graphsage

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# GraphSAGE
# pylint: disable= no-member, arguments-differ, invalid-name

import torch.nn as nn
import torch.nn.functional as F

from dgl.nn.pytorch import SAGEConv

__all__ = ['GraphSAGE']

# pylint: disable=W0221, C0103
[docs]class GraphSAGE(nn.Module): r"""GraphSAGE from `Inductive Representation Learning on Large Graphs <https://arxiv.org/abs/1706.02216>`__ Parameters ---------- in_feats : int Number of input node features. hidden_feats : list of int ``hidden_feats[i]`` gives the size of node representations after the i-th GraphSAGE layer. ``len(hidden_feats)`` equals the number of GraphSAGE layers. By default, we use ``[64, 64]``. activation : list of activation functions or None If not None, ``activation[i]`` gives the activation function to be used for the i-th GraphSAGE layer. ``len(activation)`` equals the number of GraphSAGE layers. By default, ReLU is applied for all GraphSAGE layers. dropout : list of float or None ``dropout[i]`` decides the dropout probability on the output of the i-th GraphSAGE layer. ``len(dropout)`` equals the number of GraphSAGE layers. By default, no dropout is performed for all layers. aggregator_type : list of str ``aggregator_type[i]`` decides the aggregator type for the i-th GraphSAGE layer, which can be one of ``'mean'``, ``'gcn'``, ``'pool'``, ``'lstm'``. By default, we use ``'mean'`` for all layers. """ def __init__(self, in_feats, hidden_feats=None, activation=None, dropout=None, aggregator_type=None): super(GraphSAGE, self).__init__() if hidden_feats is None: hidden_feats = [64, 64] n_layers = len(hidden_feats) if activation is None: activation = [F.relu for _ in range(n_layers)] if dropout is None: dropout = [0. for _ in range(n_layers)] if aggregator_type is None: aggregator_type = ['mean' for _ in range(n_layers)] lengths = [len(hidden_feats), len(activation), len(dropout), len(aggregator_type)] assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, activation, ' \ 'dropout and aggregator_type to be the same, ' \ 'got {}'.format(lengths) self.hidden_feats = hidden_feats self.gnn_layers = nn.ModuleList() for i in range(n_layers): self.gnn_layers.append(SAGEConv(in_feats, hidden_feats[i], aggregator_type[i], dropout[i], activation[i])) in_feats = hidden_feats[i]
[docs] def reset_parameters(self): """Reinitialize model parameters.""" for gnn in self.gnn_layers: gnn.reset_parameters()
[docs] def forward(self, g, feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which equals in_feats in initialization Returns ------- feats : FloatTensor of shape (N, M2) * N is the total number of nodes in the batch of graphs * M2 is the output node representation size, which equals hidden_sizes[-1] in initialization. """ for gnn in self.gnn_layers: feats = gnn(g, feats) return feats