# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Dataset for inference on smiles
from functools import partial
from rdkit import Chem
from ..utils.mol_to_graph import ToGraph, MolToBigraph
__all__ = ['UnlabeledSMILES']
[docs]class UnlabeledSMILES(object):
"""Construct a SMILES dataset without labels for inference.
We will 1) Filter out invalid SMILES strings and record canonical SMILES strings
for valid ones 2) Construct a DGLGraph for each valid one and feature its node/edge
Parameters
----------
smiles_list : list of str
List of SMILES strings
mol_to_graph: callable, rdkit.Chem.rdchem.Mol -> DGLGraph
A function turning an RDKit molecule object into a DGLGraph.
Default to :func:`dgllife.utils.mol_to_bigraph`.
node_featurizer : None or callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for nodes like atoms in a molecule, which can be used to update
ndata for a DGLGraph. Default to None.
edge_featurizer : None or callable, rdkit.Chem.rdchem.Mol -> dict
Featurization for edges like bonds in a molecule, which can be used to update
edata for a DGLGraph. Default to None.
log_every : bool
Print a message every time ``log_every`` molecules are processed. Default to 1000.
"""
def __init__(self, smiles_list, mol_to_graph=None, node_featurizer=None,
edge_featurizer=None, log_every=1000):
super(UnlabeledSMILES, self).__init__()
canonical_smiles = []
mol_list = []
for smi in smiles_list:
mol = Chem.MolFromSmiles(smi)
if mol is None:
continue
mol_list.append(mol)
canonical_smiles.append(Chem.MolToSmiles(mol))
self.smiles = canonical_smiles
self.graphs = []
if mol_to_graph is None:
mol_to_graph = MolToBigraph()
# Check for backward compatibility
if isinstance(mol_to_graph, ToGraph):
assert node_featurizer is None, \
'Initialize mol_to_graph object with node_featurizer=node_featurizer'
assert edge_featurizer is None, \
'Initialize mol_to_graph object with edge_featurizer=edge_featurizer'
else:
mol_to_graph = partial(mol_to_graph, node_featurizer=node_featurizer,
edge_featurizer=edge_featurizer)
for i, mol in enumerate(mol_list):
if (i + 1) % log_every == 0:
print('Processing molecule {:d}/{:d}'.format(i + 1, len(self)))
self.graphs.append(mol_to_graph(mol))
[docs] def __getitem__(self, item):
"""Get datapoint with index
Parameters
----------
item : int
Datapoint index
Returns
-------
str
SMILES for the ith datapoint
DGLGraph
DGLGraph for the ith datapoint
"""
return self.smiles[item], self.graphs[item]
[docs] def __len__(self):
"""Size for the dataset
Returns
-------
int
Size for the dataset
"""
return len(self.smiles)