Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Transformations from closure into class #2859

Merged
merged 33 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
05f2ddd
refactor transformations into class
yuxuanzhuang Jul 20, 2020
92a9653
pep8 and test fix
yuxuanzhuang Jul 20, 2020
9fc5071
move fit prep toinit
yuxuanzhuang Jul 20, 2020
6b7c45b
move rotate prep to init
yuxuanzhuang Jul 20, 2020
ae51c5b
move translate prep to init
yuxuanzhuang Jul 20, 2020
a542792
move wrap prep to init
yuxuanzhuang Jul 20, 2020
1729ee6
pep8
yuxuanzhuang Jul 20, 2020
26950b2
pep
yuxuanzhuang Jul 20, 2020
5aa8d2b
test pickle
yuxuanzhuang Jul 20, 2020
1c5bbdb
pep
yuxuanzhuang Jul 20, 2020
b09bef6
doc for each module
yuxuanzhuang Jul 27, 2020
235e8a9
read_offset issue?
yuxuanzhuang Jul 30, 2020
4447869
Merge remote-tracking branch 'mda_origin/develop' into new_transforma…
yuxuanzhuang Jul 30, 2020
5ef7d3f
change to dcd
yuxuanzhuang Jul 30, 2020
f6331d3
doc transformation
yuxuanzhuang Jul 31, 2020
1ebcc33
note revise
yuxuanzhuang Aug 1, 2020
c412104
doc for transformation
yuxuanzhuang Aug 3, 2020
3b1c470
changelog
yuxuanzhuang Aug 4, 2020
6e87625
Merge remote-tracking branch 'mda_origin/develop' into new_transforma…
yuxuanzhuang Aug 9, 2020
11f3644
enable universe pickle
yuxuanzhuang Aug 9, 2020
7c857f6
merge to develop
yuxuanzhuang Aug 12, 2020
58f6e4e
transformation doc
yuxuanzhuang Aug 12, 2020
26dab2f
Merge branch 'develop' into new_transformation
orbeckst Aug 16, 2020
648c9b0
merge to develop
yuxuanzhuang Aug 27, 2020
3f083ae
change universe pickle to reduce
yuxuanzhuang Aug 27, 2020
9ae7f83
remove pure pickle/unpickle tests
yuxuanzhuang Aug 27, 2020
26f68bc
Merge branch 'new_transformation' of https://github.com/yuxuanzhuang/…
yuxuanzhuang Aug 27, 2020
ae8a2a2
pep
yuxuanzhuang Aug 27, 2020
499efc1
example
yuxuanzhuang Aug 28, 2020
2793ba2
doc
yuxuanzhuang Aug 28, 2020
acb4488
Merge branch 'develop' into new_transformation
yuxuanzhuang Aug 30, 2020
1080678
Update package/MDAnalysis/transformations/__init__.py
yuxuanzhuang Sep 4, 2020
0655275
change snippet
yuxuanzhuang Sep 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 54 additions & 48 deletions package/MDAnalysis/transformations/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
from ..lib.transformations import euler_from_matrix, euler_matrix


def fit_translation(ag, reference, plane=None, weights=None):

class fit_translation(object):
"""Translates a given AtomGroup so that its center of geometry/mass matches
the respective center of the given reference. A plane can be given by the
user using the option `plane`, and will result in the removal of
Expand Down Expand Up @@ -82,38 +81,42 @@ def fit_translation(ag, reference, plane=None, weights=None):
-------
MDAnalysis.coordinates.base.Timestep
"""

if plane is not None:
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}
def __init__(self, ag, reference, plane=None, weights=None):
self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights


def __call__(self, ts):
if self.plane is not None:
richardjgowers marked this conversation as resolved.
Show resolved Hide resolved
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}
try:
plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') from None
try:
plane = axes[plane]
except (TypeError, KeyError):
raise ValueError(f'{plane} is not a valid plane') from None
try:
if ag.atoms.n_residues != reference.atoms.n_residues:
errmsg = f"{ag} and {reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{ag} or {reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(reference.atoms, ag.atoms)
weights = align.get_weights(ref.atoms, weights=weights)
ref_com = ref.center(weights)
ref_coordinates = ref.atoms.positions - ref_com

def wrapped(ts):
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = f"{self.ag} and {self.reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{self.ag} or {self.reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(self.reference.atoms, self.ag.atoms)
weights = align.get_weights(ref.atoms, weights=self.weights)
ref_com = ref.center(weights)
ref_coordinates = ref.atoms.positions - ref_com

mobile_com = np.asarray(mobile.atoms.center(weights), np.float32)
vector = ref_com - mobile_com
if plane is not None:
if self.plane is not None:
vector[plane] = 0
ts.positions += vector

return ts

return wrapped


def fit_rot_trans(ag, reference, plane=None, weights=None):
class fit_rot_trans(object):
"""Perform a spatial superposition by minimizing the RMSD.

Spatially align the group of atoms `ag` to `reference` by doing a RMSD
Expand Down Expand Up @@ -160,30 +163,35 @@ def fit_rot_trans(ag, reference, plane=None, weights=None):
-------
MDAnalysis.coordinates.base.Timestep
"""
if plane is not None:
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}
def __init__(self, ag, reference, plane=None, weights=None):
self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights
def __call__(self, ts):
if self.plane is not None:
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}
try:
plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') from None
try:
plane = axes[plane]
except (TypeError, KeyError):
raise ValueError(f'{plane} is not a valid plane') from None
try:
if ag.atoms.n_residues != reference.atoms.n_residues:
errmsg = f"{ag} and {reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{ag} or {reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(reference.atoms, ag.atoms)
weights = align.get_weights(ref.atoms, weights=weights)
ref_com = ref.center(weights)
ref_coordinates = ref.atoms.positions - ref_com

def wrapped(ts):
mobile_com = mobile.atoms.center(weights)
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = f"{self.ag} and {self.reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{self.ag} or {self.reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(self.reference.atoms, self.ag.atoms)
weights = align.get_weights(ref.atoms, weights=self.weights)
ref_com = ref.center(self.weights)
ref_coordinates = ref.atoms.positions - ref_com

mobile_com = mobile.atoms.center(self.weights)
mobile_coordinates = mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates, ref_coordinates, weights=weights)
rotation, dump = align.rotation_matrix(mobile_coordinates, ref_coordinates, weights=self.weights)
vector = ref_com
if plane is not None:
if self.plane is not None:
matrix = np.r_[rotation, np.zeros(3).reshape(1,3)]
matrix = np.c_[matrix, np.zeros(4)]
euler_angs = np.asarray(euler_from_matrix(matrix, axes='sxyz'), np.float32)
Expand All @@ -196,5 +204,3 @@ def wrapped(ts):
ts.positions = ts.positions + vector

return ts

return wrapped
22 changes: 15 additions & 7 deletions package/MDAnalysis/transformations/positionaveraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(self, avg_frames, check_reset=True):
self.check_reset = check_reset
self.current_avg = 0
self.resetarrays()
self.current_frame = 0

def resetarrays(self):
self.idx_array = np.empty(self.avg_frames)
Expand All @@ -164,24 +165,31 @@ def rollposx(self,ts):


def __call__(self, ts):
# calling the same timestep will not add new data to coord_array
# This can prevent from getting different values when
# call `u.trajectory[i]` multiple times.
if (ts.frame == self.current_frame
and hasattr(self, 'coord_array')
and not np.isnan(self.idx_array).all()):
test = ~np.isnan(self.idx_array)
ts.positions = np.mean(self.coord_array[...,test], axis=2)
return ts
else:
self.current_frame = ts.frame

self.rollidx(ts)
test = ~np.isnan(self.idx_array)
self.current_avg = sum(test)
if self.current_avg == 1:
return ts

if self.check_reset:
sign = np.sign(np.diff(self.idx_array[test]))

if not (np.all(sign == 1) or np.all(sign==-1)):
warnings.warn('Cannot average position for non sequential'
'iterations. Averager will be reset.',
Warning)
self.resetarrays()
return self(ts)

self.rollposx(ts)
ts.positions = np.mean(self.coord_array[...,test], axis=2)

return ts

79 changes: 43 additions & 36 deletions package/MDAnalysis/transformations/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
Trajectory rotation --- :mod:`MDAnalysis.transformations.rotate`
================================================================

Rotates the coordinates by a given angle arround an axis formed by a direction and a
point
Rotates the coordinates by a given angle arround an axis formed by a direction
and a point.

.. autofunction:: rotateby

Expand All @@ -38,7 +38,8 @@
from ..lib.transformations import rotation_matrix
from ..lib.util import get_weights

def rotateby(angle, direction, point=None, ag=None, weights=None, wrap=False):

class rotateby(object):
'''
Rotates the trajectory by a given angle on a given axis. The axis is defined by
the user, combining the direction vector and a point. This point can be the center
Expand Down Expand Up @@ -104,37 +105,47 @@ def rotateby(angle, direction, point=None, ag=None, weights=None, wrap=False):
after rotating the trajectory.

'''
angle = np.deg2rad(angle)
try:
direction = np.asarray(direction, np.float32)
if direction.shape != (3, ) and direction.shape != (1, 3):
raise ValueError('{} is not a valid direction'.format(direction))
direction = direction.reshape(3, )
except ValueError:
raise ValueError(f'{direction} is not a valid direction') from None
if point is not None:
point = np.asarray(point, np.float32)
if point.shape != (3, ) and point.shape != (1, 3):
raise ValueError('{} is not a valid point'.format(point))
point = point.reshape(3, )
elif ag:
def __init__(self, angle, direction, point=None, ag=None, weights=None, wrap=False):
self.angle = angle
self.direction = direction
self.point = point
self.ag = ag
self.weights = weights
self.wrap = wrap

def __call__(self, ts):
angle = np.deg2rad(self.angle)
try:
atoms = ag.atoms
except AttributeError:
raise ValueError(f'{ag} is not an AtomGroup object') from None
else:
direction = np.asarray(self.direction, np.float32)
if direction.shape != (3, ) and direction.shape != (1, 3):
raise ValueError('{} is not a valid direction'
.format(direction))
direction = direction.reshape(3, )
except ValueError:
raise ValueError(f'{self.direction} is not a valid direction') from None
if self.point is not None:
point = np.asarray(self.point, np.float32)
if point.shape != (3, ) and point.shape != (1, 3):
raise ValueError('{} is not a valid point'.format(point))
point = point.reshape(3, )
elif self.ag:
try:
weights = get_weights(atoms, weights=weights)
except (ValueError, TypeError):
errmsg = ("weights must be {'mass', None} or an iterable of "
"the same size as the atomgroup.")
raise TypeError(errmsg) from None
center_method = partial(atoms.center, weights, pbc=wrap)
else:
raise ValueError('A point or an AtomGroup must be specified')

def wrapped(ts):
if point is None:
atoms = self.ag.atoms
except AttributeError:
raise ValueError(f'{self.ag} is not an AtomGroup object') \
from None
else:
try:
weights = get_weights(atoms, weights=self.weights)
except (ValueError, TypeError):
errmsg = ("weights must be {'mass', None} or an iterable of"
"the same size as the atomgroup.")
raise TypeError(errmsg) from None
center_method = partial(atoms.center, weights, pbc=self.wrap)
else:
raise ValueError('A point or an AtomGroup must be specified')

if self.point is None:
position = center_method()
else:
position = point
Expand All @@ -143,8 +154,4 @@ def wrapped(ts):
translation = matrix[:3, 3]
ts.positions= np.dot(ts.positions, rotation)
ts.positions += translation

return ts

return wrapped

72 changes: 37 additions & 35 deletions package/MDAnalysis/transformations/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from ..lib.mdamath import triclinic_vectors

def translate(vector):
class translate(object):
"""
Translates the coordinates of a given :class:`~MDAnalysis.coordinates.base.Timestep`
instance by a given vector.
Expand All @@ -60,20 +60,20 @@ def translate(vector):
:class:`~MDAnalysis.coordinates.base.Timestep` object

"""
if len(vector)>2:
vector = np.float32(vector)
else:
raise ValueError("{} vector is too short".format(vector))
def __init__(self, vector):
self.vector = vector

def __call__(self, ts):
if len(self.vector)>2:
vector = np.float32(self.vector)
else:
raise ValueError("{} vector is too short".format(self.vector))

def wrapped(ts):
ts.positions += vector

return ts

return wrapped


def center_in_box(ag, center='geometry', point=None, wrap=False):
class center_in_box(object):
"""
Translates the coordinates of a given :class:`~MDAnalysis.coordinates.base.Timestep`
instance so that the center of geometry/mass of the given :class:`~MDAnalysis.core.groups.AtomGroup`
Expand Down Expand Up @@ -109,28 +109,33 @@ def center_in_box(ag, center='geometry', point=None, wrap=False):
:class:`~MDAnalysis.coordinates.base.Timestep` object

"""

pbc_arg = wrap
if point:
point = np.asarray(point, np.float32)
if point.shape != (3, ) and point.shape != (1, 3):
raise ValueError('{} is not a valid point'.format(point))
try:
if center == 'geometry':
center_method = partial(ag.center_of_geometry, pbc=pbc_arg)
elif center == 'mass':
center_method = partial(ag.center_of_mass, pbc=pbc_arg)
else:
raise ValueError('{} is not a valid argument for center'.format(center))
except AttributeError:
if center == 'mass':
errmsg = f'{ag} is not an AtomGroup object with masses'
raise AttributeError(errmsg) from None
else:
raise ValueError(f'{ag} is not an AtomGroup object') from None

def wrapped(ts):
if point is None:
def __init__(self, ag, center='geometry', point=None, wrap=False):
self.ag = ag
self.center = center
self.point = point
self.wrap = wrap

def __call__(self, ts):
pbc_arg = self.wrap
if self.point:
point = np.asarray(self.point, np.float32)
if point.shape != (3, ) and point.shape != (1, 3):
raise ValueError('{} is not a valid point'.format(point))
try:
if self.center == 'geometry':
center_method = partial(self.ag.center_of_geometry, pbc=pbc_arg)
elif self.center == 'mass':
center_method = partial(self.ag.center_of_mass, pbc=pbc_arg)
else:
raise ValueError('{} is not a valid argument for center'.format(self.center))
except AttributeError:
if self.center == 'mass':
errmsg = f'{self.ag} is not an AtomGroup object with masses'
raise AttributeError(errmsg) from None
else:
raise ValueError(f'{self.ag} is not an AtomGroup object') from None

if self.point is None:
boxcenter = np.sum(ts.triclinic_dimensions, axis=0) / 2
else:
boxcenter = point
Expand All @@ -141,6 +146,3 @@ def wrapped(ts):
ts.positions += vector

return ts

return wrapped

Loading