Source code for dgllife.model.gnn.nf

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Molecular Fingerprint
# pylint: disable= no-member, arguments-differ, invalid-name

import dgl
import dgl.function as fn
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import ModuleList

__all__ = ['NFGNN']

class NFLayer(nn.Module):
    r"""Single convolutional layer from `Convolutional Networks on Graphs for Learning Molecular
    Fingerprints <https://arxiv.org/abs/1509.09292>`__

    NF stands for neural fingerprint.

    Parameters
    ----------
    in_feats : int
        Number of input node features.
    out_feats : int
        Number of output node features.
    max_degree : int, optional
        The maximum node degree to consider when updating weights. Default to be 10.
    activation : activation function, optional
        Default to be None.
    batchnorm : bool, optional
        Whether to apply batch normalization to the output. Default to be True.
    dropout : float, optional
        The probability of dropout for the output. Default to be 0.
    """
    def __init__(self, in_feats, out_feats, max_degree=10, activation=None, batchnorm=True,
                 dropout=0.):
        super(NFLayer, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats

        self.activation = activation
        self.max_degree = max_degree
        self.lin_zero_deg = nn.Linear(in_feats, out_feats)
        self.lins_l = ModuleList([nn.Linear(in_feats, out_feats)
                                  for _ in range(1, max_degree + 1)])
        self.lins_r = ModuleList([nn.Linear(in_feats, out_feats, bias=False)
                                  for _ in range(1, max_degree + 1)])

        if batchnorm:
            self.bn = nn.BatchNorm1d(out_feats)
        else:
            self.bn = None
        if dropout > 0:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = None

        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize model parameters."""
        self.lin_zero_deg.reset_parameters()
        for lin in self.lins_l:
            lin.reset_parameters()
        for lin in self.lins_r:
            lin.reset_parameters()

        if self.bn is not None:
            self.bn.reset_parameters()

    def forward(self, g, g_self, feats, deg=None, max_deg=None, deg_membership=None):
        """Update node representations.

        Parameters
        ----------
        g : DGLGraph
            DGLGraph for a batch of graphs
        g_self : DGLGraph
            DGLGraph for a batch of graphs with self loops added
        feats : FloatTensor of shape (N, M1)
            * N is the total number of nodes in the batch of graphs
            * M1 is the input node feature size, which must match in_feats in initialization
        deg : LongTensor of shape (N), optional
            In-degrees of the nodes in the graph.
        max_deg : int, optional
            Max value in :attr:`deg`.
        deg_membership : list of LongTensor, optional
            deg_membership[i] gives a 1D LongTensor for the IDs of nodes with in-degree i.

        Returns
        -------
        new_feats : FloatTensor of shape (N, M2)
            * M2 is the output node feature size, which must match out_feats in initialization
        deg : LongTensor of shape (N)
            In-degrees of the nodes in the graph.
        max_deg : int
            Max value in :attr:`deg`.
        deg_membership : list of LongTensor
            deg_membership[i] gives a 1D LongTensor for the IDs of nodes with in-degree i.
        """
        if deg is None:
            deg = g.in_degrees().to(feats.device)
            deg = deg.clamp(max=self.max_degree)
            assert max_deg is None, 'Expect max_deg to be None when deg is None.'

        if max_deg is None:
            max_deg = deg.max().cpu().item()

        if deg_membership is None:
            deg_membership = [
                (deg == i).nonzero(as_tuple=False).view(-1) for i in range(max_deg + 1)
            ]

        with g.local_scope():
            # Message passing
            g.ndata['h'] = feats
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.sum('m', 'h'))
            h = g.ndata.pop('h')

            out = h.new_empty(list(h.size())[:-1] + [self.out_feats])

            # Case for degree 0
            idx = deg_membership[0]
            deg_out = self.lin_zero_deg(feats.index_select(0, idx))
            out.index_copy_(0, idx, deg_out)

            # Degree-based transformation
            for i, (lin_l, lin_r) in enumerate(zip(self.lins_l, self.lins_r)):
                current_deg = i + 1

                if current_deg > max_deg:
                    break

                idx = deg_membership[current_deg]
                deg_out = lin_l(h.index_select(0, idx)) + lin_r(feats.index_select(0, idx))
                out.index_copy_(0, idx, deg_out)

            if self.activation is not None:
                out = self.activation(out)

            if self.bn is not None:
                out = self.bn(out)

            if self.dropout is not None:
                out = self.dropout(out)

        with g_self.local_scope():
            g_self.ndata['h'] = out
            g_self.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.max('m', 'h'))
            out = g_self.ndata['h']

        return out, deg, max_deg, deg_membership

[docs]class NFGNN(nn.Module): r"""GNN from `Convolutional Networks on Graphs for Learning Molecular Fingerprints <https://arxiv.org/abs/1509.09292>`__ NF stands for neural fingerprint. Parameters ---------- in_feats : int Number of input node features. hidden_feats : list of int, optional ``hidden_feats[i]`` gives the size of node representations after the i-th NF layer. ``len(hidden_feats)`` equals the number of NF layers. By default, we use ``[64, 64]``. max_degree : int The maximum node degree to consider when updating weights. Default to be 10. activation : list of activation functions or None If not None, ``activation[i]`` gives the activation function to be used for the i-th NF layer. ``len(activation)`` equals the number of NF layers. By default, ReLU is applied for all NF layers. batchnorm : list of bool, optional ``batchnorm[i]`` decides if batch normalization is to be applied on the output of the i-th NF layer. ``len(batchnorm)`` equals the number of NF layers. By default, batch normalization is applied for all NF layers. dropout : list of float, optional ``dropout[i]`` decides the dropout to be applied on the output of the i-th NF layer. ``len(dropout)`` equals the number of NF layers. By default, dropout is not applied for all NF layers. """ def __init__(self, in_feats, hidden_feats=None, max_degree=10, activation=None, batchnorm=None, dropout=None): super(NFGNN, self).__init__() if hidden_feats is None: hidden_feats = [64, 64] n_layers = len(hidden_feats) if activation is None: activation = [F.relu] * n_layers if batchnorm is None: batchnorm = [True] * n_layers if dropout is None: dropout = [0.] * n_layers lengths = [len(hidden_feats), len(activation), len(batchnorm), len(dropout)] assert len(set(lengths)) == 1, 'Expect the lengths of hidden_feats, activation, ' \ 'batchnorm, and dropout to be the same, ' \ 'got {}'.format(lengths) self.hidden_feats = hidden_feats self.gnn_layers = nn.ModuleList() for i in range(n_layers): self.gnn_layers.append(NFLayer(in_feats, hidden_feats[i], max_degree, activation[i], batchnorm[i], dropout[i])) in_feats = hidden_feats[i]
[docs] def reset_parameters(self): """Reinitialize model parameters.""" for gnn in self.gnn_layers: gnn.reset_parameters()
[docs] def forward(self, g, feats): """Update node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs feats : FloatTensor of shape (N, M1) * N is the total number of nodes in the batch of graphs * M1 is the input node feature size, which equals in_feats in initialization Returns ------- feats : FloatTensor of shape (N, M2) * N is the total number of nodes in the batch of graphs * M2 is the output node representation size, which equals hidden_sizes[-1] in initialization. """ deg = None max_deg = None deg_membership = None g_self = dgl.add_self_loop(g) for gnn in self.gnn_layers: feats, deg, max_deg, deg_membership = gnn(g, g_self, feats, deg, max_deg, deg_membership) return feats