# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# pylint: disable= no-member, arguments-differ, invalid-name
#
# JTVAE
from functools import partial
import copy
import dgl
import dgl.function as fn
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.traversal import bfs_edges_generator, dfs_labeled_edges_generator
from ...data.jtvae import get_atom_featurizer_enc, get_bond_featurizer_enc
from ...utils.featurizers import ConcatFeaturizer, atom_type_one_hot, atom_degree_one_hot,\
atom_formal_charge_one_hot, atom_is_aromatic, bond_type_one_hot, bond_is_in_ring
from ...utils.jtvae.chemutils import enum_assemble, set_atommap, copy_edit_mol, attach_mols, \
decode_stereo, get_mol
from ...utils.mol_to_graph import mol_to_bigraph
__all__ = ['JTNNVAE']
MAX_NB = 8
MAX_DECODE_LEN = 100
class GRUMessage(nn.Module):
def __init__(self, hidden_size, msg_field='m'):
super(GRUMessage, self).__init__()
self.U_r = nn.Linear(hidden_size, hidden_size)
self.msg_field = msg_field
def forward(self, edges):
r_1 = edges.dst['src_x_r']
r_2 = self.U_r(edges.src['h'])
r = torch.sigmoid(r_1 + r_2)
return {self.msg_field: r * edges.src['h']}
class GRUUpdate(nn.Module):
def __init__(self, hidden_size):
super(GRUUpdate, self).__init__()
self.W_z = nn.Linear(2 * hidden_size, hidden_size)
self.W_h = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, node):
z = torch.sigmoid(self.W_z(
torch.cat([node.data['src_x'], node.data['sum_h']], dim=1)))
h_input = torch.cat([node.data['src_x'], node.data['sum_gated_h']], dim=1)
pre_h = torch.tanh(self.W_h(h_input))
new_h = (torch.tensor(1.0).to(z.device) - z) * node.data['sum_h'] + z * pre_h
return {'h': new_h}
# pylint: disable=R1710
def level_order(forest, roots):
device = forest.device
edges = list(bfs_edges_generator(forest, roots))
if len(edges) == 0:
return []
edges = [e.to(device) for e in edges]
edges_back = list(bfs_edges_generator(forest, roots, reverse=True))
edges_back = [e.to(device) for e in edges_back]
yield from reversed(edges_back)
yield from edges
def get_root_ids(graphs):
# Get the ID of the root nodes, the first node of all trees
batch_num_nodes = graphs.batch_num_nodes().cpu()
batch_num_nodes = torch.cat([torch.tensor([0]), batch_num_nodes], dim=0)
root_ids = torch.cumsum(batch_num_nodes, dim=0)[:-1]
return root_ids
class JTNNEncoder(nn.Module):
def __init__(self, vocab, hidden_size, embedding=None):
super(JTNNEncoder, self).__init__()
self.hidden_size = hidden_size
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(vocab.size(), hidden_size)
else:
self.embedding = embedding
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.gru_message = GRUMessage(hidden_size)
self.gru_update = GRUUpdate(hidden_size)
self.W = nn.Sequential(
nn.Linear(2 * hidden_size, hidden_size),
nn.ReLU()
)
def forward(self, tree_graphs):
device = tree_graphs.device
if 'x' not in tree_graphs.ndata:
tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid'])
tree_graphs.apply_edges(fn.copy_u('x', 'src_x'))
tree_graphs = tree_graphs.local_var()
line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False)
line_tree_graphs.ndata.update({
'src_x': tree_graphs.edata['src_x'],
'src_x_r': self.W_r(tree_graphs.edata['src_x']),
# Exploit the fact that the reduce function is a sum of incoming messages,
# and uncomputed messages are zero vectors.
'h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device),
'sum_h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device),
'sum_gated_h': torch.zeros(line_tree_graphs.num_nodes(), self.hidden_size).to(device)
})
# Get the ID of the root nodes, the first node of all trees
root_ids = get_root_ids(tree_graphs)
for eid in level_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)):
line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'),
reduce_func=fn.sum('h_nei', 'sum_h'))
line_tree_graphs.pull(v=eid, message_func=self.gru_message,
reduce_func=fn.sum('m', 'sum_gated_h'))
line_tree_graphs.apply_nodes(self.gru_update, v=eid)
# Readout
tree_graphs.ndata['h'] = torch.zeros(tree_graphs.num_nodes(), self.hidden_size).to(device)
tree_graphs.edata['h'] = line_tree_graphs.ndata['h']
root_ids = root_ids.to(device)
tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype),
message_func=fn.copy_e('h', 'm'),
reduce_func=fn.sum('m', 'h'))
root_vec = torch.cat([
tree_graphs.ndata['x'][root_ids],
tree_graphs.ndata['h'][root_ids]
], dim=1)
root_vec = self.W(root_vec)
return tree_graphs.edata['h'], root_vec
def have_slots(fa_slots, ch_slots):
if len(fa_slots) > 2 and len(ch_slots) > 2:
return True
matches = []
for i, s1 in enumerate(fa_slots):
a1, c1, h1 = s1
for j, s2 in enumerate(ch_slots):
a2, c2, h2 = s2
if a1 == a2 and c1 == c2 and (a1 != "C" or h1 + h2 >= 4):
matches.append((i, j))
if len(matches) == 0:
return False
fa_match, ch_match = zip(*matches)
if len(set(fa_match)) == 1 and 1 < len(fa_slots) <= 2: #never remove atom from ring
fa_slots.pop(fa_match[0])
if len(set(ch_match)) == 1 and 1 < len(ch_slots) <= 2: #never remove atom from ring
ch_slots.pop(ch_match[0])
return True
def can_assemble(node_x, node_y):
neis = node_x['neighbors'] + [node_y]
for i, nei in enumerate(neis):
nei['nid'] = i
neighbors = [nei for nei in neis if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in neis if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cands = enum_assemble(node_x, neighbors)
return len(cands) > 0
def dfs_order(forest, roots):
edges = dfs_labeled_edges_generator(forest, roots, has_reverse_edge=True)
for e, l in zip(*edges):
# Exploit the fact that the reverse edge ID equals to 1 xor forward
# edge ID. Normally, this should be done using find_edges().
yield e ^ l, l
def mol_tree_node(smiles, wid=None, idx=None, nbrs=None):
if nbrs is None:
nbrs = []
return {'smiles': smiles, 'mol': get_mol(smiles), 'wid': wid, 'idx': idx, 'neighbors': nbrs}
def gru_functional(x, h_nei, wz, wr, ur, wh):
hidden_size = x.size()[-1]
sum_h = h_nei.sum(dim=1)
z_input = torch.cat([x, sum_h], dim=1)
z = torch.sigmoid(wz(z_input))
r_1 = wr(x).view(-1, 1, hidden_size)
r_2 = ur(h_nei)
r = torch.sigmoid(r_1 + r_2)
gated_h = r * h_nei
sum_gated_h = gated_h.sum(dim=1)
h_input = torch.cat([x, sum_gated_h], dim=1)
pre_h = torch.tanh(wh(h_input))
new_h = (1.0 - z) * sum_h + z * pre_h
return new_h
class JTNNDecoder(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, embedding=None):
super(JTNNDecoder, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab.size()
self.vocab = vocab
if embedding is None:
self.embedding = nn.Embedding(self.vocab_size, hidden_size)
else:
self.embedding = embedding
# GRU Weights
self.W_r = nn.Linear(hidden_size, hidden_size, bias=False)
self.gru_message = GRUMessage(hidden_size)
self.gru_update = GRUUpdate(hidden_size)
# Feature Aggregate Weights
self.W = nn.Linear(latent_size + hidden_size, hidden_size)
self.U = nn.Linear(latent_size + 2 * hidden_size, hidden_size)
# Output Weights
self.W_o = nn.Linear(hidden_size, self.vocab_size)
self.U_s = nn.Linear(hidden_size, 1)
# Loss Functions
self.pred_loss = nn.CrossEntropyLoss(reduction='sum')
self.stop_loss = nn.BCEWithLogitsLoss(reduction='sum')
def forward(self, tree_graphs, tree_vec):
device = tree_vec.device
batch_size = tree_graphs.batch_size
root_ids = get_root_ids(tree_graphs)
if 'x' not in tree_graphs.ndata:
tree_graphs.ndata['x'] = self.embedding(tree_graphs.ndata['wid'])
if 'src_x' not in tree_graphs.edata:
tree_graphs.apply_edges(fn.copy_u('x', 'src_x'))
tree_graphs = tree_graphs.local_var()
tree_graphs.apply_edges(func=lambda edges: {'dst_wid': edges.dst['wid']})
line_tree_graphs = dgl.line_graph(tree_graphs, backtracking=False, shared=True)
line_num_nodes = line_tree_graphs.num_nodes()
line_tree_graphs.ndata.update({
'src_x_r': self.W_r(line_tree_graphs.ndata['src_x']),
# Exploit the fact that the reduce function is a sum of incoming messages,
# and uncomputed messages are zero vectors.
'h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
'vec': dgl.broadcast_edges(tree_graphs, tree_vec),
'sum_h': torch.zeros(line_num_nodes, self.hidden_size).to(device),
'sum_gated_h': torch.zeros(line_num_nodes, self.hidden_size).to(device)
})
# input tensors for stop prediction (p) and label prediction (q)
pred_hiddens, pred_mol_vecs, pred_targets = [], [], []
stop_hiddens, stop_targets = [], []
# Predict root
pred_hiddens.append(torch.zeros(batch_size, self.hidden_size).to(device))
pred_targets.append(tree_graphs.ndata['wid'][root_ids.to(device)])
pred_mol_vecs.append(tree_vec)
# Traverse the tree and predict on children
for eid, p in dfs_order(tree_graphs, root_ids.to(dtype=tree_graphs.idtype)):
eid = eid.to(device=device, dtype=tree_graphs.idtype)
p = p.to(device=device, dtype=tree_graphs.idtype)
# Message passing excluding the target
line_tree_graphs.pull(v=eid, message_func=fn.copy_u('h', 'h_nei'),
reduce_func=fn.sum('h_nei', 'sum_h'))
line_tree_graphs.pull(v=eid, message_func=self.gru_message,
reduce_func=fn.sum('m', 'sum_gated_h'))
line_tree_graphs.apply_nodes(self.gru_update, v=eid)
# Node aggregation including the target
# By construction, the edges of the raw graph follow the order of
# (i1, j1), (j1, i1), (i2, j2), (j2, i2), ... The order of the nodes
# in the line graph corresponds to the order of the edges in the raw graph.
eid = eid.long()
reverse_eid = torch.bitwise_xor(eid, torch.tensor(1).to(device))
cur_o = line_tree_graphs.ndata['sum_h'][eid] + \
line_tree_graphs.ndata['h'][reverse_eid]
# Gather targets
mask = (p == torch.tensor(0).to(device))
pred_list = eid[mask]
stop_target = torch.tensor(1).to(device) - p
# Hidden states for stop prediction
stop_hidden = torch.cat([line_tree_graphs.ndata['src_x'][eid],
cur_o, line_tree_graphs.ndata['vec'][eid]], dim=1)
stop_hiddens.append(stop_hidden)
stop_targets.extend(stop_target)
#Hidden states for clique prediction
if len(pred_list) > 0:
pred_mol_vecs.append(line_tree_graphs.ndata['vec'][pred_list])
pred_hiddens.append(line_tree_graphs.ndata['h'][pred_list])
pred_targets.append(line_tree_graphs.ndata['dst_wid'][pred_list])
#Last stop at root
root_ids = root_ids.to(device)
cur_x = tree_graphs.ndata['x'][root_ids]
tree_graphs.edata['h'] = line_tree_graphs.ndata['h']
tree_graphs.pull(v=root_ids.to(dtype=tree_graphs.idtype),
message_func=fn.copy_e('h', 'm'), reduce_func=fn.sum('m', 'cur_o'))
stop_hidden = torch.cat([cur_x, tree_graphs.ndata['cur_o'][root_ids], tree_vec], dim=1)
stop_hiddens.append(stop_hidden)
stop_targets.extend(torch.zeros(batch_size).to(device))
# Predict next clique
pred_hiddens = torch.cat(pred_hiddens, dim=0)
pred_mol_vecs = torch.cat(pred_mol_vecs, dim=0)
pred_vecs = torch.cat([pred_hiddens, pred_mol_vecs], dim=1)
pred_vecs = F.relu(self.W(pred_vecs))
pred_scores = self.W_o(pred_vecs)
pred_targets = torch.cat(pred_targets, dim=0)
pred_loss = self.pred_loss(pred_scores, pred_targets) / batch_size
_, preds = torch.max(pred_scores, dim=1)
pred_acc = torch.eq(preds, pred_targets).float()
pred_acc = torch.sum(pred_acc) / pred_targets.nelement()
# Predict stop
stop_hiddens = torch.cat(stop_hiddens, dim=0)
stop_vecs = F.relu(self.U(stop_hiddens))
stop_scores = self.U_s(stop_vecs).squeeze()
stop_targets = torch.Tensor(stop_targets).to(device)
stop_loss = self.stop_loss(stop_scores, stop_targets) / batch_size
stops = torch.ge(stop_scores, 0).float()
stop_acc = torch.eq(stops, stop_targets).float()
stop_acc = torch.sum(stop_acc) / stop_targets.nelement()
return pred_loss, stop_loss, pred_acc.item(), stop_acc.item()
def decode(self, mol_vec, prob_decode):
device = mol_vec.device
stack = []
init_hidden = torch.zeros(1, self.hidden_size).to(device)
zero_pad = torch.zeros(1, 1, self.hidden_size).to(device)
# Root Prediction
root_hidden = torch.cat([init_hidden, mol_vec], dim=1)
root_hidden = F.relu(self.W(root_hidden))
root_score = self.W_o(root_hidden)
_, root_wid = torch.max(root_score, dim=1)
root_wid = root_wid.item()
root = mol_tree_node(smiles=self.vocab.get_smiles(root_wid), wid=root_wid, idx=0)
stack.append((root, self.vocab.get_slots(root['wid'])))
all_nodes = [root]
h = {}
for step in range(MAX_DECODE_LEN):
node_x, fa_slot = stack[-1]
cur_h_nei = [h[(node_y['idx'], node_x['idx'])] for node_y in node_x['neighbors']]
if len(cur_h_nei) > 0:
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size)
else:
cur_h_nei = zero_pad
cur_x = torch.LongTensor([node_x['wid']]).to(device)
cur_x = self.embedding(cur_x)
# Predict stop
cur_h = cur_h_nei.sum(dim=1)
stop_hidden = torch.cat([cur_x, cur_h, mol_vec], dim=1)
stop_hidden = F.relu(self.U(stop_hidden))
stop_score = torch.sigmoid(self.U_s(stop_hidden) * 20).squeeze()
if prob_decode:
backtrack = (torch.bernoulli(1.0 - stop_score.data)[0] == 1)
else:
backtrack = (stop_score.item() < 0.5)
if not backtrack: # Forward: Predict next clique
new_h = gru_functional(cur_x, cur_h_nei, self.gru_update.W_z, self.W_r,
self.gru_message.U_r, self.gru_update.W_h)
pred_hidden = torch.cat([new_h, mol_vec], dim=1)
pred_hidden = F.relu(self.W(pred_hidden))
pred_score = torch.softmax(self.W_o(pred_hidden) * 20, dim=1)
if prob_decode:
sort_wid = torch.multinomial(pred_score.data.squeeze(), 5)
else:
_, sort_wid = torch.sort(pred_score, dim=1, descending=True)
sort_wid = sort_wid.data.squeeze()
next_wid = None
for wid in sort_wid[:5]:
slots = self.vocab.get_slots(wid)
node_y = mol_tree_node(smiles=self.vocab.get_smiles(wid))
if have_slots(fa_slot, slots) and can_assemble(node_x, node_y):
next_wid = wid
next_slots = slots
break
if next_wid is None:
backtrack = True # No more children can be added
else:
node_y = mol_tree_node(smiles=self.vocab.get_smiles(next_wid),
wid=next_wid, idx=step + 1, nbrs=[node_x])
h[(node_x['idx'], node_y['idx'])] = new_h[0]
stack.append((node_y, next_slots))
all_nodes.append(node_y)
if backtrack: # Backtrack, use if instead of else
if len(stack) == 1:
break # At root, terminate
node_fa, _ = stack[-2]
cur_h_nei = [h[(node_y['idx'], node_x['idx'])] for node_y in node_x['neighbors']
if node_y['idx'] != node_fa['idx']]
if len(cur_h_nei) > 0:
cur_h_nei = torch.stack(cur_h_nei, dim=0).view(1, -1, self.hidden_size)
else:
cur_h_nei = zero_pad
new_h = gru_functional(cur_x, cur_h_nei, self.gru_update.W_z, self.W_r,
self.gru_message.U_r, self.gru_update.W_h)
h[(node_x['idx'], node_fa['idx'])] = new_h[0]
node_fa['neighbors'].append(node_x)
stack.pop()
return root, all_nodes
class MPN(nn.Module):
def __init__(self, hidden_size, depth, in_node_feats=39, in_edge_feats=50):
super(MPN, self).__init__()
self.W_i = nn.Linear(in_edge_feats, hidden_size, bias=False)
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_o = nn.Sequential(
nn.Linear(in_node_feats + hidden_size, hidden_size),
nn.ReLU()
)
self.depth = depth
def forward(self, mol_graph):
mol_graph = mol_graph.local_var()
line_mol_graph = dgl.line_graph(mol_graph, backtracking=False)
line_input = self.W_i(mol_graph.edata['x'])
line_mol_graph.ndata['msg_input'] = line_input
line_mol_graph.ndata['msg'] = F.relu(line_input)
# Message passing over the line graph
for _ in range(self.depth - 1):
line_mol_graph.update_all(message_func=fn.copy_u('msg', 'msg'),
reduce_func=fn.sum('msg', 'nei_msg'))
nei_msg = self.W_h(line_mol_graph.ndata['nei_msg'])
line_mol_graph.ndata['msg'] = F.relu(line_input + nei_msg)
# Message passing over the raw graph
mol_graph.edata['msg'] = line_mol_graph.ndata['msg']
mol_graph.update_all(message_func=fn.copy_e('msg', 'msg'),
reduce_func=fn.sum('msg', 'nei_msg'))
raw_input = torch.cat([mol_graph.ndata['x'], mol_graph.ndata['nei_msg']], dim=1)
mol_graph.ndata['atom_hiddens'] = self.W_o(raw_input)
# Readout
mol_vecs = dgl.mean_nodes(mol_graph, 'atom_hiddens')
return mol_vecs
def index_select_ND(source, dim, index):
index_size = index.size()
suffix_dim = source.size()[1:]
final_size = index_size + suffix_dim
target = source.index_select(dim, index.view(-1))
return target.view(final_size)
class JTMPN(nn.Module):
def __init__(self, hidden_size, depth, in_node_feats=35, in_edge_feats=40):
super(JTMPN, self).__init__()
self.hidden_size = hidden_size
self.depth = depth
self.W_i = nn.Linear(in_edge_feats, hidden_size, bias=False)
self.W_h = nn.Linear(hidden_size, hidden_size, bias=False)
self.W_o = nn.Linear(in_node_feats + hidden_size, hidden_size)
self.atom_featurizer = ConcatFeaturizer([
partial(atom_type_one_hot,
allowable_set=['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na',
'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn'],
encode_unknown=True),
partial(atom_degree_one_hot, allowable_set=[0, 1, 2, 3, 4], encode_unknown=True),
partial(atom_formal_charge_one_hot, allowable_set=[-1, -2, 1, 2],
encode_unknown=True),
atom_is_aromatic
])
self.bond_featurizer = ConcatFeaturizer([bond_type_one_hot, bond_is_in_ring])
def forward(self, cand_batch, tree_mess, device='cpu'):
fatoms, fbonds = [], []
in_bonds, all_bonds = [], []
# Ensure index 0 is vec(0)
mess_dict, all_mess = {}, [torch.zeros(self.hidden_size).to(device)]
total_atoms = 0
scope = []
for e, vec in tree_mess.items():
mess_dict[e] = len(all_mess)
all_mess.append(vec)
for mol, all_nodes, _ in cand_batch:
n_atoms = mol.GetNumAtoms()
for atom in mol.GetAtoms():
fatoms.append(torch.Tensor(self.atom_featurizer(atom)))
in_bonds.append([])
for bond in mol.GetBonds():
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
x = a1.GetIdx() + total_atoms
y = a2.GetIdx() + total_atoms
# Here x_nid,y_nid could be 0
x_nid, y_nid = a1.GetAtomMapNum(), a2.GetAtomMapNum()
x_bid = all_nodes[x_nid - 1]['idx'] if x_nid > 0 else -1
y_bid = all_nodes[y_nid - 1]['idx'] if y_nid > 0 else -1
bfeature = torch.Tensor(self.bond_featurizer(bond))
b = len(all_mess) + len(all_bonds) # bond idx offseted by len(all_mess)
all_bonds.append((x, y))
fbonds.append(torch.cat([fatoms[x], bfeature], 0))
in_bonds[y].append(b)
b = len(all_mess) + len(all_bonds)
all_bonds.append((y, x))
fbonds.append(torch.cat([fatoms[y], bfeature], 0))
in_bonds[x].append(b)
if x_bid >= 0 and y_bid >= 0 and x_bid != y_bid:
if (x_bid, y_bid) in mess_dict:
mess_idx = mess_dict[(x_bid, y_bid)]
in_bonds[y].append(mess_idx)
if (y_bid, x_bid) in mess_dict:
mess_idx = mess_dict[(y_bid, x_bid)]
in_bonds[x].append(mess_idx)
scope.append((total_atoms, n_atoms))
total_atoms += n_atoms
total_bonds = len(all_bonds)
total_mess = len(all_mess)
fatoms = torch.stack(fatoms, 0).to(device)
fbonds = torch.stack(fbonds, 0).to(device)
agraph = torch.zeros(total_atoms, MAX_NB).long().to(device)
bgraph = torch.zeros(total_bonds, MAX_NB).long().to(device)
tree_message = torch.stack(all_mess, dim=0)
for a in range(total_atoms):
for i, b in enumerate(in_bonds[a]):
if i == MAX_NB:
break
agraph[a, i] = b
for b1 in range(total_bonds):
x, y = all_bonds[b1]
for i, b2 in enumerate(in_bonds[x]): # b2 is offseted by len(all_mess)
if i == MAX_NB:
break
if b2 < total_mess or all_bonds[b2 - total_mess][0] != y:
bgraph[b1, i] = b2
binput = self.W_i(fbonds)
graph_message = F.relu(binput)
for i in range(self.depth - 1):
message = torch.cat([tree_message, graph_message], dim=0)
nei_message = index_select_ND(message, 0, bgraph)
nei_message = nei_message.sum(dim=1)
nei_message = self.W_h(nei_message)
graph_message = F.relu(binput + nei_message)
message = torch.cat([tree_message, graph_message], dim=0)
nei_message = index_select_ND(message, 0, agraph)
nei_message = nei_message.sum(dim=1)
ainput = torch.cat([fatoms, nei_message], dim=1)
atom_hiddens = F.relu(self.W_o(ainput))
mol_vecs = []
for st, le in scope:
mol_vec = atom_hiddens.narrow(0, st, le).sum(dim=0) / le
mol_vecs.append(mol_vec)
mol_vecs = torch.stack(mol_vecs, dim=0)
return mol_vecs
[docs]class JTNNVAE(nn.Module):
# TODO
def __init__(self, vocab, hidden_size, latent_size, depth, stereo=True):
super(JTNNVAE, self).__init__()
self.vocab = vocab
self.hidden_size = hidden_size
self.latent_size = latent_size
self.depth = depth
self.embedding = nn.Embedding(self.vocab.size(), hidden_size)
self.jtnn = JTNNEncoder(self.vocab, hidden_size, self.embedding)
self.jtmpn = JTMPN(hidden_size, depth)
self.mpn = MPN(hidden_size, depth)
self.decoder = JTNNDecoder(self.vocab, hidden_size, latent_size // 2, self.embedding)
self.T_mean = nn.Linear(hidden_size, latent_size // 2)
self.T_var = nn.Linear(hidden_size, latent_size // 2)
self.G_mean = nn.Linear(hidden_size, latent_size // 2)
self.G_var = nn.Linear(hidden_size, latent_size // 2)
self.assm_loss = nn.CrossEntropyLoss(reduction='sum')
self.use_stereo = stereo
if stereo:
self.stereo_loss = nn.CrossEntropyLoss(reduction='sum')
self.atom_featurizer = get_atom_featurizer_enc()
self.bond_featurizer = get_bond_featurizer_enc()
def reset_parameters(self):
for param in self.parameters():
if param.dim() == 1:
nn.init.constant_(param.data, 0)
else:
nn.init.xavier_normal_(param.data)
def encode(self, batch_tree_graphs, batch_mol_graphs):
tree_mess, tree_vec = self.jtnn(batch_tree_graphs)
mol_vec = self.mpn(batch_mol_graphs)
return tree_mess, tree_vec, mol_vec
[docs] def forward(self, batch_trees, batch_tree_graphs, batch_mol_graphs, stereo_cand_batch_idx,
stereo_cand_labels, batch_stereo_cand_graphs, beta=0):
batch_size = batch_tree_graphs.batch_size
device = batch_tree_graphs.device
tree_mess, tree_vec, mol_vec = self.encode(batch_tree_graphs, batch_mol_graphs)
tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec)) # Following Mueller et al.
mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec)) # Following Mueller et al.
z_mean = torch.cat([tree_mean, mol_mean], dim=1)
z_log_var = torch.cat([tree_log_var, mol_log_var], dim=1)
kl_loss = -0.5 * torch.sum(torch.tensor(1.0).to(device) + z_log_var - z_mean * z_mean -
torch.exp(z_log_var)) / batch_size
epsilon = torch.randn(batch_size, self.latent_size // 2).to(device)
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
epsilon = torch.randn(batch_size, self.latent_size // 2).to(device)
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
word_loss, topo_loss, word_acc, topo_acc = self.decoder(batch_tree_graphs, tree_vec)
assm_loss, assm_acc = self.assm(batch_trees, batch_tree_graphs, mol_vec, tree_mess)
if self.use_stereo:
stereo_loss, stereo_acc = self.stereo(stereo_cand_batch_idx, stereo_cand_labels,
batch_stereo_cand_graphs, mol_vec)
else:
stereo_loss, stereo_acc = torch.tensor(0.).to(device), 0
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
return loss, kl_loss.item(), word_acc, topo_acc, assm_acc, stereo_acc
def edata_to_dict(self, g, tree_mess):
tree_mess_ = dict()
src, dst = g.edges()
for i, edge in enumerate(tuple(zip(src.tolist(), dst.tolist()))):
tree_mess_[edge] = tree_mess[i]
return tree_mess_
def assm(self, batch_trees, tree_graphs, mol_vec, tree_mess):
device = tree_graphs.device
cands = []
cand_batch_idx = []
for i, tree in enumerate(batch_trees):
for _, node in tree.nodes_dict.items():
# Leaf node's attachment is determined by neighboring node's attachment
if node['is_leaf'] or len(node['cands']) == 1:
continue
cands.extend([(cand, tree.nodes_dict, node) for cand in node['cand_mols']])
cand_batch_idx.extend([i] * len(node['cands']))
tree_mess = self.edata_to_dict(tree_graphs, tree_mess)
cand_vec = self.jtmpn(cands, tree_mess, device)
cand_vec = self.G_mean(cand_vec)
if len(cand_batch_idx) == 0:
cand_batch_idx = torch.zeros(0).long().to(device)
else:
cand_batch_idx = torch.LongTensor(cand_batch_idx).to(device)
mol_vec = mol_vec[cand_batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
cand_vec = cand_vec.view(-1, self.latent_size // 2, 1)
scores = torch.bmm(mol_vec, cand_vec).squeeze()
cnt, tot, acc = 0, 0, 0
all_loss = []
for tree in batch_trees:
for _, node in tree.nodes_dict.items():
num_cands = len(node['cands'])
if node['is_leaf'] or num_cands == 1:
continue
cnt += 1
label = node['cands'].index(node['label'])
cur_score = scores.narrow(0, tot, num_cands)
tot += num_cands
if cur_score[label].item() >= cur_score.max().item():
acc += 1
label = torch.LongTensor([label]).to(device)
all_loss.append(self.assm_loss(cur_score.view(1, -1), label))
if len(all_loss) > 0:
all_loss = sum(all_loss) / len(batch_trees)
else:
all_loss = torch.zeros(1).to(device)
return all_loss, acc * 1.0 / cnt
def stereo(self, batch_idx, batch_labels, batch_stereo_cand_graphs, mol_vec):
device = batch_stereo_cand_graphs.device
if len(batch_labels) == 0:
return torch.zeros(1).to(device), 1.0
stereo_cands = self.mpn(batch_stereo_cand_graphs)
stereo_cands = self.G_mean(stereo_cands)
stereo_labels = mol_vec[batch_idx]
scores = nn.CosineSimilarity()(stereo_cands, stereo_labels)
st, acc = 0, 0
all_loss = []
for label, le in batch_labels:
cur_scores = scores.narrow(0, st, le)
if cur_scores[label].item() >= cur_scores.max().item():
acc += 1
label = torch.LongTensor([label]).to(device)
all_loss.append(self.stereo_loss(cur_scores.view(1, -1), label))
st += le
all_loss = sum(all_loss) / len(batch_labels)
return all_loss, acc * 1.0 / len(batch_labels)
def reconstruct(self, tree_graph, mol_graph, prob_decode=False):
device = tree_graph.device
_, tree_vec, mol_vec = self.encode(tree_graph, mol_graph)
tree_mean = self.T_mean(tree_vec)
tree_log_var = -torch.abs(self.T_var(tree_vec)) # Following Mueller et al.
mol_mean = self.G_mean(mol_vec)
mol_log_var = -torch.abs(self.G_var(mol_vec)) # Following Mueller et al.
epsilon = torch.randn(1, self.latent_size // 2).to(device)
tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon
epsilon = torch.randn(1, self.latent_size // 2).to(device)
mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon
return self.decode(tree_vec, mol_vec, prob_decode)
def sample_prior(self, prob_decode=False):
tree_vec = torch.randn(1, self.latent_size // 2)
mol_vec = torch.randn(1, self.latent_size // 2)
return self.decode(tree_vec, mol_vec, prob_decode)
def sample_eval(self):
tree_vec = torch.randn(1, self.latent_size // 2)
mol_vec = torch.randn(1, self.latent_size // 2)
all_smiles = []
for _ in range(100):
s = self.decode(tree_vec, mol_vec, prob_decode=True)
all_smiles.append(s)
return all_smiles
def decode(self, tree_vec, mol_vec, prob_decode):
device = tree_vec.device
pred_root, pred_nodes = self.decoder.decode(tree_vec, prob_decode)
# Mark nid & is_leaf & atommap
for i, node in enumerate(pred_nodes):
node['nid'] = i + 1
node['is_leaf'] = (len(node['neighbors']) == 1)
if len(node['neighbors']) > 1:
set_atommap(node['mol'], node['nid'])
src = []
dst = []
for node in pred_nodes:
cur_id = node['idx']
for nbr in node['neighbors']:
nbr_id = nbr['idx']
src.extend([cur_id])
dst.extend([nbr_id])
if len(src) == 0:
tree_graph = dgl.graph((src, dst), idtype=torch.int32, device=device,
num_nodes=max([node['idx'] + 1 for node in pred_nodes]))
else:
tree_graph = dgl.graph((src, dst), idtype=torch.int32, device=device)
node_ids = torch.LongTensor([node['idx'] for node in pred_nodes]).to(device)
node_wid = torch.LongTensor([node['wid'] for node in pred_nodes]).to(device)
tree_graph_x = torch.zeros(tree_graph.num_nodes(), self.hidden_size).to(device)
tree_graph_x[node_ids] = self.embedding(node_wid)
tree_graph.ndata['x'] = tree_graph_x
tree_mess = self.jtnn(tree_graph)[0]
tree_mess = self.edata_to_dict(tree_graph, tree_mess)
cur_mol = copy_edit_mol(pred_root['mol'])
global_amap = [{}] + [{} for _ in pred_nodes]
global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}
cur_mol = self.dfs_assemble(tree_mess, mol_vec, pred_nodes, cur_mol, global_amap, [],
pred_root, None, prob_decode)
if cur_mol is None:
return None
cur_mol = cur_mol.GetMol()
set_atommap(cur_mol)
cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol))
if cur_mol is None:
return None
if not self.use_stereo:
return Chem.MolToSmiles(cur_mol)
smiles2D = Chem.MolToSmiles(cur_mol)
stereo_cands = decode_stereo(smiles2D)
if len(stereo_cands) == 1:
return stereo_cands[0]
stereo_cand_graphs = []
for cand in stereo_cands:
cand = get_mol(cand)
cg = mol_to_bigraph(cand, node_featurizer=self.atom_featurizer,
edge_featurizer=self.bond_featurizer,
canonical_atom_order=False)
cg.apply_edges(fn.copy_u('x', 'src'))
cg.edata['x'] = torch.cat([cg.edata.pop('src'), cg.edata['x']], dim=1)
stereo_cand_graphs.append(cg)
stereo_cand_graphs = dgl.batch(stereo_cand_graphs).to(device)
stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs)
scores = nn.CosineSimilarity()(stereo_vecs, mol_vec)
_, max_id = scores.max(dim=0)
return stereo_cands[max_id.item()]
def dfs_assemble(self, tree_mess, mol_vec, all_nodes, cur_mol, global_amap, fa_amap,
cur_node, fa_node, prob_decode):
fa_nid = fa_node['nid'] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children = [nei for nei in cur_node['neighbors'] if nei['nid'] != fa_nid]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
_, cand_mols, cand_amap = zip(*cands)
cands = [(candmol, all_nodes, cur_node) for candmol in cand_mols]
cand_vecs = self.jtmpn(cands, tree_mess, mol_vec.device)
cand_vecs = self.G_mean(cand_vecs)
mol_vec = mol_vec.squeeze()
scores = torch.mv(cand_vecs, mol_vec) * 20
if prob_decode:
probs = torch.softmax(scores.view(1, -1)).squeeze() + 1e-5 # prevent prob = 0
cand_idx = torch.multinomial(probs, probs.numel())
else:
_, cand_idx = torch.sort(scores, descending=True)
backup_mol = Chem.RWMol(cur_mol)
for i in range(cand_idx.numel()):
cur_mol = Chem.RWMol(backup_mol)
pred_amap = cand_amap[cand_idx[i].item()]
new_global_amap = copy.deepcopy(global_amap)
for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]
# father is already attached
cur_mol = attach_mols(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol()
new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
if new_mol is None:
continue
result = True
for nei_node in children:
if nei_node['is_leaf']:
continue
cur_mol = self.dfs_assemble(tree_mess, mol_vec, all_nodes, cur_mol,
new_global_amap, pred_amap, nei_node,
cur_node, prob_decode)
if cur_mol is None:
result = False
break
if result:
return cur_mol
return None