Skip to content

Commit

Permalink
Merge branch 'main' into get_feature_names
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork authored Nov 18, 2024
2 parents 05e041e + 78b0fe0 commit 8d1abd1
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 1 deletion.
86 changes: 85 additions & 1 deletion molpipeline/any2mol/smiles2mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from __future__ import annotations

from typing import Any, Optional

try:
from typing import Self # type: ignore[attr-defined]
except ImportError:
from typing_extensions import Self

from rdkit import Chem

from molpipeline.abstract_pipeline_elements.any2mol.string2mol import (
Expand All @@ -13,6 +20,42 @@
class SmilesToMol(SimpleStringToMolElement):
"""Transforms Smiles to RDKit Mol objects."""

def __init__(
self,
remove_hydrogens: bool = True,
name: str = "smiles2mol",
n_jobs: int = 1,
uuid: Optional[str] = None,
) -> None:
"""Initialize SmilesToMol object.
Parameters
----------
remove_hydrogens: bool
Whether to remove hydrogens from the molecule.
name: str
Name of the object.
n_jobs: int
Number of jobs to run in parallel.
uuid: Optional[str]
UUID of the object.
"""
super().__init__(name=name, n_jobs=n_jobs, uuid=uuid)
self._remove_hydrogens = remove_hydrogens

def _get_parser_config(self) -> Chem.SmilesParserParams:
"""Get parser configuration.
Returns
-------
dict[str, Any]
Configuration for the parser.
"""
# set up rdkit smiles parser parameters
parser_params = Chem.SmilesParserParams()
parser_params.removeHs = self._remove_hydrogens
return parser_params

def string_to_mol(self, value: str) -> RDKitMol:
"""Transform Smiles string to molecule.
Expand All @@ -26,4 +69,45 @@ def string_to_mol(self, value: str) -> RDKitMol:
RDKitMol
Rdkit molecule if valid SMILES, else None.
"""
return Chem.MolFromSmiles(value)
return Chem.MolFromSmiles(value, self._get_parser_config())

def get_params(self, deep: bool = True) -> dict[str, Any]:
"""Get parameters for this object.
Parameters
----------
deep: bool
If True, return a deep copy of the parameters.
Returns
-------
dict[str, Any]
Dictionary of parameters.
"""
parameters = super().get_params(deep)
if deep:
parameters["remove_hydrogens"] = bool(self._remove_hydrogens)

else:
parameters["remove_hydrogens"] = self._remove_hydrogens
return parameters

def set_params(self, **parameters: Any) -> Self:
"""Set parameters.
Parameters
----------
parameters: Any
Dictionary of parameter names and values.
Returns
-------
Self
SmilesToMol pipeline element with updated parameters.
"""
parameter_copy = dict(parameters)
remove_hydrogens = parameter_copy.pop("remove_hydrogens", None)
if remove_hydrogens is not None:
self._remove_hydrogens = remove_hydrogens
super().set_params(**parameter_copy)
return self
53 changes: 53 additions & 0 deletions tests/test_elements/test_any2mol/test_smiles2mol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Test smiles to mol pipeline element."""

import unittest
from typing import Any

from molpipeline import Pipeline
from molpipeline.any2mol import SmilesToMol


class TestSmiles2Mol(unittest.TestCase):
"""Test case for testing conversion of SMILES input to molecules."""

def test_smiles2mol_explict_hydrogens(self) -> None:
"""Test smiles reading with and without explicit smiles."""
smiles = "C[H]"

# test: remove explicit Hs
pipeline = Pipeline(
[
(
"Smiles2Mol",
SmilesToMol(remove_hydrogens=True),
),
]
)
mols = pipeline.fit_transform([smiles])
self.assertEqual(len(mols), 1)
self.assertIsNotNone(mols[0])
self.assertEqual(mols[0].GetNumAtoms(), 1)

# test: keep explicit Hs
pipeline2 = Pipeline(
[
(
"Smiles2Mol",
SmilesToMol(remove_hydrogens=False),
),
]
)
mols2 = pipeline2.fit_transform([smiles])
self.assertEqual(len(mols2), 1)
self.assertIsNotNone(mols2[0])
self.assertEqual(mols2[0].GetNumAtoms(), 2)

def test_getter_setter(self) -> None:
"""Test getter and setter methods."""
smiles2mol = SmilesToMol(remove_hydrogens=False)
self.assertEqual(smiles2mol.get_params()["remove_hydrogens"], False)
params: dict[str, Any] = {
"remove_hydrogens": True,
}
smiles2mol.set_params(**params)
self.assertEqual(smiles2mol.get_params()["remove_hydrogens"], True)

0 comments on commit 8d1abd1

Please sign in to comment.