Skip to content

Commit

Permalink
Add support for HOOMD 3 (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz authored Sep 15, 2022
2 parents 9ec9b3a + c9841ac commit add0da4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 19 deletions.
4 changes: 2 additions & 2 deletions pysages/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def __enter__(self):
"""
Trampoline 'with statements' to the wrapped context when the backend supports it.
"""
if self.get_backend_name() == "hoomd":
if hasattr(self.context, "__enter__"):
return self.context.__enter__()

def __exit__(self, exc_type, exc_value, exc_traceback):
"""
Trampoline 'with statements' to the wrapped context when the backend supports it.
"""
if self.get_backend_name() == "hoomd":
if hasattr(self.context, "__exit__"):
return self.context.__exit__(exc_type, exc_value, exc_traceback)


Expand Down
79 changes: 62 additions & 17 deletions pysages/backends/hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from jax import jit, numpy as np
from jax.dlpack import from_dlpack as asarray
from hoomd import md
from hoomd.dlext import (
AccessLocation,
AccessMode,
Expand Down Expand Up @@ -41,15 +42,57 @@
CONTEXTS_SAMPLERS = {}


class Sampler(DLExtSampler):
if getattr(hoomd, "__version__", "").startswith("2."):
SamplerBase = DLExtSampler

def is_on_gpu(context):
return context.on_gpu()

def get_integrator(context):
return context.integrator

def get_run_method(context):
return hoomd.run

def get_system(context):
return context.system

def set_half_step_hook(context, half_step_hook):
context.integrator.cpp_integrator.setHalfStepHook(half_step_hook)

def remove_half_step_hook(context):
context.integrator.cpp_integrator.removeHalfStepHook()

else:

class SamplerBase(DLExtSampler, md.HalfStepHook):
def __init__(self, sysview, update, location, mode):
md.HalfStepHook.__init__(self)
DLExtSampler.__init__(self, sysview, update, location, mode)

def is_on_gpu(context):
return not isinstance(context.device, hoomd.device.CPU)

def get_integrator(context):
return context.operations.integrator

def get_run_method(context):
context.run(0) # ensure that the context is properly initialized
return context.run

def get_system(context):
return context._cpp_sys

def set_half_step_hook(context, half_step_hook):
context.operations.integrator.half_step_hook = half_step_hook

def remove_half_step_hook(context):
context.operations.integrator.half_step_hook = None


class Sampler(SamplerBase):
def __init__(self, sysview, method_bundle, bias, callback: Callable, restore):
initial_snapshot, initialize, method_update = method_bundle
self.state = initialize()
self.callback = callback
self.bias = bias
self.box = initial_snapshot.box
self.dt = initial_snapshot.dt
self._restore = restore

def update(positions, vel_mass, rtags, images, forces, timestep):
snapshot = self._pack_snapshot(positions, vel_mass, forces, rtags, images)
Expand All @@ -59,6 +102,12 @@ def update(positions, vel_mass, rtags, images, forces, timestep):
self.callback(snapshot, self.state, timestep)

super().__init__(sysview, update, default_location(), AccessMode.Read)
self.state = initialize()
self.callback = callback
self.bias = bias
self.box = initial_snapshot.box
self.dt = initial_snapshot.dt
self._restore = restore

def restore(self, prev_snapshot):
def restore_callback(positions, vel_mass, rtags, images, forces, n):
Expand Down Expand Up @@ -100,10 +149,6 @@ def default_location():
return AccessLocation.OnHost


def is_on_gpu(context):
return context.on_gpu()


def take_snapshot(wrapped_context, location=default_location()):
context = wrapped_context.context
sysview = wrapped_context.view
Expand All @@ -113,15 +158,15 @@ def take_snapshot(wrapped_context, location=default_location()):
ids = copy(asarray(rtags(sysview, location, AccessMode.Read)))
imgs = copy(asarray(images(sysview, location, AccessMode.Read)))

box = sysview.particle_data().getGlobalBox()
box = sysview.particle_data.getGlobalBox()
L = box.getL()
xy = box.getTiltFactorXY()
xz = box.getTiltFactorXZ()
yz = box.getTiltFactorYZ()
lo = box.getLo()
H = ((L.x, xy * L.y, xz * L.z), (0.0, L.y, yz * L.z), (0.0, 0.0, L.z))
origin = (lo.x, lo.y, lo.z)
dt = context.integrator.dt
dt = get_integrator(context).dt

return Snapshot(positions, vel_mass, forces, ids, imgs, Box(H, origin), dt)

Expand Down Expand Up @@ -197,9 +242,9 @@ def bind(
wrapped_context: ContextWrapper, sampling_method: SamplingMethod, callback: Callable, **kwargs
):
context = wrapped_context.context
sysview = SystemView(context.system_definition)
sysview = SystemView(get_system(context))
wrapped_context.view = sysview
wrapped_context.run = hoomd.run
wrapped_context.run = get_run_method(context)
helpers, restore, bias = build_helpers(context, sampling_method)

with sysview:
Expand All @@ -208,7 +253,7 @@ def bind(
method_bundle = sampling_method.build(snapshot, helpers)
sync_and_bias = partial(bias, sync_backend=sysview.synchronize)
sampler = Sampler(sysview, method_bundle, sync_and_bias, callback, restore)
context.integrator.cpp_integrator.setHalfStepHook(sampler)
set_half_step_hook(context, sampler)

CONTEXTS_SAMPLERS[context] = sampler

Expand All @@ -221,7 +266,7 @@ def detach(context):
`Sampler` object.
"""
if context in CONTEXTS_SAMPLERS:
context.integrator.cpp_integrator.removeHalfStepHook()
remove_half_step_hook(context)
del CONTEXTS_SAMPLERS[context]
else:
warn("This context has no sampler bound to it.")

0 comments on commit add0da4

Please sign in to comment.