# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# SchNet
# pylint: disable=C0103, C0111, W0621, W0221, E1102, E1101
import numpy as np
import torch
import torch.nn as nn
from dgl.nn.pytorch import CFConv
__all__ = ['SchNetGNN']
class RBFExpansion(nn.Module):
r"""Expand distances between nodes by radial basis functions.
.. math::
\exp(- \gamma * ||d - \mu||^2)
where :math:`d` is the distance between two nodes and :math:`\mu` helps centralizes
the distances. We use multiple centers evenly distributed in the range of
:math:`[\text{low}, \text{high}]` with the difference between two adjacent centers
being :math:`gap`.
The number of centers is decided by :math:`(\text{high} - \text{low}) / \text{gap}`.
Choosing fewer centers corresponds to reducing the resolution of the filter.
Parameters
----------
low : float
Smallest center. Default to 0.
high : float
Largest center. Default to 30.
gap : float
Difference between two adjacent centers. :math:`\gamma` will be computed as the
reciprocal of gap. Default to 0.1.
"""
def __init__(self, low=0., high=30., gap=0.1):
super(RBFExpansion, self).__init__()
num_centers = int(np.ceil((high - low) / gap))
self.centers = np.linspace(low, high, num_centers)
self.centers = nn.Parameter(torch.tensor(self.centers).float(), requires_grad=False)
self.gamma = 1 / gap
def reset_parameters(self):
"""Reinitialize model parameters."""
device = self.centers.device
self.centers = nn.Parameter(
self.centers.clone().detach().float(), requires_grad=False).to(device)
def forward(self, edge_dists):
"""Expand distances.
Parameters
----------
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
float32 tensor of shape (E, len(self.centers))
Expanded distances.
"""
radial = edge_dists - self.centers
coef = - self.gamma
return torch.exp(coef * (radial ** 2))
class Interaction(nn.Module):
"""Building block for SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This layer combines node and edge features in message passing and updates node
representations.
Parameters
----------
node_feats : int
Size for the input and output node features.
edge_in_feats : int
Size for the input edge features.
hidden_feats : int
Size for hidden representations.
"""
def __init__(self, node_feats, edge_in_feats, hidden_feats):
super(Interaction, self).__init__()
self.conv = CFConv(node_feats, edge_in_feats, hidden_feats, node_feats)
self.project_out = nn.Linear(node_feats, node_feats)
def reset_parameters(self):
"""Reinitialize model parameters."""
for layer in self.conv.project_edge:
if isinstance(layer, nn.Linear):
layer.reset_parameters()
self.conv.project_node.reset_parameters()
self.conv.project_out[0].reset_parameters()
self.project_out.reset_parameters()
def forward(self, g, node_feats, edge_feats):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_feats : float32 tensor of shape (V, node_feats)
Input node features, V for the number of nodes.
edge_feats : float32 tensor of shape (E, edge_in_feats)
Input edge features, E for the number of edges.
Returns
-------
float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats = self.conv(g, node_feats, edge_feats)
return self.project_out(node_feats)
[docs]class SchNetGNN(nn.Module):
"""SchNet.
SchNet is introduced in `SchNet: A continuous-filter convolutional neural network for
modeling quantum interactions <https://arxiv.org/abs/1706.08566>`__.
This class performs message passing in SchNet and returns the updated node representations.
Parameters
----------
node_feats : int
Size for node representations to learn. Default to 64.
hidden_feats : list of int
``hidden_feats[i]`` gives the size of hidden representations for the i-th interaction
layer. ``len(hidden_feats)`` equals the number of interaction layers.
Default to ``[64, 64, 64]``.
num_node_types : int
Number of node types to embed. Default to 100.
cutoff : float
Largest center in RBF expansion. Default to 30.
gap : float
Difference between two adjacent centers in RBF expansion. Default to 0.1.
"""
def __init__(self, node_feats=64, hidden_feats=None, num_node_types=100, cutoff=30., gap=0.1):
super(SchNetGNN, self).__init__()
if hidden_feats is None:
hidden_feats = [64, 64, 64]
self.embed = nn.Embedding(num_node_types, node_feats)
self.rbf = RBFExpansion(high=cutoff, gap=gap)
n_layers = len(hidden_feats)
self.gnn_layers = nn.ModuleList()
for i in range(n_layers):
self.gnn_layers.append(
Interaction(node_feats, len(self.rbf.centers), hidden_feats[i]))
[docs] def reset_parameters(self):
"""Reinitialize model parameters."""
self.embed.reset_parameters()
self.rbf.reset_parameters()
for layer in self.gnn_layers:
layer.reset_parameters()
[docs] def forward(self, g, node_types, edge_dists):
"""Performs message passing and updates node representations.
Parameters
----------
g : DGLGraph
DGLGraph for a batch of graphs.
node_types : int64 tensor of shape (V)
Node types to embed, V for the number of nodes.
edge_dists : float32 tensor of shape (E, 1)
Distances between end nodes of edges, E for the number of edges.
Returns
-------
node_feats : float32 tensor of shape (V, node_feats)
Updated node representations.
"""
node_feats = self.embed(node_types)
expanded_dists = self.rbf(edge_dists)
for gnn in self.gnn_layers:
node_feats = gnn(g, node_feats, expanded_dists)
return node_feats