-
Notifications
You must be signed in to change notification settings - Fork 137
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #168 from njzjz/refactor
refactor and add format plugin system
- Loading branch information
Showing
30 changed files
with
1,105 additions
and
731 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
"""Implement the format plugin system.""" | ||
import os | ||
from collections import abc | ||
from abc import ABC | ||
|
||
from .plugin import Plugin | ||
|
||
|
||
class Format(ABC): | ||
__FormatPlugin = Plugin() | ||
__FromPlugin = Plugin() | ||
__ToPlugin = Plugin() | ||
|
||
@staticmethod | ||
def register(key): | ||
return Format.__FormatPlugin.register(key) | ||
|
||
@staticmethod | ||
def register_from(key): | ||
return Format.__FromPlugin.register(key) | ||
|
||
@staticmethod | ||
def register_to(key): | ||
return Format.__ToPlugin.register(key) | ||
|
||
@staticmethod | ||
def get_formats(): | ||
return Format.__FormatPlugin.plugins | ||
|
||
@staticmethod | ||
def get_from_methods(): | ||
return Format.__FromPlugin.plugins | ||
|
||
@staticmethod | ||
def get_to_methods(): | ||
return Format.__ToPlugin.plugins | ||
|
||
@staticmethod | ||
def post(func_name): | ||
def decorator(object): | ||
if not isinstance(func_name, (list, tuple, set)): | ||
object.post_func = (func_name,) | ||
else: | ||
object.post_func = func_name | ||
return object | ||
return decorator | ||
|
||
def from_system(self, file_name, **kwargs): | ||
"""System.from | ||
Parameters | ||
---------- | ||
file_name: str | ||
file name | ||
Returns | ||
------- | ||
data: dict | ||
system data | ||
""" | ||
raise NotImplementedError("%s doesn't support System.from" %(self.__class__.__name__)) | ||
|
||
def to_system(self, data, *args, **kwargs): | ||
"""System.to | ||
Parameters | ||
---------- | ||
data: dict | ||
system data | ||
""" | ||
raise NotImplementedError("%s doesn't support System.to" %(self.__class__.__name__)) | ||
|
||
def from_labeled_system(self, file_name, **kwargs): | ||
raise NotImplementedError("%s doesn't support LabeledSystem.from" %(self.__class__.__name__)) | ||
|
||
def to_labeled_system(self, data, *args, **kwargs): | ||
return self.to_system(data, *args, **kwargs) | ||
|
||
def from_bond_order_system(self, file_name, **kwargs): | ||
raise NotImplementedError("%s doesn't support BondOrderSystem.from" %(self.__class__.__name__)) | ||
|
||
def to_bond_order_system(self, data, rdkit_mol, *args, **kwargs): | ||
return self.to_system(data, *args, **kwargs) | ||
|
||
class MultiModes: | ||
"""File mode for MultiSystems | ||
0 (default): not implemented | ||
1: every directory under the top-level directory is a system | ||
""" | ||
NotImplemented = 0 | ||
Directory = 1 | ||
|
||
MultiMode = MultiModes.NotImplemented | ||
|
||
def from_multi_systems(self, directory, **kwargs): | ||
"""MultiSystems.from | ||
Parameters | ||
---------- | ||
directory: str | ||
directory of system | ||
Returns | ||
------- | ||
filenames: list[str] | ||
list of filenames | ||
""" | ||
if self.MultiMode == self.MultiModes.Directory: | ||
return [name for name in os.listdir(directory) if os.path.isdir(os.path.join(directory, name))] | ||
raise NotImplementedError("%s doesn't support MultiSystems.from" %(self.__class__.__name__)) | ||
|
||
def to_multi_systems(self, formulas, directory, **kwargs): | ||
if self.MultiMode == self.MultiModes.Directory: | ||
return [os.path.join(directory, ff) for ff in formulas] | ||
raise NotImplementedError("%s doesn't support MultiSystems.to" %(self.__class__.__name__)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Base of plugin systems.""" | ||
|
||
|
||
class Plugin: | ||
"""A class to register plugins. | ||
Examples | ||
-------- | ||
>>> Plugin = Register() | ||
>>> @Plugin.register("xx") | ||
def xxx(): | ||
pass | ||
>>> print(Plugin.plugins['xx']) | ||
""" | ||
def __init__(self): | ||
self.plugins = {} | ||
|
||
def register(self, key): | ||
"""Register a plugin. | ||
Parameter | ||
--------- | ||
key: str | ||
Key of the plugin. | ||
""" | ||
def decorator(object): | ||
self.plugins[key] = object | ||
return object | ||
return decorator | ||
|
||
def get_plugin(self, key): | ||
return self.plugins[key] | ||
|
||
def __add__(self, other): | ||
self.plugins.update(other.plugins) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import importlib | ||
from pathlib import Path | ||
try: | ||
from importlib import metadata | ||
except ImportError: # for Python<3.8 | ||
import importlib_metadata as metadata | ||
|
||
PACKAGE_BASE = "dpdata.plugins" | ||
NOT_LOADABLE = ("__init__.py",) | ||
|
||
for module_file in Path(__file__).parent.glob("*.py"): | ||
if module_file.name not in NOT_LOADABLE: | ||
module_name = f".{module_file.stem}" | ||
importlib.import_module(module_name, PACKAGE_BASE) | ||
|
||
# https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html | ||
eps = metadata.entry_points().get('dpdata.plugins', []) | ||
for ep in eps: | ||
plugin = ep.load() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import dpdata.abacus.scf | ||
from dpdata.format import Format | ||
|
||
|
||
@Format.register("abacus/scf") | ||
@Format.register("abacus/pw/scf") | ||
class AbacusSCFFormat(Format): | ||
@Format.post("rot_lower_triangular") | ||
def from_labeled_system(self, file_name, **kwargs): | ||
return dpdata.abacus.scf.get_frame(file_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import dpdata.amber.md | ||
import dpdata.amber.sqm | ||
from dpdata.format import Format | ||
|
||
|
||
@Format.register("amber/md") | ||
class AmberMDFormat(Format): | ||
def from_system(self, file_name=None, parm7_file=None, nc_file=None, use_element_symbols=None): | ||
# assume the prefix is the same if the spefic name is not given | ||
if parm7_file is None: | ||
parm7_file = file_name + ".parm7" | ||
if nc_file is None: | ||
nc_file = file_name + ".nc" | ||
return dpdata.amber.md.read_amber_traj(parm7_file=parm7_file, nc_file=nc_file, use_element_symbols=use_element_symbols, labeled=False) | ||
|
||
def from_labeled_system(self, file_name=None, parm7_file=None, nc_file=None, mdfrc_file=None, mden_file=None, mdout_file=None, use_element_symbols=None, **kwargs): | ||
# assume the prefix is the same if the spefic name is not given | ||
if parm7_file is None: | ||
parm7_file = file_name + ".parm7" | ||
if nc_file is None: | ||
nc_file = file_name + ".nc" | ||
if mdfrc_file is None: | ||
mdfrc_file = file_name + ".mdfrc" | ||
if mden_file is None: | ||
mden_file = file_name + ".mden" | ||
if mdout_file is None: | ||
mdout_file = file_name + ".mdout" | ||
return dpdata.amber.md.read_amber_traj(parm7_file, nc_file, mdfrc_file, mden_file, mdout_file, use_element_symbols) | ||
|
||
|
||
@Format.register("sqm/out") | ||
class SQMOutFormat(Format): | ||
def from_system(self, fname, **kwargs): | ||
''' | ||
Read from ambertools sqm.out | ||
''' | ||
return dpdata.amber.sqm.to_system_data(fname) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from dpdata.format import Format | ||
|
||
|
||
@Format.register("ase/structure") | ||
class ASEStructureFormat(Format): | ||
def to_system(self, data, **kwargs): | ||
''' | ||
convert System to ASE Atom obj | ||
''' | ||
from ase import Atoms | ||
|
||
structures = [] | ||
species = [data['atom_names'][tt] for tt in data['atom_types']] | ||
|
||
for ii in range(data['coords'].shape[0]): | ||
structure = Atoms( | ||
symbols=species, positions=data['coords'][ii], pbc=not data.get('nopbc', False), cell=data['cells'][ii]) | ||
structures.append(structure) | ||
|
||
return structures | ||
|
||
def to_labeled_system(self, data, *args, **kwargs): | ||
'''Convert System to ASE Atoms object.''' | ||
from ase import Atoms | ||
from ase.calculators.singlepoint import SinglePointCalculator | ||
|
||
structures = [] | ||
species = [data['atom_names'][tt] for tt in data['atom_types']] | ||
|
||
for ii in range(data['coords'].shape[0]): | ||
structure = Atoms( | ||
symbols=species, | ||
positions=data['coords'][ii], | ||
pbc=not data.get('nopbc', False), | ||
cell=data['cells'][ii] | ||
) | ||
|
||
results = { | ||
'energy': data["energies"][ii], | ||
'forces': data["forces"][ii] | ||
} | ||
if "virials" in data: | ||
# convert to GPa as this is ase convention | ||
v_pref = 1 * 1e4 / 1.602176621e6 | ||
vol = structure.get_volume() | ||
results['stress'] = data["virials"][ii] / (v_pref * vol) | ||
|
||
structure.calc = SinglePointCalculator(structure, **results) | ||
structures.append(structure) | ||
|
||
return structures |
Oops, something went wrong.