Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Transformations from closure into class #2859

Merged
merged 33 commits into from
Sep 8, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
05f2ddd
refactor transformations into class
yuxuanzhuang Jul 20, 2020
92a9653
pep8 and test fix
yuxuanzhuang Jul 20, 2020
9fc5071
move fit prep toinit
yuxuanzhuang Jul 20, 2020
6b7c45b
move rotate prep to init
yuxuanzhuang Jul 20, 2020
ae51c5b
move translate prep to init
yuxuanzhuang Jul 20, 2020
a542792
move wrap prep to init
yuxuanzhuang Jul 20, 2020
1729ee6
pep8
yuxuanzhuang Jul 20, 2020
26950b2
pep
yuxuanzhuang Jul 20, 2020
5aa8d2b
test pickle
yuxuanzhuang Jul 20, 2020
1c5bbdb
pep
yuxuanzhuang Jul 20, 2020
b09bef6
doc for each module
yuxuanzhuang Jul 27, 2020
235e8a9
read_offset issue?
yuxuanzhuang Jul 30, 2020
4447869
Merge remote-tracking branch 'mda_origin/develop' into new_transforma…
yuxuanzhuang Jul 30, 2020
5ef7d3f
change to dcd
yuxuanzhuang Jul 30, 2020
f6331d3
doc transformation
yuxuanzhuang Jul 31, 2020
1ebcc33
note revise
yuxuanzhuang Aug 1, 2020
c412104
doc for transformation
yuxuanzhuang Aug 3, 2020
3b1c470
changelog
yuxuanzhuang Aug 4, 2020
6e87625
Merge remote-tracking branch 'mda_origin/develop' into new_transforma…
yuxuanzhuang Aug 9, 2020
11f3644
enable universe pickle
yuxuanzhuang Aug 9, 2020
7c857f6
merge to develop
yuxuanzhuang Aug 12, 2020
58f6e4e
transformation doc
yuxuanzhuang Aug 12, 2020
26dab2f
Merge branch 'develop' into new_transformation
orbeckst Aug 16, 2020
648c9b0
merge to develop
yuxuanzhuang Aug 27, 2020
3f083ae
change universe pickle to reduce
yuxuanzhuang Aug 27, 2020
9ae7f83
remove pure pickle/unpickle tests
yuxuanzhuang Aug 27, 2020
26f68bc
Merge branch 'new_transformation' of https://github.com/yuxuanzhuang/…
yuxuanzhuang Aug 27, 2020
ae8a2a2
pep
yuxuanzhuang Aug 27, 2020
499efc1
example
yuxuanzhuang Aug 28, 2020
2793ba2
doc
yuxuanzhuang Aug 28, 2020
acb4488
Merge branch 'develop' into new_transformation
yuxuanzhuang Aug 30, 2020
1080678
Update package/MDAnalysis/transformations/__init__.py
yuxuanzhuang Sep 4, 2020
0655275
change snippet
yuxuanzhuang Sep 6, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 50 additions & 14 deletions package/MDAnalysis/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,70 @@
Trajectory transformations --- :mod:`MDAnalysis.transformations`
================================================================

The transformations submodule contains a collection of functions to modify the
trajectory. Coordinate transformations, such as PBC corrections and molecule fitting
are often required for some analyses and visualization, and the functions in this
module allow transformations to be applied on-the-fly.
These transformation functions can be called by the user for any given
timestep of the trajectory, added as a workflow using :meth:`add_transformations`
of the :mod:`~MDAnalysis.coordinates.base` module, or upon Universe creation using
The transformations submodule contains a collection of function-like classes to
modify the trajectory.
Coordinate transformations, such as PBC corrections and molecule fitting
are often required for some analyses and visualization, and the functions in
this module allow transformations to be applied on-the-fly.

A typical transformation class looks like this:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A typical transformation class looks like this:
A typical transformation class looks like this (note that we keep its name
lowercase because we will treat it as a function, thanks to the ``__call__``
method):


.. code-blocks:: python

class transfomration(object):
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, *args, **kwargs):
# do some things
# save needed args as attributes.
self.needed_var = args[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little abstract, a trivial example would be better


def __call__(self, ts):
# apply changes to the Timestep object
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes needn't be applied to the timestep oddly enough. You could E.g. modify an AtomGroup and return ts

return ts

Copy link
Member

@orbeckst orbeckst Aug 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You showed the abstract class. Now show a concrete example to address @richardjgowers comment. For instance

Suggested change
As a concrete example we will write a transformation that rotates a group of atoms around
the z-axis through the center of geometry by a fixed increment for every time step. We will use
:meth:`MDAnalysis.core.groups.AtomGroup.rotateby` and simply increase the rotation angle
every time the transformation is called::
class spin_atoms(object):
def __init__(self, atoms, dphi):
"""Rotate atoms by dphi degrees for every time step (around the z axis)"""
self.atoms = atoms
self.dphi = dphi
self.axis = np.array([0, 0, 1])
def __call__(self, ts):
phi = self.dphi * ts.frame
self.atoms.rotateby(phi, self.axis)
return ts
This transformation can be used as ::
u = mda.Universe(PSF, DCD)
u.trajectory.add_transformations(spin_atoms(u.select_atoms("protein"), 1.0))

I currently can't get nglview to work in my notebook so you'll need to check that this actually works... or come up with another example.

EDIT: forgot to increment phi...

EDIT 2: yes, it's pretty dumb that the phi angle just keeps incrementing, no matter what you do with the trajectory. Perhaps better to do something like phi = ts.frame * self.dphi

EDIT 3: changed it to phi = ts.frame * self.dphi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! it works. And I think it makes sense to just rotate by a fixed angle. (for visualization perhaps)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the snippet.

See `MDAnalysis.transformations.translate` for a simple example.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is they can't see the code from the docs, so instead have the above example be translate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this part (inside __init__) of the doc is not present anywhere in the doc webpages. A more explicit description of how transformation is done is written separately. By which I guess makes it fine to be a little more abstract?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above is not a valid sphinx link.

Suggested change
See `MDAnalysis.transformations.translate` for a simple example.
See :mod:`MDAnalysis.transformations.translate` for a simple example.


These transformation functions can be called by the user for any given timestep
of the trajectory, added as a workflow using :meth:`add_transformations`
of the :mod:`~MDAnalysis.coordinates.base`, or upon Universe creation using
the keyword argument `transformations`. Note that in the two latter cases, the
workflow cannot be changed after being defined.
workflow cannot be changed after being defined. for example:

In addition to the specific arguments that each transformation can take, they also
contain a wrapped function that takes a `Timestep` object as argument.
So, a transformation can be roughly defined as follows:
.. code-block:: python

u = mda.Universe(GRO, XTC)
ts = u.trajectory[0]
trans = transformation(args)
ts = trans(ts)

# or add as a workflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we only want to promote the workflow usage, else it's just a function they can (forget to) apply

u.trajectory.add_transformations(trans)

Transformations can also be created as a closure/nested function.
In addition to the specific arguments that each transformation can take, they
also contain a wrapped function that takes a `Timestep` object as argument.
So, a closure-style transformation can be roughly defined as follows:

.. code-block:: python

def transformations(*args,**kwargs):
def transformation(*args,**kwargs):
# do some things
def wrapped(ts):
# apply changes to the Timestep object
return ts

return wrapped


See `MDAnalysis.transformations.translate` for a simple example.
Note, to meet the need of serialization of universe, only transformation class
are used after MDAnlaysis 2.0.0. One can still write functions (closures) as in
MDA 1.x, but that these cannot be serialized and thus will not work with all
forms of parallel analysis. For detailed descriptions about how to write a
closure-style transformation, read the code in MDA 1.x as a reference
or read MDAnalysis UserGuide.
orbeckst marked this conversation as resolved.
Show resolved Hide resolved


.. versionchanged:: 2.0.0
Transformations should now be created as classes with a :meth:`__call__`
method instead of being written as a function/closure.
"""

from .translate import translate, center_in_box
Expand Down
171 changes: 102 additions & 69 deletions package/MDAnalysis/transformations/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,27 @@
Translate and/or rotates the coordinates of a given trajectory to align
a given AtomGroup to a reference structure.

.. autofunction:: fit_translation
.. autoclass:: fit_translation

.. autofunction:: fit_rot_trans
.. autoclass:: fit_rot_trans

"""
import numpy as np
from functools import partial

from ..analysis import align
from ..lib.util import get_weights
from ..lib.transformations import euler_from_matrix, euler_matrix


def fit_translation(ag, reference, plane=None, weights=None):

class fit_translation(object):
"""Translates a given AtomGroup so that its center of geometry/mass matches
the respective center of the given reference. A plane can be given by the
user using the option `plane`, and will result in the removal of
the translation motions of the AtomGroup over that particular plane.

Example
-------
Removing the translations of a given AtomGroup `ag` on the XY plane by fitting
its center of mass to the center of mass of a reference `ref`:
Removing the translations of a given AtomGroup `ag` on the XY plane by
fitting its center of mass to the center of mass of a reference `ref`:

.. code-block:: python

Expand All @@ -67,11 +64,12 @@ def fit_translation(ag, reference, plane=None, weights=None):
:class:`~MDAnalysis.core.groups.AtomGroup` or a whole
:class:`~MDAnalysis.core.universe.Universe`
reference : Universe or AtomGroup
reference structure, a :class:`~MDAnalysis.core.groups.AtomGroup` or a whole
:class:`~MDAnalysis.core.universe.Universe`
reference structure, a :class:`~MDAnalysis.core.groups.AtomGroup` or a
whole :class:`~MDAnalysis.core.universe.Universe`
plane: str, optional
used to define the plane on which the translations will be removed. Defined as a
string of the plane. Suported planes are yz, xz and xy planes.
used to define the plane on which the translations will be removed.
Defined as a string of the plane.
Suported planes are yz, xz and xy planes.
weights : {"mass", ``None``} or array_like, optional
choose weights. With ``"mass"`` uses masses as weights; with ``None``
weigh each atom equally. If a float array of the same length as
Expand All @@ -81,39 +79,56 @@ def fit_translation(ag, reference, plane=None, weights=None):
Returns
-------
MDAnalysis.coordinates.base.Timestep
"""

if plane is not None:
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}

.. versionchanged:: 2.0.0
The transformation was changed from a function/closure to a class
with ``__call__``.
"""
def __init__(self, ag, reference, plane=None, weights=None):
self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights

if self.plane is not None:
richardjgowers marked this conversation as resolved.
Show resolved Hide resolved
axes = {'yz': 0, 'xz': 1, 'xy': 2}
try:
self.plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') \
from None
try:
plane = axes[plane]
except (TypeError, KeyError):
raise ValueError(f'{plane} is not a valid plane') from None
try:
if ag.atoms.n_residues != reference.atoms.n_residues:
errmsg = f"{ag} and {reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{ag} or {reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(reference.atoms, ag.atoms)
weights = align.get_weights(ref.atoms, weights=weights)
ref_com = ref.center(weights)
ref_coordinates = ref.atoms.positions - ref_com

def wrapped(ts):
mobile_com = np.asarray(mobile.atoms.center(weights), np.float32)
vector = ref_com - mobile_com
if plane is not None:
vector[plane] = 0
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = (
f"{self.ag} and {self.reference} have mismatched"
f"number of residues"
)

raise ValueError(errmsg)
except AttributeError:
errmsg = (
f"{self.ag} or {self.reference} is not valid"
f"Universe/AtomGroup"
)
raise AttributeError(errmsg) from None
self.ref, self.mobile = align.get_matching_atoms(self.reference.atoms,
self.ag.atoms)
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)

def __call__(self, ts):
mobile_com = np.asarray(self.mobile.atoms.center(self.weights),
np.float32)
vector = self.ref_com - mobile_com
if self.plane is not None:
vector[self.plane] = 0
ts.positions += vector

return ts

return wrapped


def fit_rot_trans(ag, reference, plane=None, weights=None):
class fit_rot_trans(object):
"""Perform a spatial superposition by minimizing the RMSD.

Spatially align the group of atoms `ag` to `reference` by doing a RMSD
Expand Down Expand Up @@ -160,41 +175,59 @@ def fit_rot_trans(ag, reference, plane=None, weights=None):
-------
MDAnalysis.coordinates.base.Timestep
"""
if plane is not None:
axes = {'yz' : 0, 'xz' : 1, 'xy' : 2}
def __init__(self, ag, reference, plane=None, weights=None):
self.ag = ag
self.reference = reference
self.plane = plane
self.weights = weights

if self.plane is not None:
axes = {'yz': 0, 'xz': 1, 'xy': 2}
try:
self.plane = axes[self.plane]
except (TypeError, KeyError):
raise ValueError(f'{self.plane} is not a valid plane') \
from None
try:
plane = axes[plane]
except (TypeError, KeyError):
raise ValueError(f'{plane} is not a valid plane') from None
try:
if ag.atoms.n_residues != reference.atoms.n_residues:
errmsg = f"{ag} and {reference} have mismatched number of residues"
raise ValueError(errmsg)
except AttributeError:
errmsg = f"{ag} or {reference} is not valid Universe/AtomGroup"
raise AttributeError(errmsg) from None
ref, mobile = align.get_matching_atoms(reference.atoms, ag.atoms)
weights = align.get_weights(ref.atoms, weights=weights)
ref_com = ref.center(weights)
ref_coordinates = ref.atoms.positions - ref_com

def wrapped(ts):
mobile_com = mobile.atoms.center(weights)
mobile_coordinates = mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates, ref_coordinates, weights=weights)
vector = ref_com
if plane is not None:
matrix = np.r_[rotation, np.zeros(3).reshape(1,3)]
if self.ag.atoms.n_residues != self.reference.atoms.n_residues:
errmsg = (
f"{self.ag} and {self.reference} have mismatched "
f"number of residues"
)
raise ValueError(errmsg)
except AttributeError:
errmsg = (
f"{self.ag} or {self.reference} is not valid "
f"Universe/AtomGroup"
)
raise AttributeError(errmsg) from None
self.ref, self.mobile = align.get_matching_atoms(self.reference.atoms,
self.ag.atoms)
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)
self.ref_coordinates = self.ref.atoms.positions - self.ref_com

def __call__(self, ts):
mobile_com = self.mobile.atoms.center(self.weights)
mobile_coordinates = self.mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates,
self.ref_coordinates,
weights=self.weights)
vector = self.ref_com
if self.plane is not None:
matrix = np.r_[rotation, np.zeros(3).reshape(1, 3)]
matrix = np.c_[matrix, np.zeros(4)]
euler_angs = np.asarray(euler_from_matrix(matrix, axes='sxyz'), np.float32)
euler_angs = np.asarray(euler_from_matrix(matrix, axes='sxyz'),
np.float32)
for i in range(0, euler_angs.size):
euler_angs[i] = ( euler_angs[plane] if i == plane else 0)
rotation = euler_matrix(euler_angs[0], euler_angs[1], euler_angs[2], axes='sxyz')[:3, :3]
vector[plane] = mobile_com[plane]
euler_angs[i] = (euler_angs[self.plane] if i == self.plane
else 0)
rotation = euler_matrix(euler_angs[0],
euler_angs[1],
euler_angs[2],
axes='sxyz')[:3, :3]
vector[self.plane] = mobile_com[self.plane]
ts.positions = ts.positions - mobile_com
ts.positions = np.dot(ts.positions, rotation.T)
ts.positions = ts.positions + vector

return ts

return wrapped
Loading