From 056437572dcd155f80b2c0615c3d71830470589b Mon Sep 17 00:00:00 2001 From: Kenny Kerr Date: Sun, 14 Jan 2024 20:15:28 -0600 Subject: [PATCH] Harden generic type parameter binding (#2791) --- .../extensions/mod/Win32/Foundation/BOOL.rs | 2 +- .../mod/Win32/Foundation/BOOLEAN.rs | 2 +- crates/libs/core/src/param.rs | 28 +++++++++---------- crates/libs/core/src/strings/hstring.rs | 2 +- crates/libs/core/src/type.rs | 8 ++---- .../src/Windows/Win32/Foundation/mod.rs | 4 +-- crates/tests/implement/tests/vector.rs | 15 ++++++++++ 7 files changed, 35 insertions(+), 26 deletions(-) diff --git a/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOL.rs b/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOL.rs index 2499fa964d..1fd761fdf7 100644 --- a/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOL.rs +++ b/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOL.rs @@ -67,7 +67,7 @@ impl ::core::ops::Not for BOOL { } } impl ::windows_core::IntoParam for bool { - fn into_param(self) -> ::windows_core::Param { + unsafe fn into_param(self) -> ::windows_core::Param { ::windows_core::Param::Owned(self.into()) } } diff --git a/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOLEAN.rs b/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOLEAN.rs index 33b32a8d62..b07562a15b 100644 --- a/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOLEAN.rs +++ b/crates/libs/bindgen/src/rust/extensions/mod/Win32/Foundation/BOOLEAN.rs @@ -67,7 +67,7 @@ impl ::core::ops::Not for BOOLEAN { } } impl ::windows_core::IntoParam for bool { - fn into_param(self) -> ::windows_core::Param { + unsafe fn into_param(self) -> ::windows_core::Param { ::windows_core::Param::Owned(self.into()) } } diff --git a/crates/libs/core/src/param.rs b/crates/libs/core/src/param.rs index d90acbe38a..ea5b400060 100644 --- a/crates/libs/core/src/param.rs +++ b/crates/libs/core/src/param.rs @@ -28,17 +28,17 @@ pub trait IntoParam::TypeKind>: Sized where T: Type, { - fn into_param(self) -> Param; + unsafe fn into_param(self) -> Param; } impl IntoParam for Option<&T> where T: Type, { - fn into_param(self) -> Param { + unsafe fn into_param(self) -> Param { Param::Borrowed(match self { - Some(item) => item.abi(), - None => unsafe { std::mem::zeroed() }, + Some(item) => std::mem::transmute_copy(item), + None => std::mem::zeroed(), }) } } @@ -50,13 +50,11 @@ where U: Interface, U: CanInto, { - fn into_param(self) -> Param { - unsafe { - if U::QUERY { - self.cast().map_or(Param::Borrowed(std::mem::zeroed()), |ok| Param::Owned(ok)) - } else { - Param::Borrowed(std::mem::transmute_copy(self)) - } + unsafe fn into_param(self) -> Param { + if U::QUERY { + self.cast().map_or(Param::Borrowed(std::mem::zeroed()), |ok| Param::Owned(ok)) + } else { + Param::Borrowed(std::mem::transmute_copy(self)) } } } @@ -65,8 +63,8 @@ impl IntoParam for &T where T: TypeKind + Clone, { - fn into_param(self) -> Param { - Param::Borrowed(self.abi()) + unsafe fn into_param(self) -> Param { + Param::Borrowed(std::mem::transmute_copy(self)) } } @@ -76,7 +74,7 @@ where U: TypeKind + Clone, U: CanInto, { - fn into_param(self) -> Param { - Param::Owned(unsafe { std::mem::transmute_copy(&self) }) + unsafe fn into_param(self) -> Param { + Param::Owned(std::mem::transmute_copy(&self)) } } diff --git a/crates/libs/core/src/strings/hstring.rs b/crates/libs/core/src/strings/hstring.rs index 5eac52fd33..3082189353 100644 --- a/crates/libs/core/src/strings/hstring.rs +++ b/crates/libs/core/src/strings/hstring.rs @@ -398,7 +398,7 @@ impl From for std::ffi::OsString { } impl IntoParam for &HSTRING { - fn into_param(self) -> Param { + unsafe fn into_param(self) -> Param { Param::Owned(PCWSTR(self.as_ptr())) } } diff --git a/crates/libs/core/src/type.rs b/crates/libs/core/src/type.rs index c0d5aa5d61..3f4cd06ac5 100644 --- a/crates/libs/core/src/type.rs +++ b/crates/libs/core/src/type.rs @@ -19,10 +19,6 @@ pub trait Type::TypeKind>: TypeKind + Sized + C type Abi; type Default; - fn abi(&self) -> Self::Abi { - unsafe { std::mem::transmute_copy(self) } - } - /// # Safety unsafe fn from_abi(abi: Self::Abi) -> Result; fn from_default(default: &Self::Default) -> Result; @@ -55,7 +51,7 @@ where type Abi = std::mem::MaybeUninit; type Default = Self; - unsafe fn from_abi(abi: std::mem::MaybeUninit) -> Result { + unsafe fn from_abi(abi: Self::Abi) -> Result { Ok(abi.assume_init()) } @@ -71,7 +67,7 @@ where type Abi = Self; type Default = Self; - unsafe fn from_abi(abi: Self) -> Result { + unsafe fn from_abi(abi: Self::Abi) -> Result { Ok(abi) } diff --git a/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs b/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs index edcbdf804e..fc54b6c8f9 100644 --- a/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs +++ b/crates/libs/windows/src/Windows/Win32/Foundation/mod.rs @@ -11572,7 +11572,7 @@ impl ::core::ops::Not for BOOL { } } impl ::windows_core::IntoParam for bool { - fn into_param(self) -> ::windows_core::Param { + unsafe fn into_param(self) -> ::windows_core::Param { ::windows_core::Param::Owned(self.into()) } } @@ -11645,7 +11645,7 @@ impl ::core::ops::Not for BOOLEAN { } } impl ::windows_core::IntoParam for bool { - fn into_param(self) -> ::windows_core::Param { + unsafe fn into_param(self) -> ::windows_core::Param { ::windows_core::Param::Owned(self.into()) } } diff --git a/crates/tests/implement/tests/vector.rs b/crates/tests/implement/tests/vector.rs index 3a175b1b2e..c799d3e88c 100644 --- a/crates/tests/implement/tests/vector.rs +++ b/crates/tests/implement/tests/vector.rs @@ -284,9 +284,24 @@ fn test_2759() -> Result<()> { v.Append(&uri)?; let uri = Uri::CreateUri(h!("https://microsoft.com/"))?; v.Append(&uri)?; + v.Append(&uri.cast::()?)?; assert_eq!(&v.GetAt(0)?.ToString()?, h!("https://github.com/")); assert_eq!(&v.GetAt(1)?.ToString()?, h!("https://microsoft.com/")); Ok(()) } + +#[test] +fn test_into_param() -> Result<()> { + let v: IVector = Vector::new(vec![]).into(); + v.Append(1)?; + v.Append(Some(&2))?; + v.Append(None)?; + + assert_eq!(v.GetAt(0)?, 1); + assert_eq!(v.GetAt(1)?, 2); + assert_eq!(v.GetAt(2)?, 0); + + Ok(()) +}