Skip to content

Commit

Permalink
Polish PackedArray a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
cschwan committed Oct 4, 2024
1 parent c196936 commit 9d63ba7
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 85 deletions.
6 changes: 3 additions & 3 deletions pineappl/src/lagrange_subgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl Subgrid for LagrangeSubgridV2 {
// we cannot use `Self::indexed_iter` because it multiplies with `reweight`
if let SubgridEnum::LagrangeSubgridV2(other) = other {
// TODO: make sure `other` has the same interpolation as `self`
for (mut index, value) in other.array.indexed_iter3() {
for (mut index, value) in other.array.indexed_iter() {
if let Some((a, b)) = transpose {
index.swap(a, b);
}
Expand All @@ -75,7 +75,7 @@ impl Subgrid for LagrangeSubgridV2 {
fn symmetrize(&mut self, a: usize, b: usize) {
let mut new_array = PackedArray::new(self.array.shape().to_vec());

for (mut index, sigma) in self.array.indexed_iter3() {
for (mut index, sigma) in self.array.indexed_iter() {
// TODO: why not the other way around?
if index[b] < index[a] {
index.swap(a, b);
Expand All @@ -90,7 +90,7 @@ impl Subgrid for LagrangeSubgridV2 {
fn indexed_iter(&self) -> SubgridIndexedIter {
let nodes: Vec<_> = self.interps.iter().map(Interp::node_values).collect();

Box::new(self.array.indexed_iter3().map(move |(indices, weight)| {
Box::new(self.array.indexed_iter().map(move |(indices, weight)| {
let reweight = self
.interps
.iter()
Expand Down
179 changes: 100 additions & 79 deletions pineappl/src/packed_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,7 @@ impl<T: Copy + Default + PartialEq> PackedArray<T> {
}

/// TODO
pub fn indexed_iter2(&self) -> impl Iterator<Item = (usize, T)> + '_ {
self.start_indices
.iter()
.zip(&self.lengths)
.flat_map(|(&start_index, &length)| start_index..(start_index + length))
.zip(&self.entries)
.filter(|&(_, entry)| *entry != Default::default())
.map(|(indices, entry)| (indices, *entry))
}

/// Returns an `Iterator` over the non-default (non-zero) elements of this array. The type of
/// an iterator element is `([usize; D], T)` where the first element of the tuple is the index
/// and the second element is the value.
pub fn indexed_iter<const D: usize>(&self) -> impl Iterator<Item = ([usize; D], T)> + '_ {
assert_eq!(self.shape.len(), D);

pub fn indexed_iter(&self) -> impl Iterator<Item = (Vec<usize>, T)> + '_ {
self.start_indices
.iter()
.zip(&self.lengths)
Expand All @@ -105,19 +90,6 @@ impl<T: Copy + Default + PartialEq> PackedArray<T> {
.map(|(indices, entry)| (indices, *entry))
}

/// TODO
pub fn indexed_iter3(&self) -> impl Iterator<Item = (Vec<usize>, T)> + '_ {
self.start_indices
.iter()
.zip(&self.lengths)
.flat_map(|(&start_index, &length)| {
(start_index..(start_index + length)).map(|i| unravel_index2(i, &self.shape))
})
.zip(&self.entries)
.filter(|&(_, entry)| *entry != Default::default())
.map(|(indices, entry)| (indices, *entry))
}

/// TODO
// TODO: rewrite this method into `sub_block_iter_mut() -> impl Iterator<Item = &mut f64>`
pub fn sub_block_idx(
Expand Down Expand Up @@ -181,24 +153,8 @@ fn ravel_multi_index(multi_index: &[usize], shape: &[usize]) -> usize {
.fold(0, |acc, (i, d)| acc * d + i)
}

/// Converts a flat `index` into a `multi_index`.
///
/// # Panics
///
/// Panics when `index` is out of range.
#[must_use]
fn unravel_index<const D: usize>(mut index: usize, shape: &[usize]) -> [usize; D] {
assert!(index < shape.iter().product());
let mut indices = [0; D];
for (i, d) in indices.iter_mut().zip(shape).rev() {
*i = index % d;
index /= d;
}
indices
}

/// TODO
pub fn unravel_index2(mut index: usize, shape: &[usize]) -> Vec<usize> {
pub fn unravel_index(mut index: usize, shape: &[usize]) -> Vec<usize> {
assert!(index < shape.iter().product());
let mut indices = vec![0; shape.len()];
for (i, d) in indices.iter_mut().zip(shape).rev() {
Expand Down Expand Up @@ -255,13 +211,11 @@ impl<T: Copy + Default + PartialEq> Index<usize> for PackedArray<T> {
type Output = T;

fn index(&self, index: usize) -> &Self::Output {
// assert_eq!(index.len(), self.shape.len());
// assert!(
// index.iter().zip(self.shape.iter()).all(|(&i, &d)| i < d),
// "index {:?} is out of bounds for array of shape {:?}",
// index,
// self.shape
// );
assert!(
index < self.shape.iter().product(),
"index {index} is out of bounds for array of shape {:?}",
self.shape
);

let raveled_index = index;
// let raveled_index = ravel_multi_index(&index, &self.shape);
Expand Down Expand Up @@ -534,16 +488,6 @@ mod tests {
use ndarray::Array3;
use std::mem;

#[test]
fn unravel_index() {
assert_eq!(super::unravel_index(0, &[3, 2]), [0, 0]);
assert_eq!(super::unravel_index(1, &[3, 2]), [0, 1]);
assert_eq!(super::unravel_index(2, &[3, 2]), [1, 0]);
assert_eq!(super::unravel_index(3, &[3, 2]), [1, 1]);
assert_eq!(super::unravel_index(4, &[3, 2]), [2, 0]);
assert_eq!(super::unravel_index(5, &[3, 2]), [2, 1]);
}

#[test]
fn ravel_multi_index() {
assert_eq!(super::ravel_multi_index(&[0, 0], &[3, 2]), 0);
Expand Down Expand Up @@ -608,6 +552,61 @@ mod tests {
assert_eq!(a.lengths, vec![8]);
}

#[test]
fn flat_index() {
let shape = vec![4, 2];
let mut a = PackedArray::new(shape.clone());

a[[0, 0]] = 1;
assert_eq!(a[super::ravel_multi_index(&[0, 0], &shape)], 1);
assert_eq!(a.entries, vec![1]);
assert_eq!(a.start_indices, vec![0]);
assert_eq!(a.lengths, vec![1]);

a[[3, 0]] = 2;
assert_eq!(a[super::ravel_multi_index(&[0, 0], &shape)], 1);
assert_eq!(a[super::ravel_multi_index(&[3, 0], &shape)], 2);
assert_eq!(a.entries, vec![1, 2]);
assert_eq!(a.start_indices, vec![0, 6]);
assert_eq!(a.lengths, vec![1, 1]);

a[[3, 1]] = 3;
assert_eq!(a[super::ravel_multi_index(&[0, 0], &shape)], 1);
assert_eq!(a[super::ravel_multi_index(&[3, 0], &shape)], 2);
assert_eq!(a[super::ravel_multi_index(&[3, 1], &shape)], 3);
assert_eq!(a.entries, vec![1, 2, 3]);
assert_eq!(a.start_indices, vec![0, 6]);
assert_eq!(a.lengths, vec![1, 2]);

a[[2, 0]] = 9;
assert_eq!(a[super::ravel_multi_index(&[0, 0], &shape)], 1);
assert_eq!(a[super::ravel_multi_index(&[3, 0], &shape)], 2);
assert_eq!(a[super::ravel_multi_index(&[3, 1], &shape)], 3);
assert_eq!(a[super::ravel_multi_index(&[2, 0], &shape)], 9);
assert_eq!(a.entries, vec![1, 9, 0, 2, 3]);
assert_eq!(a.start_indices, vec![0, 4]);
assert_eq!(a.lengths, vec![1, 4]);

a[[2, 0]] = 4;
assert_eq!(a[super::ravel_multi_index(&[0, 0], &shape)], 1);
assert_eq!(a[super::ravel_multi_index(&[3, 0], &shape)], 2);
assert_eq!(a[super::ravel_multi_index(&[3, 1], &shape)], 3);
assert_eq!(a[super::ravel_multi_index(&[2, 0], &shape)], 4);
assert_eq!(a.entries, vec![1, 4, 0, 2, 3]);
assert_eq!(a.start_indices, vec![0, 4]);
assert_eq!(a.lengths, vec![1, 4]);

a[[1, 0]] = 5;
assert_eq!(a[super::ravel_multi_index(&[0, 0], &shape)], 1);
assert_eq!(a[super::ravel_multi_index(&[3, 0], &shape)], 2);
assert_eq!(a[super::ravel_multi_index(&[3, 1], &shape)], 3);
assert_eq!(a[super::ravel_multi_index(&[2, 0], &shape)], 4);
assert_eq!(a[super::ravel_multi_index(&[1, 0], &shape)], 5);
assert_eq!(a.entries, vec![1, 0, 5, 0, 4, 0, 2, 3]);
assert_eq!(a.start_indices, vec![0]);
assert_eq!(a.lengths, vec![8]);
}

#[test]
fn iter() {
let mut a = PackedArray::new(vec![6, 5]);
Expand All @@ -619,11 +618,11 @@ mod tests {
assert_eq!(
a.indexed_iter().collect::<Vec<_>>(),
&[
([2, 2], 1),
([2, 4], 2),
([4, 1], 3),
([4, 4], 4),
([5, 0], 5),
(vec![2, 2], 1),
(vec![2, 4], 2),
(vec![4, 1], 3),
(vec![4, 4], 4),
(vec![5, 0], 5),
]
);
}
Expand Down Expand Up @@ -846,6 +845,28 @@ mod tests {
assert_eq!(array[[0, 0, 2]], 0);
}

#[test]
#[should_panic(expected = "entry at index 0 is implicitly set to the default value")]
fn flat_index_panic_0() {
let shape = vec![40, 50, 50];
let mut array = PackedArray::new(shape.clone());

array[[1, 0, 0]] = 1;

let _ = array[super::ravel_multi_index(&[0, 0, 0], &shape)];
}

#[test]
#[should_panic(expected = "index 102550 is out of bounds for array of shape [40, 50, 50]")]
fn flat_index_panic_dim1() {
let shape = vec![40, 50, 50];
let mut array = PackedArray::new(shape.clone());

array[[1, 0, 0]] = 1;

let _ = array[super::ravel_multi_index(&[40, 50, 50], &shape)];
}

#[test]
fn indexed_iter() {
let mut array = PackedArray::new(vec![40, 50, 50]);
Expand All @@ -854,15 +875,15 @@ mod tests {
assert_eq!(array.shape(), [40, 50, 50]);

// check empty iterator
assert_eq!(array.indexed_iter::<3>().next(), None);
assert_eq!(array.indexed_iter().next(), None);

// insert an element
array[[2, 3, 4]] = 1;

let mut iter = array.indexed_iter();

// check iterator with one element
assert_eq!(iter.next(), Some(([2, 3, 4], 1)));
assert_eq!(iter.next(), Some((vec![2, 3, 4], 1)));
assert_eq!(iter.next(), None);

mem::drop(iter);
Expand All @@ -872,8 +893,8 @@ mod tests {

let mut iter = array.indexed_iter();

assert_eq!(iter.next(), Some(([2, 3, 4], 1)));
assert_eq!(iter.next(), Some(([2, 3, 6], 2)));
assert_eq!(iter.next(), Some((vec![2, 3, 4], 1)));
assert_eq!(iter.next(), Some((vec![2, 3, 6], 2)));
assert_eq!(iter.next(), None);

mem::drop(iter);
Expand All @@ -883,9 +904,9 @@ mod tests {

let mut iter = array.indexed_iter();

assert_eq!(iter.next(), Some(([2, 3, 4], 1)));
assert_eq!(iter.next(), Some(([2, 3, 6], 2)));
assert_eq!(iter.next(), Some(([4, 5, 7], 3)));
assert_eq!(iter.next(), Some((vec![2, 3, 4], 1)));
assert_eq!(iter.next(), Some((vec![2, 3, 6], 2)));
assert_eq!(iter.next(), Some((vec![4, 5, 7], 3)));
assert_eq!(iter.next(), None);

mem::drop(iter);
Expand All @@ -895,10 +916,10 @@ mod tests {

let mut iter = array.indexed_iter();

assert_eq!(iter.next(), Some(([2, 0, 0], 4)));
assert_eq!(iter.next(), Some(([2, 3, 4], 1)));
assert_eq!(iter.next(), Some(([2, 3, 6], 2)));
assert_eq!(iter.next(), Some(([4, 5, 7], 3)));
assert_eq!(iter.next(), Some((vec![2, 0, 0], 4)));
assert_eq!(iter.next(), Some((vec![2, 3, 4], 1)));
assert_eq!(iter.next(), Some((vec![2, 3, 6], 2)));
assert_eq!(iter.next(), Some((vec![4, 5, 7], 3)));
assert_eq!(iter.next(), None);
}

Expand Down
6 changes: 3 additions & 3 deletions pineappl/src/packed_subgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl Subgrid for PackedQ1X2SubgridV1 {

let mut array = PackedArray::new(new_node_values.iter().map(NodeValues::len).collect());

for (indices, value) in self.array.indexed_iter3() {
for (indices, value) in self.array.indexed_iter() {
let target: Vec<_> = izip!(indices, &new_node_values, &lhs_node_values)
.map(|(index, new, lhs)| new.find(lhs.get(index)).unwrap())
.collect();
Expand Down Expand Up @@ -82,7 +82,7 @@ impl Subgrid for PackedQ1X2SubgridV1 {
fn symmetrize(&mut self, a: usize, b: usize) {
let mut new_array = PackedArray::new(self.array.shape().to_vec());

for (mut index, sigma) in self.array.indexed_iter3() {
for (mut index, sigma) in self.array.indexed_iter() {
// TODO: why not the other way around?
if index[b] < index[a] {
index.swap(a, b);
Expand All @@ -95,7 +95,7 @@ impl Subgrid for PackedQ1X2SubgridV1 {
}

fn indexed_iter(&self) -> SubgridIndexedIter {
Box::new(self.array.indexed_iter3())
Box::new(self.array.indexed_iter())
}

fn stats(&self) -> Stats {
Expand Down

0 comments on commit 9d63ba7

Please sign in to comment.