Skip to content

Commit

Permalink
change ase to avoid observers
Browse files Browse the repository at this point in the history
  • Loading branch information
Ragzouken committed Sep 30, 2024
1 parent dbb3210 commit 5be80ce
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 27 deletions.
70 changes: 47 additions & 23 deletions python-libraries/nanover-omni/src/nanover/omni/ase.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from dataclasses import dataclass
from typing import Optional, Any, Callable
from typing import Optional, Any

import numpy as np
from ase.md import MDLogger
from ase.md.md import MolecularDynamics

from nanover.app import NanoverImdApplication
from nanover.ase import send_ase_frame
from nanover.ase.converter import EV_TO_KJMOL
from nanover.ase.converter import EV_TO_KJMOL, ase_to_frame_data
from nanover.ase.imd_calculator import ImdCalculator
from nanover.ase.wall_constraint import VelocityWallConstraint
from nanover.utilities.event import Event
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(self, name: Optional[str] = None):
self.checkpoint: Optional[InitialState] = None

self.frame_method = send_ase_frame
self._frame_adapter: Optional[Callable] = None
self.frame_index = 0

def load(self):
"""
Expand Down Expand Up @@ -107,18 +107,21 @@ def reset(self, app_server: NanoverImdApplication):

self.app_server = app_server

# reset atoms to initial state
self.atoms.set_positions(self.checkpoint.positions)
self.atoms.set_velocities(self.checkpoint.velocities)
self.atoms.set_cell(self.checkpoint.cell)

self.atoms.calc = ImdCalculator(
self.app_server.imd,
self.atoms.calc,
dynamics=self.dynamics,
)

self._frame_adapter = self.frame_method(
self.atoms,
self.app_server.frame_publisher,
include_velocities=self.include_velocities,
include_forces=self.include_forces,
)
# send the initial topology frame
frame_data = self.make_topology_frame()
self.app_server.frame_publisher.send_frame(0, frame_data)
self.frame_index = 1

if self.verbose:
self.dynamics.attach(
Expand All @@ -133,11 +136,6 @@ def reset(self, app_server: NanoverImdApplication):
interval=100,
)

# reset atoms to initial state
self.atoms.set_positions(self.checkpoint.positions)
self.atoms.set_velocities(self.checkpoint.velocities)
self.atoms.set_cell(self.checkpoint.cell)

def advance_by_one_step(self):
"""
Advance the simulation to the next point a frame should be reported, and send that frame.
Expand All @@ -155,7 +153,7 @@ def advance_to_next_report(self):
"""
Step the simulation to the next point a frame should be reported, and send that frame.
"""
assert self.dynamics is not None
assert self.dynamics is not None and self.app_server is not None

# determine step count for next frame
steps_to_next_frame = (
Expand All @@ -166,8 +164,12 @@ def advance_to_next_report(self):
# advance the simulation
self.dynamics.run(steps_to_next_frame)

# call frame adapter to send frame
self._frame_adapter()
# generate the next frame
frame_data = self.make_regular_frame()

# send the next frame
self.app_server.frame_publisher.send_frame(self.frame_index, frame_data)
self.frame_index += 1

# check if excessive energy requires sim reset
if self.reset_energy is not None and self.app_server is not None:
Expand All @@ -176,10 +178,32 @@ def advance_to_next_report(self):
self.on_reset_energy_exceeded.invoke()
self.reset(self.app_server)

def make_topology_frame(self):
"""
Make a NanoVer FrameData corresponding to the current particle positions and topology of the simulation.
"""
assert self.atoms is not None

frame_data = ase_to_frame_data(
self.atoms,
topology=True,
include_velocities=self.include_velocities,
include_forces=self.include_forces,
)

return frame_data

def make_regular_frame(self):
"""
Make a NanoVer FrameData corresponding to the current state of the simulation.
"""
assert self.atoms is not None

frame_data = ase_to_frame_data(
self.atoms,
topology=False,
include_velocities=self.include_velocities,
include_forces=self.include_forces,
)

def remove_observer(dynamics: MolecularDynamics, func: Callable):
entry = next(entry for entry in dynamics.observers if entry[0] == func)
try:
dynamics.observers.remove(entry)
except StopIteration:
pass
return frame_data
11 changes: 7 additions & 4 deletions python-libraries/nanover-omni/src/nanover/omni/ase_omm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from os import PathLike
from pathlib import Path
from typing import Optional, Callable
from typing import Optional

import numpy as np
from ase import units, Atoms
Expand Down Expand Up @@ -85,7 +85,6 @@ def __init__(self, name: Optional[str] = None):
self.checkpoint: Optional[InitialState] = None

self.frame_index = 0
self._frame_adapter: Optional[Callable] = None

def load(self):
"""
Expand Down Expand Up @@ -223,7 +222,6 @@ def make_topology_frame(self):
imd_calculator = self.atoms.calc
topology = imd_calculator.calculator.topology
frame_data = openmm_to_frame_data(
state=None,
topology=topology,
include_velocities=self.include_velocities,
include_forces=self.include_forces,
Expand All @@ -238,6 +236,11 @@ def make_regular_frame(self):
"""
assert self.atoms is not None

frame_data = ase_to_frame_data(self.atoms, topology=False)
frame_data = ase_to_frame_data(
self.atoms,
topology=False,
include_velocities=self.include_velocities,
include_forces=self.include_forces,
)

return frame_data

0 comments on commit 5be80ce

Please sign in to comment.