From 567e4a38d454d7af332144ef3635f8e647290aa8 Mon Sep 17 00:00:00 2001 From: Marcel Langer Date: Mon, 7 Aug 2023 11:07:24 +0200 Subject: [PATCH] Remove jax-md and replace with glp --- mlff/mdx/atoms.py | 12 ++++++------ mlff/mdx/calculator.py | 15 +++++++-------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mlff/mdx/atoms.py b/mlff/mdx/atoms.py index 88b81fc..a4b97cd 100644 --- a/mlff/mdx/atoms.py +++ b/mlff/mdx/atoms.py @@ -11,8 +11,6 @@ from mlff.utils import Graph, Neighbors, System -from jax_md.space import periodic_general, free - SpatialPartitioning = namedtuple( "SpatialPartitioning", ("allocate_fn", "update_fn", "cutoff", "skin", "capacity_multiplier") @@ -512,7 +510,9 @@ def _update(x: AtomsX, neighbors): def to_displacement(atoms): - if atoms.get_cell() is not None: - return periodic_general(atoms.get_cell(), fractional_coordinates=False)[0] - else: - return free()[0] + from glp.periodic import make_displacement + + displacement = make_displacement(atoms.cell) + + # reverse sign convention for backwards compatibility + return lambda Ra, Rb: raw_disp(Rb, Ra) diff --git a/mlff/mdx/calculator.py b/mlff/mdx/calculator.py index 0f8ecee..6da8aeb 100644 --- a/mlff/mdx/calculator.py +++ b/mlff/mdx/calculator.py @@ -6,11 +6,8 @@ from flax import struct -from jax_md.space import periodic_general, free - from mlff.utils import Graph - logging.basicConfig(level=logging.INFO) StackNet = Any @@ -136,8 +133,10 @@ def atomsx_to_graph(atoms: Any): return Graph(edges, nodes, neighbors['idx_j'], neighbors['idx_i'], mask) -def to_displacement(atoms: Any): - if atoms.get_cell() is not None: - return periodic_general(atoms.get_cell(), fractional_coordinates=False)[0] - else: - return free()[0] +def to_displacement(atoms): + from glp.periodic import make_displacement + + displacement = make_displacement(atoms.cell) + + # reverse sign convention for backwards compatibility + return lambda Ra, Rb: raw_disp(Rb, Ra)