Source code for dgllife.utils.complex_to_graph

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Convert complexes into DGLHeteroGraphs
# pylint: disable= no-member, arguments-differ, invalid-name
import numpy as np
import dgl.backend as F
from functools import partial
from dgl import graph, heterograph, batch

from ..utils.mol_to_graph import k_nearest_neighbors, mol_to_bigraph
from ..utils.featurizers import BaseAtomFeaturizer, BaseBondFeaturizer, ConcatFeaturizer, atom_type_one_hot, atom_total_degree_one_hot, atom_formal_charge_one_hot, atom_is_aromatic, atom_implicit_valence_one_hot, atom_explicit_valence_one_hot, bond_type_one_hot, bond_is_in_ring

__all__ = ['ACNN_graph_construction_and_featurization', 
           'PN_graph_construction_and_featurization']

def filter_out_hydrogens(mol):
    """Get indices for non-hydrogen atoms.

    Parameters
    ----------
    mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.

    Returns
    -------
    indices_left : list of int
        Indices of non-hydrogen atoms.
    """
    indices_left = []
    for i, atom in enumerate(mol.GetAtoms()):
        atomic_num = atom.GetAtomicNum()
        # Hydrogen atoms have an atomic number of 1.
        if atomic_num != 1:
            indices_left.append(i)
    return indices_left

def get_atomic_numbers(mol, indices):
    """Get the atomic numbers for the specified atoms.

    Parameters
    ----------
    mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    indices : list of int
        Specifying atoms.

    Returns
    -------
    list of int
        Atomic numbers computed.
    """
    atomic_numbers = []
    for i in indices:
        atom = mol.GetAtomWithIdx(i)
        atomic_numbers.append(atom.GetAtomicNum())
    return atomic_numbers

def int_2_one_hot(a):
    """Convert integer encodings on a vector to a matrix of one-hot encoding"""
    n = len(a)
    b = np.zeros((n, a.max()+1))
    b[np.arange(n), a] = 1
    return b

def PN_graph_construction_and_featurization(ligand_mol,
                                            protein_mol,
                                            ligand_coordinates,
                                            protein_coordinates,
                                            max_num_ligand_atoms=None,
                                            max_num_protein_atoms=None,
                                            max_num_neighbors=4,
                                            distance_bins=[1.5, 2.5, 3.5, 4.5],
                                            strip_hydrogens=False):
    """Graph construction and featurization for `PotentialNet for Molecular Property Prediction
     <https://pubs.acs.org/doi/10.1021/acscentsci.8b00507>`__.

    Parameters
    ----------
    ligand_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    protein_mol : rdkit.Chem.rdchem.Mol
        RDKit molecule instance.
    ligand_coordinates : Float Tensor of shape (V1, 3)
        Atom coordinates in a ligand.
    protein_coordinates : Float Tensor of shape (V2, 3)
        Atom coordinates in a protein.
    max_num_ligand_atoms : int or None
        Maximum number of atoms in ligands for zero padding, which should be no smaller than
        ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    max_num_protein_atoms : int or None
        Maximum number of atoms in proteins for zero padding, which should be no smaller than
        protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed.
        Default to None.
    max_num_neighbors : int
        Maximum number of neighbors allowed for each atom when constructing KNN graph. Default to 4.
    distance_bins : list of float
        Distance bins to determine the edge types.
        Edges of the first edge type are added between pairs of atoms whose distances are less than `distance_bins[0]`.
        The length matches the number of edge types to be constructed.
        Default `[1.5, 2.5, 3.5, 4.5]`.
    strip_hydrogens : bool
        Whether to exclude hydrogen atoms. Default to False.

    Returns
    -------
    complex_bigraph : DGLGraph
        Bigraph with the ligand and the protein (pocket) combined and canonical features extracted.
        The atom features are stored as DGLGraph.ndata['h'].
        The edge types are stored as DGLGraph.edata['e'].
        The bigraphs of the ligand and the protein are batched together as one complex graph.
    complex_knn_graph : DGLGraph
        K-nearest-neighbor graph with the ligand and the protein (pocket) combined and edge features extracted based on distances.
        The edge types are stored as DGLGraph.edata['e'].
        The knn graphs of the ligand and the protein are batched together as one complex graph.

    """

    assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.'
    assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.'
    if max_num_ligand_atoms is not None:
        assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \
            'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_ligand_atoms, ligand_mol.GetNumAtoms())
    if max_num_protein_atoms is not None:
        assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \
            'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), ' \
            'got {:d} and {:d}'.format(max_num_protein_atoms, protein_mol.GetNumAtoms())

    if strip_hydrogens:
        # Remove hydrogen atoms and their corresponding coordinates
        ligand_atom_indices_left = filter_out_hydrogens(ligand_mol)
        protein_atom_indices_left = filter_out_hydrogens(protein_mol)
        ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left, axis=0)
        protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0)
    else:
        ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms()))
        protein_atom_indices_left = list(range(protein_mol.GetNumAtoms()))

    # Node featurizer for stage 1
    atoms = ['H','N','O','C','P','S','F','Br','Cl','I','Fe','Zn','Mg','Na','Mn','Ca','Co','Ni','Se','Cu','Cd','Hg','K']
    atom_total_degrees = list(range(5))
    atom_formal_charges = [-1, 0, 1]
    atom_implicit_valence = list(range(4))
    atom_explicit_valence = list(range(8))
    atom_concat_featurizer = ConcatFeaturizer([partial(atom_type_one_hot, allowable_set=atoms), 
                                               partial(atom_total_degree_one_hot, allowable_set=atom_total_degrees),
                                               partial(atom_formal_charge_one_hot, allowable_set=atom_formal_charges),
                                               atom_is_aromatic,
                                               partial(atom_implicit_valence_one_hot, allowable_set=atom_implicit_valence),
                                               partial(atom_explicit_valence_one_hot, allowable_set=atom_explicit_valence)])
    PN_atom_featurizer = BaseAtomFeaturizer({'h': atom_concat_featurizer})

    # Bond featurizer for stage 1
    bond_concat_featurizer = ConcatFeaturizer([bond_type_one_hot, bond_is_in_ring])
    PN_bond_featurizer = BaseBondFeaturizer({'e': bond_concat_featurizer})

    # construct graphs for stage 1
    ligand_bigraph = mol_to_bigraph(ligand_mol, add_self_loop=False,
                                    node_featurizer=PN_atom_featurizer,
                                    edge_featurizer=PN_bond_featurizer,
                                    canonical_atom_order=False) # Keep the original atomic order)
    protein_bigraph = mol_to_bigraph(protein_mol, add_self_loop=False,
                                     node_featurizer=PN_atom_featurizer,
                                     edge_featurizer=PN_bond_featurizer,
                                     canonical_atom_order=False)
    complex_bigraph = batch([ligand_bigraph, protein_bigraph])

    # Construct knn graphs for stage 2
    complex_coordinates = np.concatenate([ligand_coordinates, protein_coordinates])
    complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors(
            complex_coordinates, distance_bins[-1], max_num_neighbors)
    complex_srcs = np.array(complex_srcs)
    complex_dsts = np.array(complex_dsts)
    complex_dists = np.array(complex_dists)

    complex_knn_graph = graph((complex_srcs, complex_dsts), num_nodes=len(complex_coordinates))
    d_features = np.digitize(complex_dists, bins=distance_bins, right=True)
    d_one_hot = int_2_one_hot(d_features)
    
    # add bond types and bonds (from bigraph) to stage 2
    u, v = complex_bigraph.edges()    
    complex_knn_graph.add_edges(u.to(F.int64), v.to(F.int64))
    n_d, f_d = d_one_hot.shape
    n_e, f_e = complex_bigraph.edata['e'].shape
    complex_knn_graph.edata['e'] = F.zerocopy_from_numpy(
        np.block([
            [d_one_hot, np.zeros((n_d, f_e))],
            [np.zeros((n_e, f_d)), np.array(complex_bigraph.edata['e'])]
        ]).astype(np.long)
    )
    return complex_bigraph, complex_knn_graph


# pylint: disable=C0326
[docs]def ACNN_graph_construction_and_featurization(ligand_mol, protein_mol, ligand_coordinates, protein_coordinates, max_num_ligand_atoms=None, max_num_protein_atoms=None, neighbor_cutoff=12., max_num_neighbors=12, strip_hydrogens=False): """Graph construction and featurization for `Atomic Convolutional Networks for Predicting Protein-Ligand Binding Affinity <https://arxiv.org/abs/1703.10603>`__. Parameters ---------- ligand_mol : rdkit.Chem.rdchem.Mol RDKit molecule instance. protein_mol : rdkit.Chem.rdchem.Mol RDKit molecule instance. ligand_coordinates : Float Tensor of shape (V1, 3) Atom coordinates in a ligand. protein_coordinates : Float Tensor of shape (V2, 3) Atom coordinates in a protein. max_num_ligand_atoms : int or None Maximum number of atoms in ligands for zero padding, which should be no smaller than ligand_mol.GetNumAtoms() if not None. If None, no zero padding will be performed. Default to None. max_num_protein_atoms : int or None Maximum number of atoms in proteins for zero padding, which should be no smaller than protein_mol.GetNumAtoms() if not None. If None, no zero padding will be performed. Default to None. neighbor_cutoff : float Distance cutoff to define 'neighboring'. Default to 12. max_num_neighbors : int Maximum number of neighbors allowed for each atom. Default to 12. strip_hydrogens : bool Whether to exclude hydrogen atoms. Default to False. """ assert ligand_coordinates is not None, 'Expect ligand_coordinates to be provided.' assert protein_coordinates is not None, 'Expect protein_coordinates to be provided.' if max_num_ligand_atoms is not None: assert max_num_ligand_atoms >= ligand_mol.GetNumAtoms(), \ 'Expect max_num_ligand_atoms to be no smaller than ligand_mol.GetNumAtoms(), ' \ 'got {:d} and {:d}'.format(max_num_ligand_atoms, ligand_mol.GetNumAtoms()) if max_num_protein_atoms is not None: assert max_num_protein_atoms >= protein_mol.GetNumAtoms(), \ 'Expect max_num_protein_atoms to be no smaller than protein_mol.GetNumAtoms(), ' \ 'got {:d} and {:d}'.format(max_num_protein_atoms, protein_mol.GetNumAtoms()) if strip_hydrogens: # Remove hydrogen atoms and their corresponding coordinates ligand_atom_indices_left = filter_out_hydrogens(ligand_mol) protein_atom_indices_left = filter_out_hydrogens(protein_mol) ligand_coordinates = ligand_coordinates.take(ligand_atom_indices_left, axis=0) protein_coordinates = protein_coordinates.take(protein_atom_indices_left, axis=0) else: ligand_atom_indices_left = list(range(ligand_mol.GetNumAtoms())) protein_atom_indices_left = list(range(protein_mol.GetNumAtoms())) # Compute number of nodes for each type if max_num_ligand_atoms is None: num_ligand_atoms = len(ligand_atom_indices_left) else: num_ligand_atoms = max_num_ligand_atoms if max_num_protein_atoms is None: num_protein_atoms = len(protein_atom_indices_left) else: num_protein_atoms = max_num_protein_atoms data_dict = dict() num_nodes_dict = dict() # graph data for atoms in the ligand ligand_srcs, ligand_dsts, ligand_dists = k_nearest_neighbors( ligand_coordinates, neighbor_cutoff, max_num_neighbors) data_dict[('ligand_atom', 'ligand', 'ligand_atom')] = (ligand_srcs, ligand_dsts) num_nodes_dict['ligand_atom'] = num_ligand_atoms # graph data for atoms in the protein protein_srcs, protein_dsts, protein_dists = k_nearest_neighbors( protein_coordinates, neighbor_cutoff, max_num_neighbors) data_dict[('protein_atom', 'protein', 'protein_atom')] = (protein_srcs, protein_dsts) num_nodes_dict['protein_atom'] = num_protein_atoms # 4 graphs for complex representation, including the connection within # protein atoms, the connection within ligand atoms and the connection between # protein and ligand atoms. complex_srcs, complex_dsts, complex_dists = k_nearest_neighbors( np.concatenate([ligand_coordinates, protein_coordinates]), neighbor_cutoff, max_num_neighbors) complex_srcs = np.array(complex_srcs) complex_dsts = np.array(complex_dsts) complex_dists = np.array(complex_dists) offset = num_ligand_atoms # ('ligand_atom', 'complex', 'ligand_atom') inter_ligand_indices = np.intersect1d( (complex_srcs < offset).nonzero()[0], (complex_dsts < offset).nonzero()[0], assume_unique=True) data_dict[('ligand_atom', 'complex', 'ligand_atom')] = \ (complex_srcs[inter_ligand_indices].tolist(), complex_dsts[inter_ligand_indices].tolist()) # ('protein_atom', 'complex', 'protein_atom') inter_protein_indices = np.intersect1d( (complex_srcs >= offset).nonzero()[0], (complex_dsts >= offset).nonzero()[0], assume_unique=True) data_dict[('protein_atom', 'complex', 'protein_atom')] = \ ((complex_srcs[inter_protein_indices] - offset).tolist(), (complex_dsts[inter_protein_indices] - offset).tolist()) # ('ligand_atom', 'complex', 'protein_atom') ligand_protein_indices = np.intersect1d( (complex_srcs < offset).nonzero()[0], (complex_dsts >= offset).nonzero()[0], assume_unique=True) data_dict[('ligand_atom', 'complex', 'protein_atom')] = \ (complex_srcs[ligand_protein_indices].tolist(), (complex_dsts[ligand_protein_indices] - offset).tolist()) # ('protein_atom', 'complex', 'ligand_atom') protein_ligand_indices = np.intersect1d( (complex_srcs >= offset).nonzero()[0], (complex_dsts < offset).nonzero()[0], assume_unique=True) data_dict[('protein_atom', 'complex', 'ligand_atom')] = \ ((complex_srcs[protein_ligand_indices] - offset).tolist(), complex_dsts[protein_ligand_indices].tolist()) g = heterograph(data_dict, num_nodes_dict=num_nodes_dict) g.edges['ligand'].data['distance'] = F.reshape(F.zerocopy_from_numpy( np.array(ligand_dists).astype(np.float32)), (-1, 1)) g.edges['protein'].data['distance'] = F.reshape(F.zerocopy_from_numpy( np.array(protein_dists).astype(np.float32)), (-1, 1)) g.edges[('ligand_atom', 'complex', 'ligand_atom')].data['distance'] = \ F.reshape(F.zerocopy_from_numpy( complex_dists[inter_ligand_indices].astype(np.float32)), (-1, 1)) g.edges[('protein_atom', 'complex', 'protein_atom')].data['distance'] = \ F.reshape(F.zerocopy_from_numpy( complex_dists[inter_protein_indices].astype(np.float32)), (-1, 1)) g.edges[('ligand_atom', 'complex', 'protein_atom')].data['distance'] = \ F.reshape(F.zerocopy_from_numpy( complex_dists[ligand_protein_indices].astype(np.float32)), (-1, 1)) g.edges[('protein_atom', 'complex', 'ligand_atom')].data['distance'] = \ F.reshape(F.zerocopy_from_numpy( complex_dists[protein_ligand_indices].astype(np.float32)), (-1, 1)) # Get atomic numbers for all atoms left and set node features ligand_atomic_numbers = np.array(get_atomic_numbers(ligand_mol, ligand_atom_indices_left)) # zero padding ligand_atomic_numbers = np.concatenate([ ligand_atomic_numbers, np.zeros(num_ligand_atoms - len(ligand_atom_indices_left))]) protein_atomic_numbers = np.array(get_atomic_numbers(protein_mol, protein_atom_indices_left)) # zero padding protein_atomic_numbers = np.concatenate([ protein_atomic_numbers, np.zeros(num_protein_atoms - len(protein_atom_indices_left))]) g.nodes['ligand_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy( ligand_atomic_numbers.astype(np.float32)), (-1, 1)) g.nodes['protein_atom'].data['atomic_number'] = F.reshape(F.zerocopy_from_numpy( protein_atomic_numbers.astype(np.float32)), (-1, 1)) # Prepare mask indicating the existence of nodes ligand_masks = np.zeros((num_ligand_atoms, 1)) ligand_masks[:len(ligand_atom_indices_left), :] = 1 g.nodes['ligand_atom'].data['mask'] = F.zerocopy_from_numpy( ligand_masks.astype(np.float32)) protein_masks = np.zeros((num_protein_atoms, 1)) protein_masks[:len(protein_atom_indices_left), :] = 1 g.nodes['protein_atom'].data['mask'] = F.zerocopy_from_numpy( protein_masks.astype(np.float32)) return g