# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# pylint: disable=C0103, W0622, R1710, W0104, E1101, W0221, C0411
# Learning Deep Generative Models of Graphs
# https://arxiv.org/pdf/1803.03324.pdf
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from functools import partial
from rdkit import Chem
from torch.distributions import Categorical
__all__ = ['DGMG']
class MoleculeEnv(object):
"""MDP environment for generating molecules.
Parameters
----------
atom_types : list
E.g. ['C', 'N']
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
"""
def __init__(self, atom_types, bond_types):
super(MoleculeEnv, self).__init__()
self.atom_types = atom_types
self.bond_types = bond_types
self.atom_type_to_id = dict()
self.bond_type_to_id = dict()
for id, a_type in enumerate(atom_types):
self.atom_type_to_id[a_type] = id
for id, b_type in enumerate(bond_types):
self.bond_type_to_id[b_type] = id
def get_decision_sequence(self, mol, atom_order):
"""Extract a decision sequence with which DGMG can generate the
molecule with a specified atom order.
Parameters
----------
mol : Chem.rdchem.Mol
atom_order : list
Specifies a mapping between the original atom
indices and the new atom indices. In particular,
atom_order[i] is re-labeled as i.
Returns
-------
decisions : list
decisions[i] is a 2-tuple (i, j)
- If i = 0, j specifies either the type of the atom to add
self.atom_types[j] or termination with j = len(self.atom_types)
- If i = 1, j specifies either the type of the bond to add
self.bond_types[j] or termination with j = len(self.bond_types)
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
decisions = []
old2new = dict()
for new_id, old_id in enumerate(atom_order):
atom = mol.GetAtomWithIdx(old_id)
a_type = atom.GetSymbol()
decisions.append((0, self.atom_type_to_id[a_type]))
for bond in atom.GetBonds():
u = bond.GetBeginAtomIdx()
v = bond.GetEndAtomIdx()
if v == old_id:
u, v = v, u
if v in old2new:
decisions.append((1, self.bond_type_to_id[bond.GetBondType()]))
decisions.append((2, old2new[v]))
decisions.append((1, len(self.bond_types)))
old2new[old_id] = new_id
decisions.append((0, len(self.atom_types)))
return decisions
def reset(self, rdkit_mol=False):
"""Setup for generating a new molecule
Parameters
----------
rdkit_mol : bool
Whether to keep a Chem.rdchem.Mol object so
that we know what molecule is being generated
"""
self.dgl_graph = dgl.graph(([], []), idtype=torch.int32)
# If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
self.dgl_graph.set_n_initializer(dgl.frame.zero_initializer)
self.dgl_graph.set_e_initializer(dgl.frame.zero_initializer)
self.mol = None
if rdkit_mol:
# RWMol is a molecule class that is intended to be edited.
self.mol = Chem.RWMol(Chem.MolFromSmiles(''))
def num_atoms(self):
"""Get the number of atoms for the current molecule.
Returns
-------
int
"""
return self.dgl_graph.num_nodes()
def add_atom(self, type):
"""Add an atom of the specified type.
Parameters
----------
type : int
Should be in the range of [0, len(self.atom_types) - 1]
"""
self.dgl_graph.add_nodes(1)
if self.mol is not None:
self.mol.AddAtom(Chem.Atom(self.atom_types[type]))
def add_bond(self, u, v, type, bi_direction=True):
"""Add a bond of the specified type between atom u and v.
Parameters
----------
u : int
Index for the first atom
v : int
Index for the second atom
type : int
Index for the bond type
bi_direction : bool
Whether to add edges for both directions in the DGLGraph.
If not, we will only add the edge (u, v).
"""
if bi_direction:
self.dgl_graph.add_edges([u, v], [v, u])
else:
self.dgl_graph.add_edge(u, v)
if self.mol is not None:
self.mol.AddBond(u, v, self.bond_types[type])
def get_current_smiles(self):
"""Get the generated molecule in SMILES
Returns
-------
s : str
SMILES
"""
assert self.mol is not None, 'Expect a Chem.rdchem.Mol object initialized.'
s = Chem.MolToSmiles(self.mol)
return s
class GraphEmbed(nn.Module):
"""Compute a molecule representations out of atom representations.
Parameters
----------
node_hidden_size : int
Size of atom representation
"""
def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
Current molecule graph
Returns
-------
tensor of dtype float32 and shape (1, self.graph_hidden_size)
Computed representation for the current molecule graph
"""
if g.num_nodes() == 0:
# Use a zero tensor for an empty molecule.
return torch.zeros(1, self.graph_hidden_size)
else:
# Node features are stored as hv in ndata.
hvs = g.ndata['hv']
return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True)
class GraphProp(nn.Module):
"""Perform message passing over a molecule graph and update its atom representations.
Parameters
----------
num_prop_rounds : int
Number of message passing rounds for each time
node_hidden_size : int
Size of atom representation
edge_hidden_size : int
Size of bond representation
"""
def __init__(self, num_prop_rounds, node_hidden_size, edge_hidden_size):
super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds
# Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = []
self.reduce_funcs = []
node_update_funcs = []
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + edge_hidden_size,
self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""For an edge u->v, send a message concat([h_u, x_uv])
Parameters
----------
edges : batch of edges
Returns
-------
dict
Dictionary containing messages for the edge batch,
with the messages being tensors of shape (B, F1),
B for the number of edges and F1 for the message size.
"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round):
"""Aggregate messages.
Parameters
----------
nodes : batch of nodes
round : int
Update round
Returns
-------
dict
Dictionary containing aggregated messages for each node
in the batch, with the messages being tensors of shape
(B, F2), B for the number of nodes and F2 for the aggregated
message size
"""
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
# Make copies of original atom representations to match the
# number of messages.
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
def forward(self, g):
"""
Parameters
----------
g : DGLGraph
"""
if g.num_edges() == 0:
return
else:
for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv'])
class AddNode(nn.Module):
"""Stop or add an atom of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddNode, self).__init__()
self.env = env
n_node_types = len(env.atom_types)
self.graph_op = {'embed': graph_embed_func}
self.stop = n_node_types
self.add_node = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size, graph_embed_func.graph_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size, n_node_types + 1)
)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(n_node_types, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
self.dropout = nn.Dropout(p=dropout)
def _initialize_node_repr(self, g, node_type, graph_embed):
"""Initialize atom representation
Parameters
----------
g : DGLGraph
node_type : int
Index for the type of the new atom
graph_embed : tensor of dtype float32
Molecule representation
"""
num_nodes = g.num_nodes()
hv_init = torch.cat([
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1)
hv_init = self.dropout(hv_init)
hv_init = self.initialize_hv(hv_init)
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new atoms
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
logits = self.add_node(graph_embed).view(1, -1)
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if not stop:
self.env.add_atom(action)
self._initialize_node_repr(g, action, graph_embed)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop
class AddEdge(nn.Module):
"""Stop or add a bond of a particular type.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_embed_func : callable taking g as input
Function for computing molecule representation
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_embed_func, node_hidden_size, dropout):
super(AddEdge, self).__init__()
self.env = env
n_bond_types = len(env.bond_types)
self.stop = n_bond_types
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Sequential(
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size,
graph_embed_func.graph_hidden_size + node_hidden_size),
nn.Dropout(p=dropout),
nn.Linear(graph_embed_func.graph_hidden_size + node_hidden_size, n_bond_types + 1)
)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, action=None):
"""
Parameters
----------
action : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
Returns
-------
stop : bool
Whether we stop adding new bonds
action : int
The type for the new bond
"""
g = self.env.dgl_graph
graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.num_nodes() - 1].data['hv']
logits = self.add_edge(
torch.cat([graph_embed, src_embed], dim=1))
probs = F.softmax(logits, dim=1)
if action is None:
action = Categorical(probs).sample().item()
stop = bool(action == self.stop)
if self.compute_log_prob:
sample_log_prob = F.log_softmax(logits, dim=1)[:, action: action + 1]
self.log_prob.append(sample_log_prob)
return stop, action
class ChooseDestAndUpdate(nn.Module):
"""Choose the atom to connect for the new bond.
Parameters
----------
env : MoleculeEnv
Environment for generating molecules
graph_prop_func : callable taking g as input
Function for performing message passing
and updating atom representations
node_hidden_size : int
Size of atom representation
dropout : float
Probability for dropout
"""
def __init__(self, env, graph_prop_func, node_hidden_size, dropout):
super(ChooseDestAndUpdate, self).__init__()
self.env = env
n_bond_types = len(self.env.bond_types)
# To be used for one-hot encoding of bond type
self.bond_embedding = torch.eye(n_bond_types)
self.graph_op = {'prop': graph_prop_func}
self.choose_dest = nn.Sequential(
nn.Linear(2 * node_hidden_size + n_bond_types, 2 * node_hidden_size + n_bond_types),
nn.Dropout(p=dropout),
nn.Linear(2 * node_hidden_size + n_bond_types, 1)
)
def _initialize_edge_repr(self, g, src_list, dest_list, edge_embed):
"""Initialize bond representation
Parameters
----------
g : DGLGraph
src_list : list of int
source atoms for new bonds
dest_list : list of int
destination atoms for new bonds
edge_embed : 2D tensor of dtype float32
Embeddings for the new bonds
"""
g.edges[src_list, dest_list].data['he'] = edge_embed.expand(len(src_list), -1)
def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
if compute_log_prob:
self.log_prob = []
self.compute_log_prob = compute_log_prob
def forward(self, bond_type, dest):
"""
Parameters
----------
bond_type : int
The type for the new bond
dest : int or None
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
g = self.env.dgl_graph
src = g.num_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
edge_embed = self.bond_embedding[bond_type: bond_type + 1]
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand,
edge_embed.expand(src, -1)], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if dest is None:
dest = Categorical(dests_probs).sample().item()
if not g.has_edges_between(src, dest):
# For undirected graphs, we add edges for both directions
# so that we can perform graph propagation.
src_list = [src, dest]
dest_list = [dest, src]
self.env.add_bond(src, dest, bond_type)
self._initialize_edge_repr(g, src_list, dest_list, edge_embed)
# Perform message passing when new bonds are added.
self.graph_op['prop'](g)
if self.compute_log_prob:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
def weights_init(m):
'''Function to initialize weights for models
Code from https://gist.github.com/jeasinema/ed9236ce743c8efaf30fa2ff732749f5
Usage:
model = Model()
model.apply(weight_init)
'''
if isinstance(m, nn.Linear):
init.xavier_normal_(m.weight.data)
init.normal_(m.bias.data)
elif isinstance(m, nn.GRUCell):
for param in m.parameters():
if len(param.shape) >= 2:
init.orthogonal_(param.data)
else:
init.normal_(param.data)
def dgmg_message_weight_init(m):
"""Weight initialization for graph propagation module
These are suggested by the author. This should only be used for
the message passing functions, i.e. fe's in the paper.
"""
def _weight_init(m):
if isinstance(m, nn.Linear):
init.normal_(m.weight.data, std=1./10)
init.normal_(m.bias.data, std=1./10)
else:
raise ValueError('Expected the input to be of type nn.Linear!')
if isinstance(m, nn.ModuleList):
for layer in m:
layer.apply(_weight_init)
else:
m.apply(_weight_init)
[docs]class DGMG(nn.Module):
"""DGMG model
`Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
Users only need to initialize an instance of this class.
Parameters
----------
atom_types : list
E.g. ['C', 'N'].
bond_types : list
E.g. [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC].
node_hidden_size : int
Size of atom representation. Default to 128.
num_prop_rounds : int
Number of message passing rounds for each time. Default to 2.
dropout : float
Probability for dropout. Default to 0.2.
"""
def __init__(self, atom_types, bond_types, node_hidden_size=128,
num_prop_rounds=2, dropout=0.2):
super(DGMG, self).__init__()
self.env = MoleculeEnv(atom_types, bond_types)
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
# For one-hot encoding, edge_hidden_size is just the number of bond types
self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size, len(self.env.bond_types))
# Actions
self.add_node_agent = AddNode(
self.env, self.graph_embed, node_hidden_size, dropout)
self.add_edge_agent = AddEdge(
self.env, self.graph_embed, node_hidden_size, dropout)
self.choose_dest_agent = ChooseDestAndUpdate(
self.env, self.graph_prop, node_hidden_size, dropout)
# Weight initialization
self.init_weights()
[docs] def init_weights(self):
"""Initialize model weights"""
self.graph_embed.apply(weights_init)
self.graph_prop.apply(weights_init)
self.add_node_agent.apply(weights_init)
self.add_edge_agent.apply(weights_init)
self.choose_dest_agent.apply(weights_init)
self.graph_prop.message_funcs.apply(dgmg_message_weight_init)
[docs] def count_step(self):
"""Increment the step by 1."""
self.step_count += 1
[docs] def prepare_log_prob(self, compute_log_prob):
"""Setup for returning log likelihood
Parameters
----------
compute_log_prob : bool
Whether to compute log likelihood
"""
self.compute_log_prob = compute_log_prob
self.add_node_agent.prepare_log_prob(compute_log_prob)
self.add_edge_agent.prepare_log_prob(compute_log_prob)
self.choose_dest_agent.prepare_log_prob(compute_log_prob)
[docs] def add_node_and_update(self, a=None):
"""Decide if to add a new atom.
If a new atom should be added, update the graph.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_node_agent(a)
[docs] def add_edge_or_not(self, a=None):
"""Decide if to add a new bond.
Parameters
----------
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
return self.add_edge_agent(a)
[docs] def choose_dest_and_update(self, bond_type, a=None):
"""Choose destination and connect it to the latest atom.
Add edges for both directions and update the graph.
Parameters
----------
bond_type : int
The type of the new bond to add
a : None or int
If None, a new action will be sampled. If not None,
teacher forcing will be used to enforce the decision of the
corresponding action.
"""
self.count_step()
self.choose_dest_agent(bond_type, a)
[docs] def get_log_prob(self):
"""Compute the log likelihood for the decision sequence,
typically corresponding to the generation of a molecule.
Returns
-------
torch.tensor consisting of a float only
"""
return torch.cat(self.add_node_agent.log_prob).sum()\
+ torch.cat(self.add_edge_agent.log_prob).sum()\
+ torch.cat(self.choose_dest_agent.log_prob).sum()
[docs] def teacher_forcing(self, actions):
"""Generate a molecule according to a sequence of actions.
Parameters
----------
actions : list of 2-tuples of int
actions[t] gives (i, j), the action to execute by DGMG at timestep t.
- If i = 0, j specifies either the type of the atom to add or termination
- If i = 1, j specifies either the type of the bond to add or termination
- If i = 2, j specifies the destination atom id for the bond to add.
With the formulation of DGMG, j must be created before the decision.
"""
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
while not stop_node:
# A new atom was just added.
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
while not stop_edge:
# A new bond is to be added.
self.choose_dest_and_update(bond_type, a=actions[self.step_count][1])
stop_edge, bond_type = self.add_edge_or_not(a=actions[self.step_count][1])
stop_node = self.add_node_and_update(a=actions[self.step_count][1])
[docs] def rollout(self, max_num_steps):
"""Sample a molecule from the distribution learned by DGMG."""
stop_node = self.add_node_and_update()
while (not stop_node) and (self.step_count <= max_num_steps):
stop_edge, bond_type = self.add_edge_or_not()
if self.env.num_atoms() == 1:
stop_edge = True
while (not stop_edge) and (self.step_count <= max_num_steps):
self.choose_dest_and_update(bond_type)
stop_edge, bond_type = self.add_edge_or_not()
stop_node = self.add_node_and_update()
[docs] def forward(self, actions=None, rdkit_mol=False, compute_log_prob=False, max_num_steps=400):
"""
Parameters
----------
actions : list of 2-tuples or None.
If actions are not None, generate a molecule according to actions.
Otherwise, a molecule will be generated based on sampled actions.
rdkit_mol : bool
Whether to maintain a Chem.rdchem.Mol object. This brings extra
computational cost, but is necessary if we are interested in
learning the generated molecule.
compute_log_prob : bool
Whether to compute log likelihood
max_num_steps : int
Maximum number of steps allowed. This only comes into effect
during inference and prevents the model from not stopping.
Returns
-------
torch.tensor consisting of a float only, optional
The log likelihood for the actions taken
str, optional
The generated molecule in the form of SMILES
"""
# Initialize an empty molecule
self.step_count = 0
self.env.reset(rdkit_mol=rdkit_mol)
self.prepare_log_prob(compute_log_prob)
if actions is not None:
# A sequence of decisions is given, use teacher forcing
self.teacher_forcing(actions)
else:
# Sample a molecule from the distribution learned by DGMG
self.rollout(max_num_steps)
if compute_log_prob and rdkit_mol:
return self.get_log_prob(), self.env.get_current_smiles()
if compute_log_prob:
return self.get_log_prob()
if rdkit_mol:
return self.env.get_current_smiles()