Skip to content

Commit

Permalink
Merge pull request #34 from CosmWasm/typed-bounds
Browse files Browse the repository at this point in the history
Provide typed bounds for iteration
  • Loading branch information
uint committed May 7, 2024
2 parents 201e7f5 + 35c7e50 commit 43f68fa
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 59 deletions.
58 changes: 51 additions & 7 deletions packages/storey/src/containers/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::encoding::{DecodableWith, EncodableWith};
use crate::storage::{IterableStorage, StorageBranch};
use crate::storage::{Storage, StorageMut};

use super::{IterableAccessor, Storable};
use super::{BoundFor, BoundedIterableAccessor, IterableAccessor, Storable};

const META_LAST_IX: &[u8] = &[0];
const META_LEN: &[u8] = &[1];
Expand Down Expand Up @@ -135,6 +135,20 @@ where
}
}

impl<E, T, S> BoundedIterableAccessor for ColumnAccess<E, T, S>
where
E: Encoding,
T: EncodableWith<E> + DecodableWith<E>,
S: IterableStorage,
{
}

impl<T, E> BoundFor<Column<T, E>> for u32 {
fn into_bytes(self) -> Vec<u8> {
self.to_be_bytes().to_vec()
}
}

impl<E, T, S> ColumnAccess<E, T, S>
where
E: Encoding,
Expand Down Expand Up @@ -428,28 +442,58 @@ mod tests {
access.push(&9001).unwrap();
access.remove(1).unwrap();

assert_eq!(
access.pairs().collect::<Result<Vec<_>, _>>().unwrap(),
vec![(0, 1337), (2, 9001)]
);

assert_eq!(
access.keys().collect::<Result<Vec<_>, _>>().unwrap(),
vec![0, 2]
);

assert_eq!(
access.values().collect::<Result<Vec<_>, _>>().unwrap(),
vec![1337, 9001]
);
}

#[test]
fn bounded_iteration() {
let mut storage = TestStorage::new();

let column = Column::<u64, TestEncoding>::new(0);
let mut access = column.access(&mut storage);

access.push(&1337).unwrap();
access.push(&42).unwrap();
access.push(&9001).unwrap();
access.push(&1).unwrap();
access.push(&2).unwrap();
access.remove(2).unwrap();

assert_eq!(
access
.pairs(None, None)
.bounded_pairs(Some(1), Some(4))
.collect::<Result<Vec<_>, _>>()
.unwrap(),
vec![(0, 1337), (2, 9001)]
vec![(1, 42), (3, 1)]
);

assert_eq!(
access
.keys(None, None)
.bounded_keys(Some(1), Some(4))
.collect::<Result<Vec<_>, _>>()
.unwrap(),
vec![0, 2]
vec![1, 3]
);

assert_eq!(
access
.values(None, None)
.bounded_values(Some(1), Some(4))
.collect::<Result<Vec<_>, _>>()
.unwrap(),
vec![1337, 9001]
vec![42, 1]
);
}
}
45 changes: 21 additions & 24 deletions packages/storey/src/containers/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,7 @@ where
K: Borrow<Q>,
Q: Key + ?Sized,
{
let len = key.bytes().len();
let bytes = key.bytes();
let mut key = Vec::with_capacity(len + 1);

key.push(len as u8);
key.extend_from_slice(bytes);
let key = length_prefixed_key(key);

V::access_impl(StorageBranch::new(&self.storage, key))
}
Expand Down Expand Up @@ -243,6 +238,17 @@ where
}
}

fn length_prefixed_key<K: Key + ?Sized>(key: &K) -> Vec<u8> {
let len = key.bytes().len();
let bytes = key.bytes();
let mut key = Vec::with_capacity(len + 1);

key.push(len as u8);
key.extend_from_slice(bytes);

key
}

impl<K, V, S> IterableAccessor for MapAccess<K, V, S>
where
K: OwnedKey,
Expand Down Expand Up @@ -276,6 +282,12 @@ impl Key for String {
}
}

impl Key for str {
fn bytes(&self) -> &[u8] {
self.as_bytes()
}
}

#[derive(Debug, PartialEq, Eq, Clone, Copy, thiserror::Error)]
#[error("invalid UTF8")]
pub struct InvalidUtf8;
Expand All @@ -293,12 +305,6 @@ impl OwnedKey for String {
}
}

impl Key for str {
fn bytes(&self) -> &[u8] {
self.as_bytes()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -338,10 +344,7 @@ mod tests {
access.entry_mut("foo").set(&1337).unwrap();
access.entry_mut("bar").set(&42).unwrap();

let items = access
.pairs(None, None)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let items = access.pairs().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(
items,
vec![
Expand All @@ -361,10 +364,7 @@ mod tests {
access.entry_mut("foo").set(&1337).unwrap();
access.entry_mut("bar").set(&42).unwrap();

let keys = access
.keys(None, None)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let keys = access.keys().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(keys, vec![("bar".to_string(), ()), ("foo".to_string(), ())])
}

Expand All @@ -378,10 +378,7 @@ mod tests {
access.entry_mut("foo").set(&1337).unwrap();
access.entry_mut("bar").set(&42).unwrap();

let values = access
.values(None, None)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let values = access.values().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(values, vec![42, 1337])
}
}
95 changes: 76 additions & 19 deletions packages/storey/src/containers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub enum KVDecodeError<K, V> {

/// A trait for collection accessors (see [`Storable::AccessorT`]) that provide iteration over
/// their contents.
pub trait IterableAccessor {
pub trait IterableAccessor: Sized {
/// The [`Storable`] type this accessor is associated with.
type StorableT: Storable;

Expand All @@ -81,42 +81,99 @@ pub trait IterableAccessor {
fn storage(&self) -> &Self::StorageT;

/// Iterate over key-value pairs in this collection.
fn pairs<'s>(
&'s self,
start: Option<&[u8]>,
end: Option<&[u8]>,
) -> StorableIter<'s, Self::StorableT, Self::StorageT> {
fn pairs(&self) -> StorableIter<'_, Self::StorableT, Self::StorageT> {
StorableIter {
inner: self.storage().pairs(start, end),
inner: self.storage().pairs(None, None),
phantom: PhantomData,
}
}

/// Iterate over keys in this collection.
fn keys<'s>(
&'s self,
start: Option<&[u8]>,
end: Option<&[u8]>,
) -> StorableKeys<'s, Self::StorableT, Self::StorageT> {
fn keys(&self) -> StorableKeys<'_, Self::StorableT, Self::StorageT> {
StorableKeys {
inner: self.storage().keys(start, end),
inner: self.storage().keys(None, None),
phantom: PhantomData,
}
}

/// Iterate over values in this collection.
fn values<'s>(
&'s self,
start: Option<&[u8]>,
end: Option<&[u8]>,
) -> StorableValues<'s, Self::StorableT, Self::StorageT> {
fn values(&self) -> StorableValues<'_, Self::StorableT, Self::StorageT> {
StorableValues {
inner: self.storage().values(start, end),
inner: self.storage().values(None, None),
phantom: PhantomData,
}
}
}

pub trait BoundedIterableAccessor: IterableAccessor {
/// Iterate over key-value pairs in this collection, respecting the given bounds.
fn bounded_pairs<S, E>(
&self,
start: Option<S>,
end: Option<E>,
) -> StorableIter<'_, Self::StorableT, Self::StorageT>
where
S: BoundFor<Self::StorableT>,
E: BoundFor<Self::StorableT>,
{
let start = start.map(|b| b.into_bytes());
let end = end.map(|b| b.into_bytes());

StorableIter {
inner: self.storage().pairs(start.as_deref(), end.as_deref()),
phantom: PhantomData,
}
}

/// Iterate over keys in this collection, respecting the given bounds.
fn bounded_keys<S, E>(
&self,
start: Option<S>,
end: Option<E>,
) -> StorableKeys<'_, Self::StorableT, Self::StorageT>
where
S: BoundFor<Self::StorableT>,
E: BoundFor<Self::StorableT>,
{
let start = start.map(|b| b.into_bytes());
let end = end.map(|b| b.into_bytes());

StorableKeys {
inner: self.storage().keys(start.as_deref(), end.as_deref()),
phantom: PhantomData,
}
}

/// Iterate over values in this collection, respecting the given bounds.
fn bounded_values<S, E>(
&self,
start: Option<S>,
end: Option<E>,
) -> StorableValues<'_, Self::StorableT, Self::StorageT>
where
S: BoundFor<Self::StorableT>,
E: BoundFor<Self::StorableT>,
{
let start = start.map(|b| b.into_bytes());
let end = end.map(|b| b.into_bytes());

StorableValues {
inner: self.storage().values(start.as_deref(), end.as_deref()),
phantom: PhantomData,
}
}
}

/// A type that can be used as bounds for iteration over a given collection.
///
/// As an example, a collection `Foo` with string-y keys can accept both `String` and
/// `&str` bounds by providing these impls:
/// - `impl BoundFor<Foo> for &str`
/// - `impl BoundFor<Foo> for String`
pub trait BoundFor<T> {
fn into_bytes(self) -> Vec<u8>;
}

/// The iterator over key-value pairs in a collection.
pub struct StorableIter<'i, S, B>
where
Expand Down
5 changes: 1 addition & 4 deletions packages/storey/tests/composition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ fn map_of_column() {
assert_eq!(access.entry("bar").get(0).unwrap(), Some(9001));
assert_eq!(access.entry("bar").len().unwrap(), 1);

let all = access
.pairs(None, None)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let all = access.pairs().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(
all,
vec![
Expand Down
7 changes: 2 additions & 5 deletions packages/storey/tests/iteration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ fn map_of_map_iteration() {
.unwrap();

// iterate over all items
let items = access
.pairs(None, None)
.collect::<Result<Vec<_>, _>>()
.unwrap();
let items = access.pairs().collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(
items,
vec![
Expand All @@ -36,7 +33,7 @@ fn map_of_map_iteration() {
// iterate over items under "foo"
let items = access
.entry("foo")
.pairs(None, None)
.pairs()
.collect::<Result<Vec<_>, _>>()
.unwrap();
assert_eq!(
Expand Down

0 comments on commit 43f68fa

Please sign in to comment.