Skip to content

Commit

Permalink
Merge pull request #87 from hanjinliu/rma-no-hash
Browse files Browse the repository at this point in the history
Use `Array2` instead of `HashMap` to speed up RMA
  • Loading branch information
hanjinliu authored Jun 29, 2024
2 parents 5db4eaf + 678d82a commit 22f5d1b
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 28 deletions.
73 changes: 45 additions & 28 deletions src/annealing/graph.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use core::slice::Iter;
use std::{sync::Arc, collections::HashMap};
use std::sync::Arc;
use numpy::{
ndarray::{Array1, Array2, Array, s, ArcArray, ArcArray2}, Ix3, Ix4
};
Expand All @@ -9,6 +9,7 @@ use crate::{
value_error,
coordinates::{Vector3D, CoordinateSystem, list_neighbors},
cylindric::Index,
hash_2d::HashMap2D,
annealing::{
potential::{TrapezoidalPotential2D, BindingPotential2D, EdgeType},
random::RandomNumberGenerator,
Expand Down Expand Up @@ -111,8 +112,8 @@ pub struct NodeState {
#[derive(Clone)]
pub struct CylindricGraph {
components: GraphComponents<NodeState, EdgeType>,
coords: Arc<HashMap<Index, CoordinateSystem<f32>>>,
energy: Arc<HashMap<Index, Array<f32, Ix3>>>,
coords: Arc<HashMap2D<CoordinateSystem<f32>>>,
energy: Arc<HashMap2D<Array<f32, Ix3>>>,
pub binding_potential: TrapezoidalPotential2D,
pub local_shape: Vector3D<isize>,
}
Expand All @@ -122,8 +123,8 @@ impl CylindricGraph {
pub fn empty() -> Self {
Self {
components: GraphComponents::empty(),
coords: Arc::new(HashMap::new()),
energy: Arc::new(HashMap::new()),
coords: Arc::new(HashMap2D::new()),
energy: Arc::new(HashMap2D::new()),
binding_potential: TrapezoidalPotential2D::unbounded(),
local_shape: Vector3D::new(0, 0, 0),
}
Expand All @@ -137,18 +138,17 @@ impl CylindricGraph {
nrise: isize,
) -> PyResult<&Self> {
self.components.clear();

let mut index_to_id: HashMap<Index, usize> = HashMap::new();
let ny = indices.len() / na as usize;
let mut index_to_id: HashMap2D<usize> = HashMap2D::from_shape(ny, na as usize);
for i in 0..indices.len() {
let idx = indices[i].clone();
index_to_id.insert(idx.clone(), i);
index_to_id.insert(idx.as_tuple_usize(), i);
self.components.add_node(NodeState { index: idx, shift: Vector3D::new(0, 0, 0) });
}

for (idx, i) in index_to_id.iter() {
let neighbors = idx.get_neighbors(na, nrise);
let neighbors = Index::new(idx.0 as isize, idx.1 as isize).get_neighbors(na, nrise);
for neighbor in neighbors.y_iter() {
match index_to_id.get(&neighbor) {
match index_to_id.get((neighbor.y, neighbor.a)) {
Some(j) => {
if i < j {
self.components.add_edge(*i, *j, EdgeType::Longitudinal);
Expand All @@ -158,7 +158,7 @@ impl CylindricGraph {
}
}
for neighbor in neighbors.a_iter() {
match index_to_id.get(&neighbor) {
match index_to_id.get((neighbor.y, neighbor.a)) {
Some(j) => {
if i < j {
self.components.add_edge(*i, *j, EdgeType::Lateral);
Expand Down Expand Up @@ -188,13 +188,13 @@ impl CylindricGraph {
} else if xvec.shape() != [n_nodes, 3] {
return value_error!("xvec has wrong shape");
}
let mut _coords: HashMap<Index, CoordinateSystem<f32>> = HashMap::new();

let (ny, na) = self.outer_shape();
let mut _coords: HashMap2D<CoordinateSystem<f32>> = HashMap2D::from_shape(ny, na);
for i in 0..n_nodes {
let node = self.components.node_state(i);
let idx = node.index.clone();

_coords.insert(
idx,
node.index.as_tuple_usize(),
CoordinateSystem::new(
origin.slice(s![i, ..]).into(),
zvec.slice(s![i, ..]).into(),
Expand All @@ -220,11 +220,12 @@ impl CylindricGraph {
let (_nz, _ny, _nx) = (shape[1], shape[2], shape[3]);
self.local_shape = Vector3D::new(_nz, _ny, _nx).into();
let center: Vector3D<isize> = Vector3D::new(_nz / 2, _ny / 2, _nx / 2).into();
let mut _energy: HashMap<Index, Array<f32, Ix3>> = HashMap::new();
let (ny_out, na_out) = self.outer_shape();
let mut _energy: HashMap2D<Array<f32, Ix3>> = HashMap2D::from_shape(ny_out, na_out);
for i in 0..n_nodes {
let node_state = self.components.node_state(i);
let idx = &node_state.index;
_energy.insert(idx.clone(), energy.slice(s![i, .., .., ..]).to_owned());
_energy.insert(idx.as_tuple_usize(), energy.slice(s![i, .., .., ..]).to_owned());
self.components.set_node_state(i, NodeState { index: idx.clone(), shift: center.clone() })
}
self.energy = Arc::new(_energy);
Expand All @@ -247,7 +248,7 @@ impl CylindricGraph {
pub fn internal(&self, node_state: &NodeState) -> f32 {
let idx = &node_state.index;
let vec = node_state.shift;
self.energy[&idx][[vec.z as usize, vec.y as usize, vec.x as usize]]
self.energy[(idx.y, idx.a)][[vec.z as usize, vec.y as usize, vec.x as usize]]
}

/// Calculate the binding energy between two nodes.
Expand All @@ -258,8 +259,8 @@ impl CylindricGraph {
pub fn binding(&self, node_state0: &NodeState, node_state1: &NodeState, typ: &EdgeType) -> f32 {
let vec1 = node_state0.shift;
let vec2 = node_state1.shift;
let coord1 = &self.coords[&node_state0.index];
let coord2 = &self.coords[&node_state1.index];
let coord1 = &self.coords[(node_state0.index.y, node_state0.index.a)];
let coord2 = &self.coords[(node_state1.index.y, node_state1.index.a)];
let dr = coord1.at_vec(vec1.into()) - coord2.at_vec(vec2.into());
// ey is required for the angle constraint.
let ey = coord2.origin - coord1.origin;
Expand Down Expand Up @@ -320,6 +321,22 @@ impl CylindricGraph {
Ok(self)
}

/// If the graph is a cylinder with (ny, na) nodes, return (ny, na).
fn outer_shape(&self) -> (usize, usize) {
let mut ny = 0;
let mut na = 0;
for node in self.components.node_states.iter() {
let idx = &node.index;
if idx.y > ny {
ny = idx.y;
}
if idx.a > na {
na = idx.a;
}
}
(ny as usize + 1, na as usize + 1)
}

fn get_distances(&self, typ: &EdgeType) -> Array1<f32> {
if self.coords.len() == 0 {
panic!("Coordinates not set.")
Expand All @@ -334,8 +351,8 @@ impl CylindricGraph {
let pos0 = graph.node_state(edge.0);
let pos1 = graph.node_state(edge.1);

let coord0 = &self.coords[&pos0.index];
let coord1 = &self.coords[&pos1.index];
let coord0 = &self.coords[(pos0.index.y, pos0.index.a)];
let coord1 = &self.coords[(pos1.index.y, pos1.index.a)];
let dr = coord0.at_vec(pos0.shift.into()) - coord1.at_vec(pos1.shift.into());
distances.push(dr.length())
}
Expand Down Expand Up @@ -371,9 +388,9 @@ impl CylindricGraph {
let pos_l = graph.node_state(neighbors[0]);
let pos_r = graph.node_state(neighbors[1]);

let coord_c = &self.coords[&pos_c.index];
let coord_l = &self.coords[&pos_l.index];
let coord_r = &self.coords[&pos_r.index];
let coord_c = &self.coords[(pos_c.index.y, pos_c.index.a)];
let coord_l = &self.coords[(pos_l.index.y, pos_l.index.a)];
let coord_r = &self.coords[(pos_r.index.y, pos_r.index.a)];

let dr_l = coord_c.at_vec(pos_c.shift.into()) - coord_l.at_vec(pos_l.shift.into());
let dr_r = coord_c.at_vec(pos_c.shift.into()) - coord_r.at_vec(pos_r.shift.into());
Expand Down Expand Up @@ -532,8 +549,8 @@ impl CylindricGraph {
let ends = self.components.edge_end(i);
let node0 = self.components.node_state(ends.0);
let node1 = self.components.node_state(ends.1);
let coord0 = self.coords[&node0.index].at_vec(node0.shift.into());
let coord1 = self.coords[&node1.index].at_vec(node1.shift.into());
let coord0 = self.coords[(node0.index.y, node0.index.a)].at_vec(node0.shift.into());
let coord1 = self.coords[(node1.index.y, node1.index.a)].at_vec(node1.shift.into());
out0[[i, 0]] = coord0.z;
out0[[i, 1]] = coord0.y;
out0[[i, 2]] = coord0.x;
Expand Down
11 changes: 11 additions & 0 deletions src/coordinates/coordinate_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ pub struct CoordinateSystem<T> {
pub ex: Vector3D<T>,
}

impl<T: Default> Default for CoordinateSystem<T> {
fn default() -> Self {
Self {
origin: Vector3D::default(),
ez: Vector3D::default(),
ey: Vector3D::default(),
ex: Vector3D::default(),
}
}
}

impl<T> CoordinateSystem<T> {
pub fn new(origin: Vector3D<T>, ez: Vector3D<T>, ey: Vector3D<T>, ex: Vector3D<T>) -> Self {
Self { origin, ez, ey, ex }
Expand Down
10 changes: 10 additions & 0 deletions src/coordinates/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ impl<T> Vector3D<T> {
}
}

impl<T: Default> Default for Vector3D<T> {
fn default() -> Self {
Vector3D {
z: T::default(),
y: T::default(),
x: T::default(),
}
}
}

/////////////////////////////////////////////////////////////////////////////////////////
/////////// Casting /////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////
Expand Down
4 changes: 4 additions & 0 deletions src/cylindric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ impl Index {
}

impl Index {
pub fn as_tuple_usize(&self) -> (usize, usize) {
(self.y as usize, self.a as usize)
}

pub fn get_neighbors(&self, na: isize, nrise: isize) -> Neighbors {
let y_fw = Index::new(self.y + 1, self.a);
let y_bw = Index::new(self.y - 1, self.a);
Expand Down
65 changes: 65 additions & 0 deletions src/hash_2d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use std::ops;
use numpy::ndarray::Array2;

pub struct HashMap2D<V> {
arrays: Array2<Option<V>>,
len: usize,
}

impl<V> HashMap2D<V> {
pub fn new() -> Self {
Self {
arrays: Array2::default((0, 0)),
len: 0,
}
}

pub fn from_shape(n0: usize, n1: usize) -> Self {
Self {
arrays: Array2::default((n0, n1)),
len: 0,
}
}

pub fn insert(&mut self, index: (usize, usize), value: V) {
self.arrays[index] = Some(value);
self.len += 1;
}

pub fn get(&self, index: (isize, isize)) -> &Option<V> {
let n0 = self.arrays.shape()[0] as isize;
let n1 = self.arrays.shape()[1] as isize;
if index.0 < 0 || index.1 < 0 || index.0 >= n0 || index.1 >= n1 {
return &None;
}
&self.arrays[[index.0 as usize, index.1 as usize]]
}

pub fn iter(&self) -> impl Iterator<Item=((usize, usize), &V)> {
self.arrays.indexed_iter().filter_map(|(index, value)| {
value.as_ref().map(|v| (index, v))
})
}

pub fn len(&self) -> usize {
self.len
}

pub fn shape(&self) -> (usize, usize) {
(self.arrays.shape()[0], self.arrays.shape()[1])
}
}

impl<V> ops::Index<(isize, isize)> for HashMap2D<V> {
type Output = V;

fn index(&self, index: (isize, isize)) -> &Self::Output {
self.get(index).as_ref().unwrap()
}
}

impl<V> ops::IndexMut<(isize, isize)> for HashMap2D<V> {
fn index_mut(&mut self, index: (isize, isize)) -> &mut Self::Output {
self.arrays[[index.0 as usize, index.1 as usize]].as_mut().unwrap()
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod annealing;
pub mod filters;
pub mod exceptions;
pub mod regionprops;
pub mod hash_2d;

// Python module
#[pymodule]
Expand Down

0 comments on commit 22f5d1b

Please sign in to comment.