diff --git a/crates/libs/core/src/unknown.rs b/crates/libs/core/src/unknown.rs index 1f16b35a0d4..bdbaf84bc69 100644 --- a/crates/libs/core/src/unknown.rs +++ b/crates/libs/core/src/unknown.rs @@ -47,11 +47,17 @@ impl Drop for IUnknown { impl PartialEq for IUnknown { fn eq(&self, other: &Self) -> bool { + // First we test for ordinary pointer equality. If two COM interface pointers have the + // same pointer value, then they are the same object. This can save us a lot of time, + // since calling QueryInterface is much more expensive than a single pointer comparison. + // + // However, interface pointers may have different values and yet point to the same object. // Since COM objects may implement multiple interfaces, COM identity can only // be determined by querying for `IUnknown` explicitly and then comparing the // pointer values. This works since `QueryInterface` is required to return // the same pointer value for queries for `IUnknown`. - self.cast::().unwrap().0 == other.cast::().unwrap().0 + self.as_raw() == other.as_raw() + || self.cast::().unwrap().0 == other.cast::().unwrap().0 } } diff --git a/crates/tests/implement_core/src/com_object.rs b/crates/tests/implement_core/src/com_object.rs index 6cf227be893..9bc86c2bd79 100644 --- a/crates/tests/implement_core/src/com_object.rs +++ b/crates/tests/implement_core/src/com_object.rs @@ -1,5 +1,6 @@ //! Unit tests for `windows_core::ComObject` +use core::ffi::c_void; use std::borrow::Borrow; use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; use std::sync::Arc; @@ -30,6 +31,8 @@ unsafe trait IBar2: IBar { const APP_SIGNATURE: [u8; 8] = *b"cafef00d"; +// We are intentionally declaring IBar twice (in two different interface chains). +// If you change this, you will need to update tests. #[implement(IFoo, IBar, IBar2)] struct MyApp { // We use signature to verify field offsets for dynamic casts @@ -332,7 +335,7 @@ fn construct_with_into() { } #[test] -fn debug() { +fn com_object_debug() { let app = MyApp::new(100); let s = format!("{:?}", app); assert_eq!(s, "x = 100"); @@ -417,12 +420,80 @@ fn common_method_name() { } #[test] -fn debug_fmt() { +fn interface_debug_fmt() { let app = MyApp::new(42); + + let iunknown: IUnknown = app.to_interface(); + let unknown_dbg = format!("{iunknown:?}"); + assert!(unknown_dbg.starts_with("IUnknown(0x"), "{unknown_dbg:?}"); + + let ifoo: IFoo = app.to_interface(); + let foo_dbg = format!("{ifoo:?}"); + assert!(foo_dbg.starts_with("IFoo(0x"), "{foo_dbg:?}"); +} + +// Test that we always get the same IUnknown pointer for an object, regardless of which +// interface we use to query for it. +#[test] +fn iunknown_identity() { + let app = MyApp::new(0); + + println!("identity = {:?}", &app.identity as *const _); + println!("vtables.0 = {:?}", &app.vtables.0 as *const _); + println!("vtables.1 = {:?}", &app.vtables.1 as *const _); + println!("vtables.2 = {:?}", &app.vtables.2 as *const _); + + let raw_identity: *mut c_void = &app.identity as *const _ as *mut c_void; + let raw_foo: *mut c_void = &app.vtables.0 as *const _ as *mut c_void; + let raw_bar: *mut c_void = &app.vtables.1 as *const _ as *mut c_void; + let raw_bar2: *mut c_void = &app.vtables.2 as *const _ as *mut c_void; + + // iunknown is from the identity vtable slot. It is the canonical IUnknown pointer. let iunknown: IUnknown = app.to_interface(); - println!("IUnknown = {iunknown:?}"); + assert_eq!( + iunknown.as_raw(), + raw_identity, + "IUnknown should come from primary interface" + ); + + // Get the most-derived interface of each interface chain. let ifoo: IFoo = app.to_interface(); - println!("IFoo = {ifoo:?}"); + let ibar: IBar = app.to_interface(); + let ibar2: IBar2 = app.to_interface(); + + assert_eq!(ifoo.as_raw(), raw_foo, "IFoo interface chain"); + assert_eq!(ibar.as_raw(), raw_bar, "IBar interface chain"); + assert_eq!(ibar2.as_raw(), raw_bar2, "IBar2 interface chain"); + + // Do a static cast from IBar to IUnknown and verify that the interface pointer for the + // resulting IUnknown is _not_ the same as the canonical IUnknown pointer. + { + let ibar_iunknown_static: IUnknown = (*ibar).clone(); + assert_ne!( + ibar_iunknown_static.as_raw(), + iunknown.as_raw(), + "IBar-to-IUnknown is non-canonical interface chain" + ); + + // The equality implementation uses QueryInterface to check that they are the same pointer. + assert_eq!( + ibar_iunknown_static, iunknown, + "QueryInterface for IUnknown yields same pointer" + ); + } + + // ibar and ibar2_ibar point to different interface chains, even though they have the same + // COM interface type. + let ibar2_ibar: &IBar = &ibar2; // static cast + assert_ne!( + ibar.as_raw(), + ibar2_ibar.as_raw(), + "IBar from different interface chains have different pointer values" + ); + assert_eq!( + ibar, *ibar2_ibar, + "IBar from different interface chains are equal (using QueryInterface)" + ); } // This tests that we can place a type that is not Send in a ComObject.