Skip to content

Commit

Permalink
Adds ability to override GetTrustLevel return value (#2714)
Browse files Browse the repository at this point in the history
  • Loading branch information
kennykerr authored Nov 27, 2023
1 parent 0a64f0e commit 5c5c82a
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 7 deletions.
25 changes: 18 additions & 7 deletions crates/libs/core/src/inspectable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ impl IInspectable {
Ok(std::mem::transmute(abi))
}
}

/// Gets the trust level of the current object.
pub fn GetTrustLevel(&self) -> Result<i32> {
unsafe {
let mut value = 0;
(self.vtable().GetTrustLevel)(std::mem::transmute_copy(self), &mut value).ok()?;
Ok(value)
}
}
}

#[doc(hidden)]
Expand Down Expand Up @@ -60,14 +69,16 @@ impl IInspectable_Vtbl {
*value = std::mem::transmute(h);
HRESULT(0)
}
unsafe extern "system" fn GetTrustLevel(_: *mut std::ffi::c_void, value: *mut i32) -> HRESULT {
// Note: even if we end up implementing this in future, it still doesn't need a this pointer
// since the data to be returned is type- not instance-specific so can be shared for all
// interfaces.
*value = 0;
HRESULT(0)
unsafe extern "system" fn GetTrustLevel<T: IUnknownImpl, const OFFSET: isize>(this: *mut std::ffi::c_void, value: *mut i32) -> HRESULT {
let this = (this as *mut *mut std::ffi::c_void).offset(OFFSET) as *mut T;
(*this).GetTrustLevel(value)
}
Self {
base: IUnknown_Vtbl::new::<Identity, OFFSET>(),
GetIids,
GetRuntimeClassName: GetRuntimeClassName::<Name>,
GetTrustLevel: GetTrustLevel::<Identity, OFFSET>,
}
Self { base: IUnknown_Vtbl::new::<Identity, OFFSET>(), GetIids, GetRuntimeClassName: GetRuntimeClassName::<Name>, GetTrustLevel }
}
}

Expand Down
5 changes: 5 additions & 0 deletions crates/libs/core/src/unknown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,20 @@ pub trait IUnknownImpl {
/// This function is safe to call as long as the interface pointer is non-null and valid for writes
/// of an interface pointer.
unsafe fn QueryInterface(&self, iid: *const GUID, interface: *mut *mut std::ffi::c_void) -> HRESULT;

/// Increments the reference count of the interface
fn AddRef(&self) -> u32;

/// Decrements the reference count causing the interface's memory to be freed when the count is 0
///
/// # Safety
///
/// This function should only be called when the interfacer pointer is no longer used as calling `Release`
/// on a non-aliased interface pointer and then using that interface pointer may result in use after free.
unsafe fn Release(&self) -> u32;

/// Gets the trust level of the current object.
unsafe fn GetTrustLevel(&self, value: *mut i32) -> HRESULT;
}

#[cfg(feature = "implement")]
Expand Down
25 changes: 25 additions & 0 deletions crates/libs/implement/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
});

let trust_level = proc_macro2::Literal::usize_unsuffixed(attributes.trust_level);

let conversions = attributes.implement.iter().enumerate().map(|(enumerate, implement)| {
let interface_ident = implement.to_ident();
let offset = proc_macro2::Literal::usize_unsuffixed(enumerate);
Expand Down Expand Up @@ -162,6 +164,13 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
remaining
}
unsafe fn GetTrustLevel(&self, value: *mut i32) -> ::windows::core::HRESULT {
if value.is_null() {
return ::windows::core::HRESULT(-2147467261); // E_POINTER
}
*value = #trust_level;
::windows::core::HRESULT(0)
}
}
impl #generics #original_ident::#generics where #constraints {
/// Try casting as the provided interface
Expand Down Expand Up @@ -225,6 +234,7 @@ impl ImplementType {
#[derive(Default)]
struct ImplementAttributes {
pub implement: Vec<ImplementType>,
pub trust_level: usize,
}

impl syn::parse::Parse for ImplementAttributes {
Expand Down Expand Up @@ -269,6 +279,7 @@ impl ImplementAttributes {
self.walk_implement(tree, namespace)?;
}
}
UseTree2::TrustLevel(input) => self.trust_level = *input,
}

Ok(())
Expand All @@ -279,6 +290,7 @@ enum UseTree2 {
Path(UsePath2),
Name(UseName2),
Group(UseGroup2),
TrustLevel(usize),
}

impl UseTree2 {
Expand Down Expand Up @@ -308,6 +320,7 @@ impl UseTree2 {
Ok(ImplementType { type_name, generics })
}
UseTree2::Group(input) => Err(syn::parse::Error::new(input.brace_token.span.join(), "Syntax not supported")),
_ => unimplemented!(),
}
}
}
Expand Down Expand Up @@ -336,6 +349,18 @@ impl syn::parse::Parse for UseTree2 {
if input.peek(syn::Token![::]) {
input.parse::<syn::Token![::]>()?;
Ok(UseTree2::Path(UsePath2 { ident, tree: Box::new(input.parse()?) }))
} else if input.peek(syn::Token![=]) {
if ident != "TrustLevel" {
return Err(syn::parse::Error::new(ident.span(), "Unrecognized key-value pair"));
}
input.parse::<syn::Token![=]>()?;
let span = input.span();
let value = input.call(syn::Ident::parse_any)?;
match value.to_string().as_str() {
"Partial" => Ok(UseTree2::TrustLevel(1)),
"Full" => Ok(UseTree2::TrustLevel(2)),
_ => Err(syn::parse::Error::new(span, "`TrustLevel` must be `Partial` or `Full`")),
}
} else {
let generics = if input.peek(syn::Token![<]) {
input.parse::<syn::Token![<]>()?;
Expand Down
54 changes: 54 additions & 0 deletions crates/tests/implement/tests/trust_level.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#![allow(non_snake_case)]

use windows::core::*;
use windows::Foundation::*;

#[implement(IStringable)]
struct BaseTrust;

impl IStringable_Impl for BaseTrust {
fn ToString(&self) -> Result<HSTRING> {
Ok("BaseTrust".into())
}
}

#[implement(IClosable, TrustLevel = Partial, IStringable)]
struct PartialTrust;

impl IStringable_Impl for PartialTrust {
fn ToString(&self) -> Result<HSTRING> {
Ok("PartialTrust".into())
}
}

impl IClosable_Impl for PartialTrust {
fn Close(&self) -> Result<()> {
Ok(())
}
}

#[implement(IStringable, TrustLevel = Full)]
struct FullTrust;

impl IStringable_Impl for FullTrust {
fn ToString(&self) -> Result<HSTRING> {
Ok("FullTrust".into())
}
}

#[test]
fn test() -> Result<()> {
let base: IStringable = BaseTrust.into();
assert_eq!(base.ToString()?, "BaseTrust");
assert_eq!(base.cast::<IInspectable>()?.GetTrustLevel()?, 0);

let partial: IStringable = PartialTrust.into();
assert_eq!(partial.ToString()?, "PartialTrust");
assert_eq!(partial.cast::<IInspectable>()?.GetTrustLevel()?, 1);

let full: IStringable = FullTrust.into();
assert_eq!(full.ToString()?, "FullTrust");
assert_eq!(full.cast::<IInspectable>()?.GetTrustLevel()?, 2);

Ok(())
}

0 comments on commit 5c5c82a

Please sign in to comment.