Source code for dgllife.data.tox21

# -*- coding: utf-8 -*-
#
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
#
# The Toxicology in the 21st Century initiative.

import pandas as pd

from dgl.data.utils import get_download_dir, download, _get_dgl_url

from .csv_dataset import MoleculeCSVDataset
from ..utils.mol_to_graph import smiles_to_bigraph

__all__ = ['Tox21']

[docs]class Tox21(MoleculeCSVDataset): """Tox21 dataset. The Toxicology in the 21st Century (https://tripod.nih.gov/tox21/challenge/) initiative created a public database measuring toxicity of compounds, which has been used in the 2014 Tox21 Data Challenge. The dataset contains qualitative toxicity measurements for 8014 compounds on 12 different targets, including nuclear receptors and stress response pathways. Each target results in a binary label. A common issue for multi-task prediction is that some datapoints are not labeled for all tasks. This is also the case for Tox21. In data pre-processing, we set non-existing labels to be 0 so that they can be placed in tensors and used for masking in loss computation. All molecules are converted into DGLGraphs. After the first-time construction, the DGLGraphs will be saved for reloading so that we do not need to reconstruct them everytime. Parameters ---------- smiles_to_graph: callable, str -> DGLGraph A function turning a SMILES string into a DGLGraph. Default to :func:`dgllife.utils.smiles_to_bigraph`. node_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for nodes like atoms in a molecule, which can be used to update ndata for a DGLGraph. Default to None. edge_featurizer : callable, rdkit.Chem.rdchem.Mol -> dict Featurization for edges like bonds in a molecule, which can be used to update edata for a DGLGraph. Default to None. load : bool Whether to load the previously pre-processed dataset or pre-process from scratch. ``load`` should be False when we want to try different graph construction and featurization methods and need to preprocess from scratch. Default to False. log_every : bool Print a message every time ``log_every`` molecules are processed. Default to 1000. cache_file_path : str Path to the cached DGLGraphs, default to 'tox21_dglgraph.bin'. n_jobs : int The maximum number of concurrently running jobs for graph construction and featurization, using joblib backend. Default to 1. Examples -------- >>> from dgllife.data import Tox21 >>> from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer >>> dataset = Tox21(smiles_to_bigraph, CanonicalAtomFeaturizer()) >>> # Get size of the dataset >>> len(dataset) 7831 >>> # Get the 0th datapoint, consisting of SMILES, DGLGraph, labels, and masks >>> dataset[0] ('CCOc1ccc2nc(S(N)(=O)=O)sc2c1', DGLGraph(num_nodes=16, num_edges=34, ndata_schemes={} edata_schemes={}), tensor([0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), tensor([1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1.])) The dataset instance also contains information about molecule ids. >>> dataset.id[i] We can also get the id along with SMILES, DGLGraph, labels, and masks at once. >>> dataset.load_full = True >>> dataset[0] ('CCOc1ccc2nc(S(N)(=O)=O)sc2c1', DGLGraph(num_nodes=16, num_edges=34, ndata_schemes={} edata_schemes={}), tensor([0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.]), tensor([1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1.]), 'TOX3021') To address the imbalance between positive and negative samples, we can re-weight positive samples for each task based on the training datapoints. >>> train_ids = torch.arange(1000) >>> dataset.task_pos_weights(train_ids) tensor([26.9706, 35.3750, 5.9756, 21.6364, 6.4404, 21.4500, 26.0000, 5.0826, 21.4390, 14.7692, 6.1442, 12.4308]) """ def __init__(self, smiles_to_graph=smiles_to_bigraph, node_featurizer=None, edge_featurizer=None, load=False, log_every=1000, cache_file_path='./tox21_dglgraph.bin', n_jobs=1): self._url = 'dataset/tox21.csv.gz' data_path = get_download_dir() + '/tox21.csv.gz' download(_get_dgl_url(self._url), path=data_path, overwrite=False) df = pd.read_csv(data_path) self.id = df['mol_id'] df = df.drop(columns=['mol_id']) self.load_full = False super(Tox21, self).__init__(df, smiles_to_graph, node_featurizer, edge_featurizer, "smiles", cache_file_path, load=load, log_every=log_every, n_jobs=n_jobs) self.id = [self.id[i] for i in self.valid_ids]
[docs] def __getitem__(self, item): """Get datapoint with index Parameters ---------- item : int Datapoint index Returns ------- str SMILES for the ith datapoint DGLGraph DGLGraph for the ith datapoint Tensor of dtype float32 and shape (T) Labels of the ith datapoint for all tasks. T for the number of tasks. Tensor of dtype float32 and shape (T) Binary masks of the ith datapoint indicating the existence of labels for all tasks. str, optional Id for the ith datapoint, returned only when ``self.load_full`` is True. """ if self.load_full: return self.smiles[item], self.graphs[item], self.labels[item], \ self.mask[item], self.id[item] else: return self.smiles[item], self.graphs[item], self.labels[item], self.mask[item]