diff --git a/src/annealing/graph.rs b/src/annealing/graph.rs index 1dd0018a..57536e9f 100644 --- a/src/annealing/graph.rs +++ b/src/annealing/graph.rs @@ -1,5 +1,5 @@ use core::slice::Iter; -use std::{sync::Arc, collections::HashMap}; +use std::sync::Arc; use numpy::{ ndarray::{Array1, Array2, Array, s, ArcArray, ArcArray2}, Ix3, Ix4 }; @@ -9,6 +9,7 @@ use crate::{ value_error, coordinates::{Vector3D, CoordinateSystem, list_neighbors}, cylindric::Index, + hash_2d::HashMap2D, annealing::{ potential::{TrapezoidalPotential2D, BindingPotential2D, EdgeType}, random::RandomNumberGenerator, @@ -111,8 +112,8 @@ pub struct NodeState { #[derive(Clone)] pub struct CylindricGraph { components: GraphComponents, - coords: Arc>>, - energy: Arc>>, + coords: Arc>>, + energy: Arc>>, pub binding_potential: TrapezoidalPotential2D, pub local_shape: Vector3D, } @@ -122,8 +123,8 @@ impl CylindricGraph { pub fn empty() -> Self { Self { components: GraphComponents::empty(), - coords: Arc::new(HashMap::new()), - energy: Arc::new(HashMap::new()), + coords: Arc::new(HashMap2D::new()), + energy: Arc::new(HashMap2D::new()), binding_potential: TrapezoidalPotential2D::unbounded(), local_shape: Vector3D::new(0, 0, 0), } @@ -137,18 +138,17 @@ impl CylindricGraph { nrise: isize, ) -> PyResult<&Self> { self.components.clear(); - - let mut index_to_id: HashMap = HashMap::new(); + let ny = indices.len() / na as usize; + let mut index_to_id: HashMap2D = HashMap2D::from_shape(ny, na as usize); for i in 0..indices.len() { let idx = indices[i].clone(); - index_to_id.insert(idx.clone(), i); + index_to_id.insert(idx.as_tuple_usize(), i); self.components.add_node(NodeState { index: idx, shift: Vector3D::new(0, 0, 0) }); } - for (idx, i) in index_to_id.iter() { - let neighbors = idx.get_neighbors(na, nrise); + let neighbors = Index::new(idx.0 as isize, idx.1 as isize).get_neighbors(na, nrise); for neighbor in neighbors.y_iter() { - match index_to_id.get(&neighbor) { + match index_to_id.get((neighbor.y, neighbor.a)) { Some(j) => { if i < j { self.components.add_edge(*i, *j, EdgeType::Longitudinal); @@ -158,7 +158,7 @@ impl CylindricGraph { } } for neighbor in neighbors.a_iter() { - match index_to_id.get(&neighbor) { + match index_to_id.get((neighbor.y, neighbor.a)) { Some(j) => { if i < j { self.components.add_edge(*i, *j, EdgeType::Lateral); @@ -188,13 +188,13 @@ impl CylindricGraph { } else if xvec.shape() != [n_nodes, 3] { return value_error!("xvec has wrong shape"); } - let mut _coords: HashMap> = HashMap::new(); + + let (ny, na) = self.outer_shape(); + let mut _coords: HashMap2D> = HashMap2D::from_shape(ny, na); for i in 0..n_nodes { let node = self.components.node_state(i); - let idx = node.index.clone(); - _coords.insert( - idx, + node.index.as_tuple_usize(), CoordinateSystem::new( origin.slice(s![i, ..]).into(), zvec.slice(s![i, ..]).into(), @@ -220,11 +220,12 @@ impl CylindricGraph { let (_nz, _ny, _nx) = (shape[1], shape[2], shape[3]); self.local_shape = Vector3D::new(_nz, _ny, _nx).into(); let center: Vector3D = Vector3D::new(_nz / 2, _ny / 2, _nx / 2).into(); - let mut _energy: HashMap> = HashMap::new(); + let (ny_out, na_out) = self.outer_shape(); + let mut _energy: HashMap2D> = HashMap2D::from_shape(ny_out, na_out); for i in 0..n_nodes { let node_state = self.components.node_state(i); let idx = &node_state.index; - _energy.insert(idx.clone(), energy.slice(s![i, .., .., ..]).to_owned()); + _energy.insert(idx.as_tuple_usize(), energy.slice(s![i, .., .., ..]).to_owned()); self.components.set_node_state(i, NodeState { index: idx.clone(), shift: center.clone() }) } self.energy = Arc::new(_energy); @@ -247,7 +248,7 @@ impl CylindricGraph { pub fn internal(&self, node_state: &NodeState) -> f32 { let idx = &node_state.index; let vec = node_state.shift; - self.energy[&idx][[vec.z as usize, vec.y as usize, vec.x as usize]] + self.energy[(idx.y, idx.a)][[vec.z as usize, vec.y as usize, vec.x as usize]] } /// Calculate the binding energy between two nodes. @@ -258,8 +259,8 @@ impl CylindricGraph { pub fn binding(&self, node_state0: &NodeState, node_state1: &NodeState, typ: &EdgeType) -> f32 { let vec1 = node_state0.shift; let vec2 = node_state1.shift; - let coord1 = &self.coords[&node_state0.index]; - let coord2 = &self.coords[&node_state1.index]; + let coord1 = &self.coords[(node_state0.index.y, node_state0.index.a)]; + let coord2 = &self.coords[(node_state1.index.y, node_state1.index.a)]; let dr = coord1.at_vec(vec1.into()) - coord2.at_vec(vec2.into()); // ey is required for the angle constraint. let ey = coord2.origin - coord1.origin; @@ -320,6 +321,22 @@ impl CylindricGraph { Ok(self) } + /// If the graph is a cylinder with (ny, na) nodes, return (ny, na). + fn outer_shape(&self) -> (usize, usize) { + let mut ny = 0; + let mut na = 0; + for node in self.components.node_states.iter() { + let idx = &node.index; + if idx.y > ny { + ny = idx.y; + } + if idx.a > na { + na = idx.a; + } + } + (ny as usize + 1, na as usize + 1) + } + fn get_distances(&self, typ: &EdgeType) -> Array1 { if self.coords.len() == 0 { panic!("Coordinates not set.") @@ -334,8 +351,8 @@ impl CylindricGraph { let pos0 = graph.node_state(edge.0); let pos1 = graph.node_state(edge.1); - let coord0 = &self.coords[&pos0.index]; - let coord1 = &self.coords[&pos1.index]; + let coord0 = &self.coords[(pos0.index.y, pos0.index.a)]; + let coord1 = &self.coords[(pos1.index.y, pos1.index.a)]; let dr = coord0.at_vec(pos0.shift.into()) - coord1.at_vec(pos1.shift.into()); distances.push(dr.length()) } @@ -371,9 +388,9 @@ impl CylindricGraph { let pos_l = graph.node_state(neighbors[0]); let pos_r = graph.node_state(neighbors[1]); - let coord_c = &self.coords[&pos_c.index]; - let coord_l = &self.coords[&pos_l.index]; - let coord_r = &self.coords[&pos_r.index]; + let coord_c = &self.coords[(pos_c.index.y, pos_c.index.a)]; + let coord_l = &self.coords[(pos_l.index.y, pos_l.index.a)]; + let coord_r = &self.coords[(pos_r.index.y, pos_r.index.a)]; let dr_l = coord_c.at_vec(pos_c.shift.into()) - coord_l.at_vec(pos_l.shift.into()); let dr_r = coord_c.at_vec(pos_c.shift.into()) - coord_r.at_vec(pos_r.shift.into()); @@ -532,8 +549,8 @@ impl CylindricGraph { let ends = self.components.edge_end(i); let node0 = self.components.node_state(ends.0); let node1 = self.components.node_state(ends.1); - let coord0 = self.coords[&node0.index].at_vec(node0.shift.into()); - let coord1 = self.coords[&node1.index].at_vec(node1.shift.into()); + let coord0 = self.coords[(node0.index.y, node0.index.a)].at_vec(node0.shift.into()); + let coord1 = self.coords[(node1.index.y, node1.index.a)].at_vec(node1.shift.into()); out0[[i, 0]] = coord0.z; out0[[i, 1]] = coord0.y; out0[[i, 2]] = coord0.x; diff --git a/src/coordinates/coordinate_system.rs b/src/coordinates/coordinate_system.rs index 43c37519..1562a544 100644 --- a/src/coordinates/coordinate_system.rs +++ b/src/coordinates/coordinate_system.rs @@ -14,6 +14,17 @@ pub struct CoordinateSystem { pub ex: Vector3D, } +impl Default for CoordinateSystem { + fn default() -> Self { + Self { + origin: Vector3D::default(), + ez: Vector3D::default(), + ey: Vector3D::default(), + ex: Vector3D::default(), + } + } +} + impl CoordinateSystem { pub fn new(origin: Vector3D, ez: Vector3D, ey: Vector3D, ex: Vector3D) -> Self { Self { origin, ez, ey, ex } diff --git a/src/coordinates/vector.rs b/src/coordinates/vector.rs index f2d8f538..b0d4cdf9 100644 --- a/src/coordinates/vector.rs +++ b/src/coordinates/vector.rs @@ -15,6 +15,16 @@ impl Vector3D { } } +impl Default for Vector3D { + fn default() -> Self { + Vector3D { + z: T::default(), + y: T::default(), + x: T::default(), + } + } +} + ///////////////////////////////////////////////////////////////////////////////////////// /////////// Casting ///////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/cylindric.rs b/src/cylindric.rs index 35b2383b..08cbee80 100644 --- a/src/cylindric.rs +++ b/src/cylindric.rs @@ -35,6 +35,10 @@ impl Index { } impl Index { + pub fn as_tuple_usize(&self) -> (usize, usize) { + (self.y as usize, self.a as usize) + } + pub fn get_neighbors(&self, na: isize, nrise: isize) -> Neighbors { let y_fw = Index::new(self.y + 1, self.a); let y_bw = Index::new(self.y - 1, self.a); diff --git a/src/hash_2d.rs b/src/hash_2d.rs new file mode 100644 index 00000000..759b669c --- /dev/null +++ b/src/hash_2d.rs @@ -0,0 +1,65 @@ +use std::ops; +use numpy::ndarray::Array2; + +pub struct HashMap2D { + arrays: Array2>, + len: usize, +} + +impl HashMap2D { + pub fn new() -> Self { + Self { + arrays: Array2::default((0, 0)), + len: 0, + } + } + + pub fn from_shape(n0: usize, n1: usize) -> Self { + Self { + arrays: Array2::default((n0, n1)), + len: 0, + } + } + + pub fn insert(&mut self, index: (usize, usize), value: V) { + self.arrays[index] = Some(value); + self.len += 1; + } + + pub fn get(&self, index: (isize, isize)) -> &Option { + let n0 = self.arrays.shape()[0] as isize; + let n1 = self.arrays.shape()[1] as isize; + if index.0 < 0 || index.1 < 0 || index.0 >= n0 || index.1 >= n1 { + return &None; + } + &self.arrays[[index.0 as usize, index.1 as usize]] + } + + pub fn iter(&self) -> impl Iterator { + self.arrays.indexed_iter().filter_map(|(index, value)| { + value.as_ref().map(|v| (index, v)) + }) + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn shape(&self) -> (usize, usize) { + (self.arrays.shape()[0], self.arrays.shape()[1]) + } +} + +impl ops::Index<(isize, isize)> for HashMap2D { + type Output = V; + + fn index(&self, index: (isize, isize)) -> &Self::Output { + self.get(index).as_ref().unwrap() + } +} + +impl ops::IndexMut<(isize, isize)> for HashMap2D { + fn index_mut(&mut self, index: (isize, isize)) -> &mut Self::Output { + self.arrays[[index.0 as usize, index.1 as usize]].as_mut().unwrap() + } +} diff --git a/src/lib.rs b/src/lib.rs index 72a72dff..b61adfec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ pub mod annealing; pub mod filters; pub mod exceptions; pub mod regionprops; +pub mod hash_2d; // Python module #[pymodule]