Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contact Object Serialization #29

Merged
merged 10 commits into from
Jan 8, 2018
236 changes: 224 additions & 12 deletions contact_map/contact_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import collections
import itertools
import pickle
import json
import scipy
import numpy as np
import pandas as pd
import mdtraj as md

from .py_2_3 import inspect_method_arguments

# matplotlib is technically optional, but required for plotting
try:
import matplotlib
Expand Down Expand Up @@ -59,6 +63,7 @@ def _residue_and_index(residue, topology):
return (res, res_idx)



class ContactCount(object):
"""Return object when dealing with contacts (residue or atom).

Expand Down Expand Up @@ -264,6 +269,141 @@ def __init__(self, topology, query, haystack, cutoff, n_neighbors_ignored):
self._atom_idx_to_residue_idx = {atom.index: atom.residue.index
for atom in self.topology.atoms}

def __hash__(self):
return hash((self.cutoff, self.n_neighbors_ignored,
frozenset(self._query), frozenset(self._haystack),
self.topology))

def __eq__(self, other):
is_equal = (self.cutoff == other.cutoff
and self.n_neighbors_ignored == other.n_neighbors_ignored
and self.query == other.query
and self.haystack == other.haystack
and self.topology == other.topology)
return is_equal

def to_dict(self):
"""Convert object to a dict.

Keys should be strings; values should be (JSON-) serializable.

See also
--------
from_dict
"""
# need to explicitly convert possible np.int64 to int in several
dct = {
'topology': self._serialize_topology(self.topology),
'cutoff': self._cutoff,
'query': list([int(val) for val in self._query]),
'haystack': list([int(val) for val in self._haystack]),
'n_neighbors_ignored': self._n_neighbors_ignored,
'atom_idx_to_residue_idx': self._atom_idx_to_residue_idx,
'atom_contacts': \
self._serialize_contact_counter(self._atom_contacts),
'residue_contacts': \
self._serialize_contact_counter(self._residue_contacts)
}
return dct

@classmethod
def from_dict(cls, dct):
"""Create object from dict.

Parameters
----------
dct : dict
dict-formatted serialization (see to_dict for details)

See also
--------
to_dict
"""
deserialize_set = lambda k: set(k)
deserialize_atom_to_residue_dct = lambda d: {int(k): d[k] for k in d}
deserialization_helpers = {
'topology': cls._deserialize_topology,
'atom_contacts': cls._deserialize_contact_counter,
'residue_contacts': cls._deserialize_contact_counter,
'query': deserialize_set,
'haystack': deserialize_set,
'atom_idx_to_residue_idx': deserialize_atom_to_residue_dct
}
for key in deserialization_helpers:
if key in dct:
dct[key] = deserialization_helpers[key](dct[key])

kwarg_keys = inspect_method_arguments(cls.__init__)
set_keys = set(dct.keys())
missing = set(kwarg_keys) - set_keys
dct.update({k: None for k in missing})
instance = cls.__new__(cls)
for k in dct:
setattr(instance, "_" + k, dct[k])
return instance

@staticmethod
def _deserialize_topology(topology_json):
"""Create MDTraj topology from JSON-serialized version"""
table, bonds = json.loads(topology_json)
topology_df = pd.read_json(table)
topology = md.Topology.from_dataframe(topology_df,
np.array(bonds))
return topology

@staticmethod
def _serialize_topology(topology):
"""Serialize MDTraj topology (to JSON)"""
table, bonds = topology.to_dataframe()
json_tuples = (table.to_json(), bonds.tolist())
return json.dumps(json_tuples)

# TODO: adding a separate object for these frozenset counters will be
# useful for many things, and this serialization should be moved there
@staticmethod
def _serialize_contact_counter(counter):
"""JSON string from contact counter"""
# have to explicitly convert to int because json doesn't know how to
# serialize np.int64 objects, which we get in Python 3
serializable = {json.dumps([int(val) for val in key]): counter[key]
for key in counter}
return json.dumps(serializable)

@staticmethod
def _deserialize_contact_counter(json_string):
"""Contact counted from JSON string"""
dct = json.loads(json_string)
counter = collections.Counter({
frozenset(json.loads(key)): dct[key] for key in dct
})
return counter

def to_json(self):
"""JSON-serialized version of this object.

See also
--------
from_json
"""
dct = self.to_dict()
return json.dumps(dct)

@classmethod
def from_json(cls, json_string):
"""Create object from JSON string

Parameters
----------
json_string : str
JSON-serialized version of the object

See also
--------
to_json
"""
dct = json.loads(json_string)
return cls.from_dict(dct)

def _check_compatibility(self, other):
assert self.cutoff == other.cutoff
assert self.topology == other.topology
Expand Down Expand Up @@ -517,14 +657,35 @@ class ContactMap(ContactObject):
"""
def __init__(self, frame, query=None, haystack=None, cutoff=0.45,
n_neighbors_ignored=2):
self._frame = frame
self._frame = frame # TODO: remove this?
super(ContactMap, self).__init__(frame.topology, query, haystack,
cutoff, n_neighbors_ignored)
contact_maps = self.contact_map(frame, 0,
self.residue_query_atom_idxs,
self.residue_ignore_atom_idxs)
(self._atom_contacts, self._residue_contacts) = contact_maps

def __hash__(self):
return hash((super(ContactMap, self).__hash__(),
tuple(self._atom_contacts.items()),
tuple(self._residue_contacts.items())))

def __eq__(self, other):
is_equal = (super(ContactMap, self).__eq__(other)
and self._atom_contacts == other._atom_contacts
and self._residue_contacts == other._residue_contacts)
return is_equal

#def to_dict(self):
#dct = super(ContactMap, self).to_dict()
#atom_cntcts = self._serialize_contact_counter(self._atom_contacts)
#res_cntcts = self._serialize_contact_counter(self._residue_contacts)
#dct.update({
#'atom_contacts': atom_cntcts,
#'residue_contacts': res_cntcts
#})
#return dct


class ContactFrequency(ContactObject):
"""
Expand Down Expand Up @@ -560,7 +721,9 @@ def __init__(self, trajectory, query=None, haystack=None, cutoff=0.45,
super(ContactFrequency, self).__init__(trajectory.topology,
query, haystack, cutoff,
n_neighbors_ignored)
self._build_contact_map(trajectory)
contacts = self._build_contact_map(trajectory)
(self._atom_contacts, self._residue_contacts) = contacts


def _build_contact_map(self, trajectory):
# We actually build the contact map on a per-residue basis, although
Expand All @@ -569,8 +732,8 @@ def _build_contact_map(self, trajectory):
# TODO: this whole thing should be cleaned up and should replace
# MDTraj's really slow old compute_contacts by using MDTraj's new
# neighborlists (unless the MDTraj people do that first).
self._atom_contacts_count = collections.Counter([])
self._residue_contacts_count = collections.Counter([])
atom_contacts_count = collections.Counter([])
residue_contacts_count = collections.Counter([])

# cache things that can be calculated once based on the topology
# (namely, which atom indices matter for each residue)
Expand All @@ -583,8 +746,10 @@ def _build_contact_map(self, trajectory):
frame_atom_contacts = frame_contacts[0]
frame_residue_contacts = frame_contacts[1]
# self._atom_contacts_count += frame_atom_contacts
self._atom_contacts_count.update(frame_atom_contacts)
self._residue_contacts_count += frame_residue_contacts
atom_contacts_count.update(frame_atom_contacts)
residue_contacts_count += frame_residue_contacts

return (atom_contacts_count, residue_contacts_count)

@property
def n_frames(self):
Expand All @@ -601,8 +766,8 @@ def add_contact_frequency(self, other):
contact frequency
"""
self._check_compatibility(other)
self._atom_contacts_count += other._atom_contacts_count
self._residue_contacts_count += other._residue_contacts_count
self._atom_contacts += other._atom_contacts
self._residue_contacts += other._residue_contacts
self._n_frames += other._n_frames


Expand All @@ -621,8 +786,8 @@ def subtract_contact_frequency(self, other):
contact frequency
"""
self._check_compatibility(other)
self._atom_contacts_count -= other._atom_contacts_count
self._residue_contacts_count -= other._residue_contacts_count
self._atom_contacts -= other._atom_contacts
self._residue_contacts -= other._residue_contacts
self._n_frames -= other._n_frames

@property
Expand All @@ -632,7 +797,7 @@ def atom_contacts(self):
n_y = self.topology.n_atoms
return ContactCount(collections.Counter({
item[0]: float(item[1])/self.n_frames
for item in self._atom_contacts_count.items()
for item in self._atom_contacts.items()
}), self.topology.atom, n_x, n_y)

@property
Expand All @@ -642,7 +807,7 @@ def residue_contacts(self):
n_y = self.topology.n_residues
return ContactCount(collections.Counter({
item[0]: float(item[1])/self.n_frames
for item in self._residue_contacts_count.items()
for item in self._residue_contacts.items()
}), self.topology.residue, n_x, n_y)


Expand All @@ -665,6 +830,53 @@ def __init__(self, positive, negative):
positive.cutoff,
positive.n_neighbors_ignored)

def to_dict(self):
"""Convert object to a dict.

Keys should be strings; values should be (JSON-) serializable.

See also
--------
from_dict
"""
return {
'positive': self.positive.to_json(),
'negative': self.negative.to_json(),
'positive_cls': self.positive.__class__.__name__,
'negative_cls': self.negative.__class__.__name__
}

@classmethod
def from_dict(cls, dct):
"""Create object from dict.

Parameters
----------
dct : dict
dict-formatted serialization (see to_dict for details)

See also
--------
to_dict
"""
# TODO: add searching for subclasses (http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/)
supported_classes = [ContactMap, ContactFrequency]
supported_classes_dict = {class_.__name__: class_
for class_ in supported_classes}

def rebuild(pos_neg):
class_name = dct[pos_neg + "_cls"]
try:
cls_ = supported_classes_dict[class_name]
except KeyError: # pragma: no cover
raise RuntimeError("Can't rebuild class " + class_name)
obj = cls_.from_json(dct[pos_neg])
return obj

positive = rebuild('positive')
negative = rebuild('negative')
return cls(positive, negative)

def __sub__(self, other):
raise NotImplementedError

Expand Down
13 changes: 13 additions & 0 deletions contact_map/py_2_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import inspect

try:
getargspec = inspect.getfullargspec
except AttributeError:
getargspec = inspect.getargspec

def inspect_method_arguments(method, no_self=True):
args = getargspec(method).args
if no_self:
args = [arg for arg in args if arg != 'self']
return args

Loading