Skip to content

Commit

Permalink
Merge pull request #168 from njzjz/refactor
Browse files Browse the repository at this point in the history
refactor and add format plugin system
  • Loading branch information
amcadmus authored Jul 5, 2021
2 parents 59dfe7d + 4ca3706 commit ea32d45
Show file tree
Hide file tree
Showing 30 changed files with 1,105 additions and 731 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,14 @@ print(syst.get_charge()) # return the total charge of the system
```

If a valence of 3 is detected on carbon, the formal charge will be assigned to -1. Because for most cases (in alkynyl anion, isonitrile, cyclopentadienyl anion), the formal charge on 3-valence carbon is -1, and this is also consisent with the 8-electron rule.

# Plugins

One can follow [a simple example](plugin_example/) to add their own format by creating and installing plugins. It's crirical to add the [Format](dpdata/format.py) class to `entry_points['dpdata.plugins']` in `setup.py`:
```py
entry_points={
'dpdata.plugins': [
'random=dpdata_random:RandomFormat'
]
},
```
3 changes: 2 additions & 1 deletion dpdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

# BondOrder System has dependency on rdkit
try:
import rdkit
# prevent conflict with dpdata.rdkit
import rdkit as _
USE_RDKIT = True
except ModuleNotFoundError:
USE_RDKIT = False
Expand Down
48 changes: 11 additions & 37 deletions dpdata/bond_order_system.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#%%
# Bond Order System
from dpdata.system import Register, System, LabeledSystem, check_System
import rdkit.Chem
from dpdata.system import System, LabeledSystem, check_System, load_format
import dpdata.rdkit.utils
from dpdata.rdkit.sanitize import Sanitizer, SanitizeError
from copy import deepcopy
Expand Down Expand Up @@ -87,8 +86,16 @@ def __init__(self,
if type_map:
self.apply_type_map(type_map)

register_from_funcs = Register()
register_to_funcs = System.register_to_funcs + Register()
def from_fmt_obj(self, fmtobj, file_name, **kwargs):
mol = fmtobj.from_bond_order_system(file_name, **kwargs)
self.from_rdkit_mol(mol)
if hasattr(fmtobj.from_bond_order_system, 'post_func'):
for post_f in fmtobj.from_bond_order_system.post_func:
self.post_funcs.get_plugin(post_f)(self)
return self

def to_fmt_obj(self, fmtobj, *args, **kwargs):
return fmtobj.to_bond_order_system(self.data, self.rdkit_mol, *args, **kwargs)

def __repr__(self):
return self.__str__()
Expand Down Expand Up @@ -164,36 +171,3 @@ def from_rdkit_mol(self, rdkit_mol):
self.data = dpdata.rdkit.utils.mol_to_system_data(rdkit_mol)
self.data['bond_dict'] = dict([(f'{int(bond[0])}-{int(bond[1])}', bond[2]) for bond in self.data['bonds']])
self.rdkit_mol = rdkit_mol

@register_from_funcs.register_funcs('mol')
def from_mol_file(self, file_name):
mol = rdkit.Chem.MolFromMolFile(file_name, sanitize=False, removeHs=False)
self.from_rdkit_mol(mol)

@register_to_funcs.register_funcs("mol")
def to_mol_file(self, file_name, frame_idx=0):
assert (frame_idx < self.get_nframes())
rdkit.Chem.MolToMolFile(self.rdkit_mol, file_name, confId=frame_idx)

@register_from_funcs.register_funcs("sdf")
def from_sdf_file(self, file_name):
'''
Note that it requires all molecules in .sdf file must be of the same topology
'''
mols = [m for m in rdkit.Chem.SDMolSupplier(file_name, sanitize=False, removeHs=False)]
if len(mols) > 1:
mol = dpdata.rdkit.utils.combine_molecules(mols)
else:
mol = mols[0]
self.from_rdkit_mol(mol)

@register_to_funcs.register_funcs("sdf")
def to_sdf_file(self, file_name, frame_idx=-1):
sdf_writer = rdkit.Chem.SDWriter(file_name)
if frame_idx == -1:
for ii in self.get_nframes():
sdf_writer.write(self.rdkit_mol, confId=ii)
else:
assert (frame_idx < self.get_nframes())
sdf_writer.write(self.rdkit_mol, confId=frame_idx)
sdf_writer.close()
8 changes: 4 additions & 4 deletions dpdata/cp2k/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,20 +278,20 @@ def get_frames (fname) :

#conver to float array and add extra dimension for nframes
cell = np.array(cell)
cell = cell.astype(np.float)
cell = cell.astype(float)
cell = cell[np.newaxis, :, :]
coord = np.array(coord)
coord = coord.astype(np.float)
coord = coord.astype(float)
coord = coord[np.newaxis, :, :]
atom_symbol_list = np.array(atom_symbol_list)
force = np.array(force)
force = force.astype(np.float)
force = force.astype(float)
force = force[np.newaxis, :, :]

# virial is not necessary
if stress:
stress = np.array(stress)
stress = stress.astype(np.float)
stress = stress.astype(float)
stress = stress[np.newaxis, :, :]
# stress to virial conversion, default unit in cp2k is GPa
# note the stress is virial = stress * volume
Expand Down
116 changes: 116 additions & 0 deletions dpdata/format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Implement the format plugin system."""
import os
from collections import abc
from abc import ABC

from .plugin import Plugin


class Format(ABC):
__FormatPlugin = Plugin()
__FromPlugin = Plugin()
__ToPlugin = Plugin()

@staticmethod
def register(key):
return Format.__FormatPlugin.register(key)

@staticmethod
def register_from(key):
return Format.__FromPlugin.register(key)

@staticmethod
def register_to(key):
return Format.__ToPlugin.register(key)

@staticmethod
def get_formats():
return Format.__FormatPlugin.plugins

@staticmethod
def get_from_methods():
return Format.__FromPlugin.plugins

@staticmethod
def get_to_methods():
return Format.__ToPlugin.plugins

@staticmethod
def post(func_name):
def decorator(object):
if not isinstance(func_name, (list, tuple, set)):
object.post_func = (func_name,)
else:
object.post_func = func_name
return object
return decorator

def from_system(self, file_name, **kwargs):
"""System.from
Parameters
----------
file_name: str
file name
Returns
-------
data: dict
system data
"""
raise NotImplementedError("%s doesn't support System.from" %(self.__class__.__name__))

def to_system(self, data, *args, **kwargs):
"""System.to
Parameters
----------
data: dict
system data
"""
raise NotImplementedError("%s doesn't support System.to" %(self.__class__.__name__))

def from_labeled_system(self, file_name, **kwargs):
raise NotImplementedError("%s doesn't support LabeledSystem.from" %(self.__class__.__name__))

def to_labeled_system(self, data, *args, **kwargs):
return self.to_system(data, *args, **kwargs)

def from_bond_order_system(self, file_name, **kwargs):
raise NotImplementedError("%s doesn't support BondOrderSystem.from" %(self.__class__.__name__))

def to_bond_order_system(self, data, rdkit_mol, *args, **kwargs):
return self.to_system(data, *args, **kwargs)

class MultiModes:
"""File mode for MultiSystems
0 (default): not implemented
1: every directory under the top-level directory is a system
"""
NotImplemented = 0
Directory = 1

MultiMode = MultiModes.NotImplemented

def from_multi_systems(self, directory, **kwargs):
"""MultiSystems.from
Parameters
----------
directory: str
directory of system
Returns
-------
filenames: list[str]
list of filenames
"""
if self.MultiMode == self.MultiModes.Directory:
return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))]
raise NotImplementedError("%s doesn't support MultiSystems.from" %(self.__class__.__name__))

def to_multi_systems(self, formulas, directory, **kwargs):
if self.MultiMode == self.MultiModes.Directory:
return [os.path.join(directory, ff) for ff in formulas]
raise NotImplementedError("%s doesn't support MultiSystems.to" %(self.__class__.__name__))

36 changes: 36 additions & 0 deletions dpdata/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Base of plugin systems."""


class Plugin:
"""A class to register plugins.
Examples
--------
>>> Plugin = Register()
>>> @Plugin.register("xx")
def xxx():
pass
>>> print(Plugin.plugins['xx'])
"""
def __init__(self):
self.plugins = {}

def register(self, key):
"""Register a plugin.
Parameter
---------
key: str
Key of the plugin.
"""
def decorator(object):
self.plugins[key] = object
return object
return decorator

def get_plugin(self, key):
return self.plugins[key]

def __add__(self, other):
self.plugins.update(other.plugins)
return self
19 changes: 19 additions & 0 deletions dpdata/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import importlib
from pathlib import Path
try:
from importlib import metadata
except ImportError: # for Python<3.8
import importlib_metadata as metadata

PACKAGE_BASE = "dpdata.plugins"
NOT_LOADABLE = ("__init__.py",)

for module_file in Path(__file__).parent.glob("*.py"):
if module_file.name not in NOT_LOADABLE:
module_name = f".{module_file.stem}"
importlib.import_module(module_name, PACKAGE_BASE)

# https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html
eps = metadata.entry_points().get('dpdata.plugins', [])
for ep in eps:
plugin = ep.load()
10 changes: 10 additions & 0 deletions dpdata/plugins/abacus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import dpdata.abacus.scf
from dpdata.format import Format


@Format.register("abacus/scf")
@Format.register("abacus/pw/scf")
class AbacusSCFFormat(Format):
@Format.post("rot_lower_triangular")
def from_labeled_system(self, file_name, **kwargs):
return dpdata.abacus.scf.get_frame(file_name)
37 changes: 37 additions & 0 deletions dpdata/plugins/amber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import dpdata.amber.md
import dpdata.amber.sqm
from dpdata.format import Format


@Format.register("amber/md")
class AmberMDFormat(Format):
def from_system(self, file_name=None, parm7_file=None, nc_file=None, use_element_symbols=None):
# assume the prefix is the same if the spefic name is not given
if parm7_file is None:
parm7_file = file_name + ".parm7"
if nc_file is None:
nc_file = file_name + ".nc"
return dpdata.amber.md.read_amber_traj(parm7_file=parm7_file, nc_file=nc_file, use_element_symbols=use_element_symbols, labeled=False)

def from_labeled_system(self, file_name=None, parm7_file=None, nc_file=None, mdfrc_file=None, mden_file=None, mdout_file=None, use_element_symbols=None, **kwargs):
# assume the prefix is the same if the spefic name is not given
if parm7_file is None:
parm7_file = file_name + ".parm7"
if nc_file is None:
nc_file = file_name + ".nc"
if mdfrc_file is None:
mdfrc_file = file_name + ".mdfrc"
if mden_file is None:
mden_file = file_name + ".mden"
if mdout_file is None:
mdout_file = file_name + ".mdout"
return dpdata.amber.md.read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file, mdout_file, use_element_symbols)


@Format.register("sqm/out")
class SQMOutFormat(Format):
def from_system(self, fname, **kwargs):
'''
Read from ambertools sqm.out
'''
return dpdata.amber.sqm.to_system_data(fname)
52 changes: 52 additions & 0 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dpdata.format import Format


@Format.register("ase/structure")
class ASEStructureFormat(Format):
def to_system(self, data, **kwargs):
'''
convert System to ASE Atom obj
'''
from ase import Atoms

structures = []
species = [data['atom_names'][tt] for tt in data['atom_types']]

for ii in range(data['coords'].shape[0]):
structure = Atoms(
symbols=species, positions=data['coords'][ii], pbc=not data.get('nopbc', False), cell=data['cells'][ii])
structures.append(structure)

return structures

def to_labeled_system(self, data, *args, **kwargs):
'''Convert System to ASE Atoms object.'''
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator

structures = []
species = [data['atom_names'][tt] for tt in data['atom_types']]

for ii in range(data['coords'].shape[0]):
structure = Atoms(
symbols=species,
positions=data['coords'][ii],
pbc=not data.get('nopbc', False),
cell=data['cells'][ii]
)

results = {
'energy': data["energies"][ii],
'forces': data["forces"][ii]
}
if "virials" in data:
# convert to GPa as this is ase convention
v_pref = 1 * 1e4 / 1.602176621e6
vol = structure.get_volume()
results['stress'] = data["virials"][ii] / (v_pref * vol)

structure.calc = SinglePointCalculator(structure, **results)
structures.append(structure)

return structures
Loading

0 comments on commit ea32d45

Please sign in to comment.