Source code for dgllife.model.model_zoo.jtnn.jtnn_vae

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# pylint: disable=C0111, C0103, E1101, W0611, W0612, C0200, W0221, E1102

import copy
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl import batch, unbatch
from dgl.data.utils import _get_dgl_url, download, get_download_dir, extract_archive

from .jtmpn import DGLJTMPN
from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder
from .vocab import Vocab
from .mpn import DGLMPN
from ....data.jtvae import set_atommap, decode_stereo, copy_edit_mol, attach_mols_nx, \
    enum_assemble_nx, get_atom_featurizer_enc, get_bond_featurizer_enc, mol2dgl_enc, \
    get_atom_featurizer_dec, get_bond_featurizer_dec, mol2dgl_dec

[docs]class DGLJTNNVAE(nn.Module): """ `Junction Tree Variational Autoencoder for Molecular Graph Generation <https://arxiv.org/abs/1802.04364>`__ Parameters ---------- hidden_size : int Size for hidden representations. latent_size : int Size for latent representations of nodes and edges. depth : int The number of times for message passing. vocab_file : str The path to a file of vocabulary, with one SMILES per line. If not specified, it will use the vocabulary extracted from the ZINC dataset. """ def __init__(self, hidden_size, latent_size, depth, vocab_file=None): super(DGLJTNNVAE, self).__init__() if vocab_file is None: default_dir = get_download_dir() vocab_file = '{}/jtvae/{}.txt'.format(default_dir, 'vocab') zip_file_path = '{}/jtvae.zip'.format(default_dir) download(_get_dgl_url('dataset/jtvae.zip'), path=zip_file_path) extract_archive(zip_file_path, '{}/jtvae'.format(default_dir)) with open(vocab_file, 'r') as f: self.vocab = Vocab([x.strip("\r\n ") for x in f]) self.hidden_size = hidden_size self.latent_size = latent_size self.depth = depth self.embedding = nn.Embedding(self.vocab.size(), hidden_size) self.mpn = DGLMPN(hidden_size, depth) self.jtnn = DGLJTNNEncoder(self.vocab, hidden_size, self.embedding) self.decoder = DGLJTNNDecoder(self.vocab, hidden_size, latent_size // 2, self.embedding) self.jtmpn = DGLJTMPN(hidden_size, depth) self.T_mean = nn.Linear(hidden_size, latent_size // 2) self.T_var = nn.Linear(hidden_size, latent_size // 2) self.G_mean = nn.Linear(hidden_size, latent_size // 2) self.G_var = nn.Linear(hidden_size, latent_size // 2) self.atom_featurizer_enc = get_atom_featurizer_enc() self.bond_featurizer_enc = get_bond_featurizer_enc() self.atom_featurizer_dec = get_atom_featurizer_dec() self.bond_featurizer_dec = get_bond_featurizer_dec()
[docs] def reset_parameters(self): """Reinitialize model parameters.""" self.embedding.reset_parameters() self.mpn.reset_parameters() self.jtnn.reset_parameters() self.decoder.reset_parameters() self.jtmpn.reset_parameters() self.T_mean.reset_parameters() self.T_var.reset_parameters() self.G_mean.reset_parameters() self.G_var.reset_parameters()
def encode(self, mol_batch): mol_graphs = mol_batch['mol_graph_batch'] mol_vec = self.mpn(mol_graphs) mol_trees = [tree.g for tree in mol_batch['mol_trees']] mol_tree_batch = batch(mol_trees) mol_tree_batch, tree_vec = self.jtnn(mol_tree_batch) return mol_tree_batch, tree_vec, mol_vec def sample(self, tree_vec, mol_vec, e1=None, e2=None): device = tree_vec.device tree_mean = self.T_mean(tree_vec) tree_log_var = -torch.abs(self.T_var(tree_vec)) mol_mean = self.G_mean(mol_vec) mol_log_var = -torch.abs(self.G_var(mol_vec)) epsilon = torch.randn(*tree_mean.shape).to(device) if e1 is None else e1 tree_vec = tree_mean + torch.exp(tree_log_var / 2) * epsilon epsilon = torch.randn(*mol_mean.shape).to(device) if e2 is None else e2 mol_vec = mol_mean + torch.exp(mol_log_var / 2) * epsilon z_mean = torch.cat([tree_mean, mol_mean], 1) z_log_var = torch.cat([tree_log_var, mol_log_var], 1) return tree_vec, mol_vec, z_mean, z_log_var
[docs] def forward(self, mol_batch, beta=0, e1=None, e2=None): mol_trees = mol_batch['mol_trees'] batch_size = len(mol_trees) mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch) tree_vec, mol_vec, z_mean, z_log_var = self.sample( tree_vec, mol_vec, e1, e2) kl_loss = -0.5 * torch.sum( 1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size word_loss, topo_loss, word_acc, topo_acc = self.decoder( mol_trees, tree_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
def assm(self, mol_batch, mol_tree_batch, mol_vec): device = mol_vec.device cands = [mol_batch['cand_graph_batch'], mol_batch['tree_mess_src_e'], mol_batch['tree_mess_tgt_e'], mol_batch['tree_mess_tgt_n']] cand_vec = self.jtmpn(cands, mol_tree_batch) cand_vec = self.G_mean(cand_vec) batch_idx = torch.LongTensor(mol_batch['cand_batch_idx']).to(device) mol_vec = mol_vec[batch_idx] mol_vec = mol_vec.view(-1, 1, self.latent_size // 2) cand_vec = cand_vec.view(-1, self.latent_size // 2, 1) scores = (mol_vec @ cand_vec)[:, 0, 0] cnt, tot, acc = 0, 0, 0 all_loss = [] for i, mol_tree in enumerate(mol_batch['mol_trees']): comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items() if len(node['cands']) > 1 and not node['is_leaf']] cnt += len(comp_nodes) # segmented accuracy and cross entropy for node_id in comp_nodes: node = mol_tree.nodes_dict[node_id] label = node['cands'].index(node['label']) ncand = len(node['cands']) cur_score = scores[tot:tot + ncand] tot += ncand if cur_score[label].item() >= cur_score.max().item(): acc += 1 label = torch.LongTensor([label]).to(device) all_loss.append( F.cross_entropy(cur_score.view(1, -1), label, reduction='sum')) all_loss = sum(all_loss) / len(mol_batch['mol_trees']) return all_loss, acc / cnt def stereo(self, mol_batch, mol_vec): device = mol_vec.device stereo_cands = mol_batch['stereo_cand_graph_batch'] batch_idx = mol_batch['stereo_cand_batch_idx'] labels = mol_batch['stereo_cand_labels'] lengths = mol_batch['stereo_cand_lengths'] if len(labels) == 0: # Only one stereoisomer exists; do nothing return torch.tensor(0.).to(device), 1. batch_idx = torch.LongTensor(batch_idx).to(device) stereo_cands = self.mpn(stereo_cands) stereo_cands = self.G_mean(stereo_cands) stereo_labels = mol_vec[batch_idx] scores = F.cosine_similarity(stereo_cands, stereo_labels) st, acc = 0, 0 all_loss = [] for label, le in zip(labels, lengths): cur_scores = scores[st:st + le] if cur_scores.data[label].item() >= cur_scores.max().item(): acc += 1 label = torch.LongTensor([label]).to(device) all_loss.append( F.cross_entropy(cur_scores.view(1, -1), label, reduction='sum')) st += le all_loss = sum(all_loss) / len(labels) return all_loss, acc / len(labels) def decode(self, tree_vec, mol_vec): device = mol_vec.device mol_tree, nodes_dict, effective_nodes = self.decoder.decode(tree_vec) effective_nodes_list = effective_nodes.tolist() nodes_dict = [nodes_dict[v] for v in effective_nodes_list] for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)): node['idx'] = i node['nid'] = i + 1 node['is_leaf'] = True if mol_tree.g.in_degrees(node_id) > 1: node['is_leaf'] = False set_atommap(node['mol'], node['nid']) mol_tree_sg = mol_tree.g.subgraph(effective_nodes) mol_tree_msg, _ = self.jtnn(mol_tree_sg) mol_tree_msg = unbatch(mol_tree_msg)[0] mol_tree_msg.nodes_dict = nodes_dict cur_mol = copy_edit_mol(nodes_dict[0]['mol']) global_amap = [{}] + [{} for _ in nodes_dict] global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()} cur_mol = self.dfs_assemble( mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None) if cur_mol is None: return None cur_mol = cur_mol.GetMol() set_atommap(cur_mol) cur_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cur_mol)) if cur_mol is None: return None smiles_2d = Chem.MolToSmiles(cur_mol) stereo_cands = decode_stereo(smiles_2d) if len(stereo_cands) == 1: return stereo_cands[0] stereo_cand_graphs = [mol2dgl_enc(c, self.atom_featurizer_enc, self.bond_featurizer_enc) for c in stereo_cands] stereo_cand_graphs = batch(stereo_cand_graphs).to(device) stereo_cand_graphs.edata['src_x'] = torch.zeros( stereo_cand_graphs.num_edges(), stereo_cand_graphs.ndata['x'].shape[1]).to(device) stereo_vecs = self.mpn(stereo_cand_graphs) stereo_vecs = self.G_mean(stereo_vecs) scores = F.cosine_similarity(stereo_vecs, mol_vec) _, max_id = scores.max(0) return stereo_cands[max_id.item()] def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id): device = mol_vec.device nodes_dict = mol_tree_msg.nodes_dict fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None cur_node = nodes_dict[cur_node_id] fa_nid = fa_node['nid'] if fa_node is not None else -1 prev_nodes = [fa_node] if fa_node is not None else [] children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist() if nodes_dict[v]['nid'] != fa_nid] children = [nodes_dict[v] for v in children_node_id] neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1] neighbors = sorted( neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1] neighbors = singletons + neighbors cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']] cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap) if len(cands) == 0: return None cand_smiles, cand_mols, cand_amap = list(zip(*cands)) cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols] cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes = \ mol2dgl_dec(cands, self.atom_featurizer_dec, self.bond_featurizer_dec) cand_graphs = batch(cand_graphs).to(device) tree_mess_src_edges = tree_mess_src_edges.to(device) tree_mess_tgt_edges = tree_mess_tgt_edges.to(device) tree_mess_tgt_nodes = tree_mess_tgt_nodes.to(device) cand_graphs.edata['src_x'] = torch.zeros(cand_graphs.num_edges(), cand_graphs.ndata['x'].shape[1]).to(device) cand_vecs = self.jtmpn( (cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes), mol_tree_msg, ) cand_vecs = self.G_mean(cand_vecs) mol_vec = mol_vec.squeeze() scores = cand_vecs @ mol_vec _, cand_idx = torch.sort(scores, descending=True) backup_mol = Chem.RWMol(cur_mol) for i in range(len(cand_idx)): cur_mol = Chem.RWMol(backup_mol) pred_amap = cand_amap[cand_idx[i].item()] new_global_amap = copy.deepcopy(global_amap) for nei_id, ctr_atom, nei_atom in pred_amap: if nei_id == fa_nid: continue new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom] cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap) new_mol = cur_mol.GetMol() new_mol = Chem.MolFromSmiles(Chem.MolToSmiles(new_mol)) if new_mol is None: continue result = True for nei_node_id, nei_node in zip(children_node_id, children): if nei_node['is_leaf']: continue cur_mol = self.dfs_assemble( mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap, nei_node_id, cur_node_id) if cur_mol is None: result = False break if result: return cur_mol return None