diff --git a/src/dynsight/_internal/lens/lens.py b/src/dynsight/_internal/lens/lens.py index 81fee67..5bb011f 100644 --- a/src/dynsight/_internal/lens/lens.py +++ b/src/dynsight/_internal/lens/lens.py @@ -6,7 +6,7 @@ from MDAnalysis import AtomGroup, Universe import numpy as np -from MDAnalysis.lib.NeighborSearch import AtomNeighborSearch +from MDAnalysis.lib.nsgrid import FastNS def list_neighbours_along_trajectory( @@ -34,15 +34,22 @@ def list_neighbours_along_trajectory( if trajslice is None: trajslice = slice(None) neigh_list_per_frame = [] + for _ in input_universe.universe.trajectory[trajslice]: - neigh_search = AtomNeighborSearch( - input_universe.atoms, box=input_universe.dimensions - ) + atom_pos = input_universe.atoms.positions + box_dim = input_universe.dimensions + gridsearch = FastNS(cutoff, atom_pos, box=box_dim, pbc=True) + max_cutoff = gridsearch._prepare_box(box=box_dim, pbc=True) + if cutoff > max_cutoff: + cutoff = max_cutoff + fastns_results = gridsearch.self_search() + pairs = fastns_results.get_pairs() + neigh_list_per_atom = [[] for _ in range(len(input_universe.atoms))] + for x, y in pairs: + neigh_list_per_atom[x].append(y) + neigh_list_per_atom[y].append(x) + neigh_list_per_frame.append(neigh_list_per_atom) - neigh_list_per_atom = [ - neigh_search.search(atom, cutoff) for atom in input_universe.atoms - ] - neigh_list_per_frame.append([at.ix for at in neigh_list_per_atom]) return neigh_list_per_frame