Source code for dgllife.model.gnn.wln

# -*- coding: utf-8 -*-
# Copyright, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pylint: disable= no-member, arguments-differ, invalid-name

import math
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import Parameter

__all__ = ['WLN']

class WLNLinear(nn.Module):
    r"""Linear layer for WLN

    Let stddev be

    .. math::
        \min(\frac{1.0}{\sqrt{in_feats}}, 0.1)

    The weight of the linear layer is initialized from a normal distribution
    with mean 0 and std as specified in stddev.

    in_feats : int
        Size for the input.
    out_feats : int
        Size for the output.
    bias : bool
        Whether bias will be added to the output. Default to True.
    def __init__(self, in_feats, out_feats, bias=True):
        super(WLNLinear, self).__init__()

        self.in_feats = in_feats
        self.out_feats = out_feats
        self.weight = Parameter(torch.Tensor(out_feats, in_feats))
        if bias:
            self.bias = Parameter(torch.Tensor(out_feats))
            self.register_parameter('bias', None)

    def reset_parameters(self):
        """Initialize model parameters."""
        stddev = min(1.0 / math.sqrt(self.in_feats), 0.1)
        nn.init.normal_(self.weight, std=stddev)
        if self.bias is not None:
            nn.init.constant_(self.bias, 0.0)

    def forward(self, feats):
        """Applies the layer.

        feats : float32 tensor of shape (N, *, in_feats)
            N for the number of samples, * for any additional dimensions.

        float32 tensor of shape (N, *, out_feats)
            Result of the layer.
        return F.linear(feats, self.weight, self.bias)

    def extra_repr(self):
        """Return a description of the layer."""
        return 'in_feats={}, out_feats={}, bias={}'.format(
            self.in_feats, self.out_feats, self.bias is not None

[docs]class WLN(nn.Module): """Weisfeiler-Lehman Network (WLN) WLN is introduced in `Predicting Organic Reaction Outcomes with Weisfeiler-Lehman Network <>`__. This class performs message passing and updates node representations. Parameters ---------- node_in_feats : int Size for the input node features. edge_in_feats : int Size for the input edge features. node_out_feats : int Size for the output node representations. Default to 300. n_layers : int Number of times for message passing. Note that same parameters are shared across n_layers message passing. Default to 3. project_in_feats : bool Whether to project input node features. If this is False, we expect node_in_feats to be the same as node_out_feats. Default to True. set_comparison : bool Whether to perform final node representation update mimicking set comparison. Default to True. """ def __init__(self, node_in_feats, edge_in_feats, node_out_feats=300, n_layers=3, project_in_feats=True, set_comparison=True): super(WLN, self).__init__() self.n_layers = n_layers self.project_in_feats = project_in_feats if project_in_feats: self.project_node_in_feats = nn.Sequential( WLNLinear(node_in_feats, node_out_feats, bias=False), nn.ReLU() ) else: assert node_in_feats == node_out_feats, \ 'Expect input node features to have the same size as that of output ' \ 'node features, got {:d} and {:d}'.format(node_in_feats, node_out_feats) self.project_node_in_feats = None self.project_concatenated_messages = nn.Sequential( WLNLinear(edge_in_feats + node_out_feats, node_out_feats), nn.ReLU() ) self.get_new_node_feats = nn.Sequential( WLNLinear(2 * node_out_feats, node_out_feats), nn.ReLU() ) self.set_comparison = set_comparison if set_comparison: self.project_edge_messages = WLNLinear(edge_in_feats, node_out_feats, bias=False) self.project_node_messages = WLNLinear(node_out_feats, node_out_feats, bias=False) self.project_self = WLNLinear(node_out_feats, node_out_feats, bias=False)
[docs] def reset_parameters(self): """Initialize model parameters.""" if self.project_node_in_feats is not None: self.project_node_in_feats[0].reset_parameters() self.project_concatenated_messages[0].reset_parameters() self.get_new_node_feats[0].reset_parameters() if self.set_comparison: self.project_edge_messages.reset_parameters() self.project_node_messages.reset_parameters() self.project_self.reset_parameters()
[docs] def forward(self, g, node_feats, edge_feats): """Performs message passing and updates node representations. Parameters ---------- g : DGLGraph DGLGraph for a batch of graphs node_feats : float32 tensor of shape (V, node_in_feats) Input node features. V for the number of nodes. edge_feats : float32 tensor of shape (E, edge_in_feats) Input edge features. E for the number of edges. Returns ------- float32 tensor of shape (V, node_out_feats) Updated node representations. """ if self.project_in_feats: node_feats = self.project_node_in_feats(node_feats) for _ in range(self.n_layers): g = g.local_var() if g.num_edges() > 0: # The following lines do not work for a graph without edges. g.ndata['hv'] = node_feats g.apply_edges(fn.copy_src('hv', 'he_src')) concat_edge_feats =[g.edata['he_src'], edge_feats], dim=1) g.edata['he'] = self.project_concatenated_messages(concat_edge_feats) g.update_all(fn.copy_edge('he', 'm'), fn.sum('m', 'hv_new')) node_feats = self.get_new_node_feats([node_feats, g.ndata['hv_new']], dim=1)) else: # If we don't have edges, above formula becomes very simple. # The sum over the neighbors is zero then. # Refer to equations in section S2.2 of # node_feats = self.get_new_node_feats([node_feats, node_feats*0], dim=1)) if not self.set_comparison: return node_feats else: if g.num_edges() > 0: # The following lines don't work for a graph without edges g = g.local_var() g.ndata['hv'] = self.project_node_messages(node_feats) g.edata['he'] = self.project_edge_messages(edge_feats) g.update_all(fn.u_mul_e('hv', 'he', 'm'), fn.sum('m', 'h_nbr')) h_self = self.project_self(node_feats) # (V, node_out_feats) return g.ndata['h_nbr'] * h_self else: # If the graph has no edges, the formula becomes very simple. # The sum over the neighbors is zero then. # Refer to equations in section S2.5 of # return torch.zeros((g.num_nodes(), self.project_self.out_feats), device=node_feats.device)