Skip to content

Commit

Permalink
Change signature of get_many_mut APIs
Browse files Browse the repository at this point in the history
by 1) panicking on overlapping keys and 2) returning an array of Option
rather than an Option of array.

And address a potential future UB by only using raw pointers
rust-lang/unsafe-code-guidelines#346
  • Loading branch information
Urgau committed Sep 30, 2024
1 parent 7cf51ea commit e78ea1f
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 86 deletions.
150 changes: 98 additions & 52 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1467,8 +1467,11 @@ where
/// Attempts to get mutable references to `N` values in the map at once.
///
/// Returns an array of length `N` with the results of each query. For soundness, at most one
/// mutable reference will be returned to any value. `None` will be returned if any of the
/// keys are duplicates or missing.
/// mutable reference will be returned to any value. `None` will be used if the key is missing.
///
/// # Panics
///
/// Panics if any keys are overlapping.
///
/// # Examples
///
Expand All @@ -1481,33 +1484,46 @@ where
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
/// libraries.insert("Library of Congress".to_string(), 1800);
///
/// // Get Athenæum and Bodleian Library
/// let [Some(a), Some(b)] = libraries.get_many_mut([
/// "Athenæum",
/// "Bodleian Library",
/// ]) else { panic!() };
///
/// // Assert values of Athenæum and Library of Congress
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "Library of Congress",
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// &mut 1807,
/// &mut 1800,
/// ]),
/// [
/// Some(&mut 1807),
/// Some(&mut 1800),
/// ],
/// );
///
/// // Missing keys result in None
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "Athenæum1",
/// "New York Public Library",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(got, [None, None]);
/// ```
///
/// ```should_panic
/// use hashbrown::HashMap;
///
/// let mut libraries = HashMap::new();
/// libraries.insert("Athenæum".to_string(), 1807);
///
/// // Duplicate keys result in None
/// // Duplicate keys panic!
/// let got = libraries.get_many_mut([
/// "Athenæum",
/// "Athenæum",
/// ]);
/// assert_eq!(got, None);
/// ```
pub fn get_many_mut<Q, const N: usize>(&mut self, ks: [&Q; N]) -> Option<[&'_ mut V; N]>
pub fn get_many_mut<Q, const N: usize>(&mut self, ks: [&Q; N]) -> [Option<&'_ mut V>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1517,8 +1533,8 @@ where
/// Attempts to get mutable references to `N` values in the map at once, without validating that
/// the values are unique.
///
/// Returns an array of length `N` with the results of each query. `None` will be returned if
/// any of the keys are missing.
/// Returns an array of length `N` with the results of each query. `None` will be used if
/// the key is missing.
///
/// For a safe alternative see [`get_many_mut`](`HashMap::get_many_mut`).
///
Expand All @@ -1540,29 +1556,37 @@ where
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
/// libraries.insert("Library of Congress".to_string(), 1800);
///
/// let got = libraries.get_many_mut([
/// // SAFETY: The keys do not overlap.
/// let [Some(a), Some(b)] = (unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "Bodleian Library",
/// ]) }) else { panic!() };
///
/// // SAFETY: The keys do not overlap.
/// let got = unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "Library of Congress",
/// ]);
/// ]) };
/// assert_eq!(
/// got,
/// Some([
/// &mut 1807,
/// &mut 1800,
/// ]),
/// [
/// Some(&mut 1807),
/// Some(&mut 1800),
/// ],
/// );
///
/// // Missing keys result in None
/// let got = libraries.get_many_mut([
/// // SAFETY: The keys do not overlap.
/// let got = unsafe { libraries.get_many_unchecked_mut([
/// "Athenæum",
/// "New York Public Library",
/// ]);
/// assert_eq!(got, None);
/// ]) };
/// // Missing keys result in None
/// assert_eq!(got, [Some(&mut 1807), None]);
/// ```
pub unsafe fn get_many_unchecked_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[&'_ mut V; N]>
) -> [Option<&'_ mut V>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1574,8 +1598,11 @@ where
/// references to the corresponding keys.
///
/// Returns an array of length `N` with the results of each query. For soundness, at most one
/// mutable reference will be returned to any value. `None` will be returned if any of the keys
/// are duplicates or missing.
/// mutable reference will be returned to any value. `None` will be used if the key is missing.
///
/// # Panics
///
/// Panics if any keys are overlapping.
///
/// # Examples
///
Expand All @@ -1594,30 +1621,37 @@ where
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// (&"Bodleian Library".to_string(), &mut 1602),
/// (&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691),
/// ]),
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// Some((&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691)),
/// ],
/// );
/// // Missing keys result in None
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Gewandhaus",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(got, [Some((&"Bodleian Library".to_string(), &mut 1602)), None]);
/// ```
///
/// // Duplicate keys result in None
/// ```should_panic
/// use hashbrown::HashMap;
///
/// let mut libraries = HashMap::new();
/// libraries.insert("Bodleian Library".to_string(), 1602);
/// libraries.insert("Herzogin-Anna-Amalia-Bibliothek".to_string(), 1691);
///
/// // Duplicate keys result in panic!
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Herzogin-Anna-Amalia-Bibliothek",
/// "Herzogin-Anna-Amalia-Bibliothek",
/// ]);
/// assert_eq!(got, None);
/// ```
pub fn get_many_key_value_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[(&'_ K, &'_ mut V); N]>
) -> [Option<(&'_ K, &'_ mut V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand Down Expand Up @@ -1657,30 +1691,36 @@ where
/// ]);
/// assert_eq!(
/// got,
/// Some([
/// (&"Bodleian Library".to_string(), &mut 1602),
/// (&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691),
/// ]),
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// Some((&"Herzogin-Anna-Amalia-Bibliothek".to_string(), &mut 1691)),
/// ],
/// );
/// // Missing keys result in None
/// let got = libraries.get_many_key_value_mut([
/// "Bodleian Library",
/// "Gewandhaus",
/// ]);
/// assert_eq!(got, None);
/// assert_eq!(
/// got,
/// [
/// Some((&"Bodleian Library".to_string(), &mut 1602)),
/// None,
/// ],
/// );
/// ```
pub unsafe fn get_many_key_value_unchecked_mut<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[(&'_ K, &'_ mut V); N]>
) -> [Option<(&'_ K, &'_ mut V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
self.get_many_unchecked_mut_inner(ks)
.map(|res| res.map(|(k, v)| (&*k, v)))
}

fn get_many_mut_inner<Q, const N: usize>(&mut self, ks: [&Q; N]) -> Option<[&'_ mut (K, V); N]>
fn get_many_mut_inner<Q, const N: usize>(&mut self, ks: [&Q; N]) -> [Option<&'_ mut (K, V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand All @@ -1692,7 +1732,7 @@ where
unsafe fn get_many_unchecked_mut_inner<Q, const N: usize>(
&mut self,
ks: [&Q; N],
) -> Option<[&'_ mut (K, V); N]>
) -> [Option<&'_ mut (K, V)>; N]
where
Q: Hash + Equivalent<K> + ?Sized,
{
Expand Down Expand Up @@ -5937,33 +5977,39 @@ mod test_map {
}

#[test]
fn test_get_each_mut() {
fn test_get_many_mut() {
let mut map = HashMap::new();
map.insert("foo".to_owned(), 0);
map.insert("bar".to_owned(), 10);
map.insert("baz".to_owned(), 20);
map.insert("qux".to_owned(), 30);

let xs = map.get_many_mut(["foo", "qux"]);
assert_eq!(xs, Some([&mut 0, &mut 30]));
assert_eq!(xs, [Some(&mut 0), Some(&mut 30)]);

let xs = map.get_many_mut(["foo", "dud"]);
assert_eq!(xs, None);

let xs = map.get_many_mut(["foo", "foo"]);
assert_eq!(xs, None);
assert_eq!(xs, [Some(&mut 0), None]);

let ys = map.get_many_key_value_mut(["bar", "baz"]);
assert_eq!(
ys,
Some([(&"bar".to_owned(), &mut 10), (&"baz".to_owned(), &mut 20),]),
[
Some((&"bar".to_owned(), &mut 10)),
Some((&"baz".to_owned(), &mut 20))
],
);

let ys = map.get_many_key_value_mut(["bar", "dip"]);
assert_eq!(ys, None);
assert_eq!(ys, [Some((&"bar".to_string(), &mut 10)), None]);
}

#[test]
#[should_panic = "duplicate keys found"]
fn test_get_many_mut_duplicate() {
let mut map = HashMap::new();
map.insert("foo".to_owned(), 0);

let ys = map.get_many_key_value_mut(["baz", "baz"]);
assert_eq!(ys, None);
let _xs = map.get_many_mut(["foo", "foo"]);
}

#[test]
Expand Down
39 changes: 21 additions & 18 deletions src/raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::TryReserveError;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem;
use core::mem::MaybeUninit;
use core::ptr::NonNull;
use core::{hint, ptr};

Expand Down Expand Up @@ -484,6 +483,13 @@ impl<T> Bucket<T> {
}
}

/// Acquires the underlying non-null pointer `*mut T` to `data`.
#[inline]
fn as_non_null(&self) -> NonNull<T> {
// SAFETY: `self.ptr` is already a `NonNull`
unsafe { NonNull::new_unchecked(self.as_ptr()) }
}

/// Create a new [`Bucket`] that is offset from the `self` by the given
/// `offset`. The pointer calculation is performed by calculating the
/// offset from `self` pointer (convenience for `self.ptr.as_ptr().sub(offset)`).
Expand Down Expand Up @@ -1291,48 +1297,45 @@ impl<T, A: Allocator> RawTable<T, A> {
&mut self,
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
) -> [Option<&'_ mut T>; N] {
unsafe {
let ptrs = self.get_many_mut_pointers(hashes, eq)?;
let ptrs = self.get_many_mut_pointers(hashes, eq);

for (i, &cur) in ptrs.iter().enumerate() {
if ptrs[..i].iter().any(|&prev| ptr::eq::<T>(prev, cur)) {
return None;
for (i, cur) in ptrs.iter().enumerate() {
if cur.is_some() && ptrs[..i].contains(cur) {
panic!("duplicate keys found");
}
}
// All bucket are distinct from all previous buckets so we're clear to return the result
// of the lookup.

// TODO use `MaybeUninit::array_assume_init` here instead once that's stable.
Some(mem::transmute_copy(&ptrs))
ptrs.map(|ptr| ptr.map(|mut ptr| ptr.as_mut()))
}
}

pub unsafe fn get_many_unchecked_mut<const N: usize>(
&mut self,
hashes: [u64; N],
eq: impl FnMut(usize, &T) -> bool,
) -> Option<[&'_ mut T; N]> {
let ptrs = self.get_many_mut_pointers(hashes, eq)?;
Some(mem::transmute_copy(&ptrs))
) -> [Option<&'_ mut T>; N] {
let ptrs = self.get_many_mut_pointers(hashes, eq);
ptrs.map(|ptr| ptr.map(|mut ptr| ptr.as_mut()))
}

unsafe fn get_many_mut_pointers<const N: usize>(
&mut self,
hashes: [u64; N],
mut eq: impl FnMut(usize, &T) -> bool,
) -> Option<[*mut T; N]> {
// TODO use `MaybeUninit::uninit_array` here instead once that's stable.
let mut outs: MaybeUninit<[*mut T; N]> = MaybeUninit::uninit();
) -> [Option<NonNull<T>>; N] {
let mut outs: [Option<NonNull<T>>; N] = [None; N];
let outs_ptr = outs.as_mut_ptr();

for (i, &hash) in hashes.iter().enumerate() {
let cur = self.find(hash, |k| eq(i, k))?;
*(*outs_ptr).get_unchecked_mut(i) = cur.as_mut();
let cur = self.find(hash, |k| eq(i, k)).map(|cur| cur.as_non_null());
outs_ptr.add(i).write(cur);
}

// TODO use `MaybeUninit::array_assume_init` here instead once that's stable.
Some(outs.assume_init())
outs
}

/// Returns the number of elements the map can hold without reallocating.
Expand Down
Loading

0 comments on commit e78ea1f

Please sign in to comment.