diff --git a/Cargo.toml b/Cargo.toml index faa25d0a87..4b01632c2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,9 @@ exclude = [".github", ".windows", "docs", "examples"] [dependencies] windows_macros = { path = "crates/macros", version = "0.10.0", optional = true } gen = { package = "windows_gen", path = "crates/gen", version = "0.10.0", optional = true } + const-sha1 = "0.2" +widestring = "0.4" [dev-dependencies] gen = { package = "windows_gen", path = "crates/gen" } diff --git a/crates/gen/src/types/bstr.rs b/crates/gen/src/types/bstr.rs index 90527c80f3..c51a4785a6 100644 --- a/crates/gen/src/types/bstr.rs +++ b/crates/gen/src/types/bstr.rs @@ -4,24 +4,35 @@ pub fn gen_bstr() -> TokenStream { quote! { #[repr(transparent)] #[derive(::std::cmp::Eq)] - pub struct BSTR(*mut u16); + /// https://docs.microsoft.com/en-us/previous-versions/windows/desktop/automat/bstr#remarks + /// Uses [`::windows::widestring::UStr`] and not [`::windows::widestring::UCstr`], the latter checks for internal nulls. + pub struct BSTR(*mut ::windows::widestring::WideChar); impl BSTR { pub fn is_empty(&self) -> bool { + // TODO: Should possibly also check length! self.0.is_null() } - fn from_wide(value: &[u16]) -> Self { - if value.len() == 0 { + pub fn len(&self) -> usize { + unsafe { SysStringLen(self) as usize } + } + fn from_wide(value: &::windows::widestring::WideStr) -> Self { + if value.is_empty() { return Self(::std::ptr::null_mut()); } - - unsafe { SysAllocStringLen(super::SystemServices::PWSTR(value.as_ptr() as _), value.len() as u32) } + unsafe { + SysAllocStringLen( + super::SystemServices::PWSTR(value.as_ptr() as _), + value.len() as u32, + ) + } } - fn as_wide(&self) -> &[u16] { + fn as_wide(&self) -> &::windows::widestring::WideStr { if self.0.is_null() { - return &[]; + // `UStr` unlike `UCStr` doesn't implement an empty-string default yet + ::windows::widestring::WideStr::from_slice(&[]) + } else { + unsafe { ::windows::widestring::WideStr::from_ptr(self.0, self.len()) } } - - unsafe { ::std::slice::from_raw_parts(self.0 as *const u16, SysStringLen(self) as usize) } } } impl ::std::clone::Clone for BSTR { @@ -29,51 +40,51 @@ pub fn gen_bstr() -> TokenStream { Self::from_wide(self.as_wide()) } } + impl ::std::convert::From<&str> for BSTR { fn from(value: &str) -> Self { - let value: ::std::vec::Vec = value.encode_utf16().collect(); + // TODO: This allocates+copies twice. + let value = ::windows::widestring::WideString::from_str(value); Self::from_wide(&value) } } - impl ::std::convert::From<::std::string::String> for BSTR { fn from(value: ::std::string::String) -> Self { value.as_str().into() } } - impl ::std::convert::From<&::std::string::String> for BSTR { fn from(value: &::std::string::String) -> Self { value.as_str().into() } } - impl<'a> ::std::convert::TryFrom<&'a BSTR> for ::std::string::String { - type Error = ::std::string::FromUtf16Error; + #[cfg(windows)] + type FromWidestringError = ::std::string::FromUtf16Error; + #[cfg(not(windows))] + type FromWidestringError = ::windows::widestring::FromUtf32Error; + impl<'a> ::std::convert::TryFrom<&'a BSTR> for ::std::string::String { + type Error = FromWidestringError; fn try_from(value: &BSTR) -> ::std::result::Result { - ::std::string::String::from_utf16(value.as_wide()) + value.as_wide().to_string() } } - impl ::std::convert::TryFrom for ::std::string::String { - type Error = ::std::string::FromUtf16Error; - + type Error = FromWidestringError; fn try_from(value: BSTR) -> ::std::result::Result { - ::std::string::String::try_from(&value) + value.as_wide().to_string() } } + impl ::std::default::Default for BSTR { fn default() -> Self { Self(::std::ptr::null_mut()) } } + impl ::std::fmt::Display for BSTR { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - use ::std::fmt::Write; - for c in ::std::char::decode_utf16(self.as_wide().iter().cloned()) { - f.write_char(c.map_err(|_| ::std::fmt::Error)?)? - } - Ok(()) + f.write_str(&self.as_wide().to_string().unwrap()) } } impl ::std::fmt::Debug for BSTR { @@ -81,6 +92,7 @@ pub fn gen_bstr() -> TokenStream { ::std::write!(f, "{}", self) } } + impl ::std::cmp::PartialEq for BSTR { fn eq(&self, other: &Self) -> bool { self.as_wide() == other.as_wide() @@ -98,30 +110,32 @@ pub fn gen_bstr() -> TokenStream { } impl ::std::cmp::PartialEq<&str> for BSTR { fn eq(&self, other: &&str) -> bool { - self.as_wide().iter().copied().eq(other.encode_utf16()) + let other = ::windows::widestring::WideString::from_str(other); + self.as_wide().eq(&other) } } - impl ::std::cmp::PartialEq for &str { fn eq(&self, other: &BSTR) -> bool { other == self } } + impl ::std::ops::Drop for BSTR { fn drop(&mut self) { if !self.0.is_null() { - unsafe { SysFreeString(self as &Self); } + unsafe { SysFreeString(self as &Self) }; } } } + unsafe impl ::windows::Abi for BSTR { - type Abi = *mut u16; + type Abi = *mut ::windows::widestring::WideChar; - fn set_abi(&mut self) -> *mut *mut u16 { + fn set_abi(&mut self) -> *mut *mut ::windows::widestring::WideChar { debug_assert!(self.0.is_null()); &mut self.0 as *mut _ as _ } } - pub type BSTR_abi = *mut u16; + pub type BSTR_abi = *mut ::windows::widestring::WideChar; } } diff --git a/crates/gen/src/types/callback.rs b/crates/gen/src/types/callback.rs index aa3b99de7f..f496833b4b 100644 --- a/crates/gen/src/types/callback.rs +++ b/crates/gen/src/types/callback.rs @@ -40,8 +40,24 @@ impl Callback { quote! {} }; + let query_interface_fn = if signature.has_query_interface() { + let constraints = signature.gen_constraints(&signature.params); + let leading_params = &signature.params[..signature.params.len() - 2]; + let params = signature.gen_win32_params(leading_params, gen); + let args = leading_params.iter().map(|p| p.gen_win32_abi_arg()); + quote! { + pub unsafe fn #name<#constraints T: ::windows::Interface>(func: &#name, #params) -> ::windows::Result { + let mut result__ = ::std::option::Option::None; + (func)(#(#args,)* &::IID, ::windows::Abi::set_abi(&mut result__)).and_some(result__) + } + } + } else { + quote!() + }; + quote! { pub type #name = unsafe extern "system" fn(#(#params),*) #return_type; + #query_interface_fn } } } diff --git a/crates/gen/src/types/pwstr.rs b/crates/gen/src/types/pwstr.rs index 712a44d75a..6df2c5f19e 100644 --- a/crates/gen/src/types/pwstr.rs +++ b/crates/gen/src/types/pwstr.rs @@ -3,46 +3,92 @@ use super::*; pub fn gen_pwstr() -> TokenStream { quote! { #[repr(transparent)] - #[derive(::std::clone::Clone, ::std::marker::Copy, ::std::cmp::Eq, ::std::fmt::Debug)] - pub struct PWSTR(pub *mut u16); + #[derive(::std::clone::Clone, ::std::marker::Copy, ::std::cmp::Eq)] + /// Uses [`::windows::widestring::UCStr`] for null-terminated, checked strings. + pub struct PWSTR(pub *mut ::windows::widestring::WideChar); impl PWSTR { pub const NULL: Self = Self(::std::ptr::null_mut()); pub fn is_null(&self) -> bool { self.0.is_null() } + pub fn is_empty(&self) -> bool { + // TODO: Should possibly also check length! + self.is_null() + } + pub fn len(&self) -> usize { + self.as_wide().len() + } + fn as_wide(&self) -> &::windows::widestring::WideCStr { + if self.is_null() { + Default::default() + } else { + unsafe { ::windows::widestring::WideCStr::from_ptr_str(self.0) } + } + } } + impl ::std::default::Default for PWSTR { fn default() -> Self { Self(::std::ptr::null_mut()) } } - // TODO: impl Debug and Display to display value and PartialEq etc + + impl ::std::fmt::Display for PWSTR { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + f.write_str(&self.as_wide().to_string().unwrap()) + } + } + impl ::std::fmt::Debug for PWSTR { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::std::write!(f, "{}", self) + } + } + impl ::std::cmp::PartialEq for PWSTR { fn eq(&self, other: &Self) -> bool { - // TODO: do value compare - self.0 == other.0 + self.as_wide().eq(other.as_wide()) + } + } + impl ::std::cmp::PartialEq<&str> for PWSTR { + fn eq(&self, other: &&str) -> bool { + let other = unsafe { ::windows::widestring::WideCString::from_str_unchecked(other) }; + self.as_wide().eq(&other) + } + } + impl ::std::cmp::PartialEq for &str { + fn eq(&self, other: &PWSTR) -> bool { + other.eq(self) } } + unsafe impl ::windows::Abi for PWSTR { type Abi = Self; fn drop_param(param: &mut ::windows::Param) { if let ::windows::Param::Boxed(value) = param { - if !value.0.is_null() { - unsafe { ::std::boxed::Box::from_raw(value.0); } + if !value.is_null() { + unsafe { ::windows::widestring::WideCString::from_raw(value.0) }; } } } } impl<'a> ::windows::IntoParam<'a, PWSTR> for &'a str { fn into_param(self) -> ::windows::Param<'a, PWSTR> { - ::windows::Param::Boxed(PWSTR(::std::boxed::Box::<[u16]>::into_raw(self.encode_utf16().chain(::std::iter::once(0)).collect::>().into_boxed_slice()) as _)) + ::windows::Param::Boxed(PWSTR( + ::windows::widestring::WideCString::from_str(self) + .unwrap() + .into_raw(), + )) } } impl<'a> ::windows::IntoParam<'a, PWSTR> for String { fn into_param(self) -> ::windows::Param<'a, PWSTR> { // TODO: call variant above - ::windows::Param::Boxed(PWSTR(::std::boxed::Box::<[u16]>::into_raw(self.encode_utf16().chain(::std::iter::once(0)).collect::>().into_boxed_slice()) as _)) + ::windows::Param::Boxed(PWSTR( + ::windows::widestring::WideCString::from_str(self) + .unwrap() + .into_raw(), + )) } } } diff --git a/examples/dxc/Cargo.toml b/examples/dxc/Cargo.toml new file mode 100644 index 0000000000..c2b65cb047 --- /dev/null +++ b/examples/dxc/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "dxc" +version = "0.0.0" +authors = ["Microsoft"] +edition = "2018" + +[dependencies] +bindings = { package = "dxc_bindings", path = "bindings" } +libloading = "0.7" +windows = { path = "../.." } diff --git a/examples/dxc/bindings/Cargo.toml b/examples/dxc/bindings/Cargo.toml new file mode 100644 index 0000000000..438fef584e --- /dev/null +++ b/examples/dxc/bindings/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "dxc_bindings" +version = "0.0.0" +authors = ["Microsoft"] +edition = "2018" + +[dependencies] +windows = { path = "../../.." } + +[build-dependencies] +windows = { path = "../../.." } diff --git a/examples/dxc/bindings/build.rs b/examples/dxc/bindings/build.rs new file mode 100644 index 0000000000..8406daaef9 --- /dev/null +++ b/examples/dxc/bindings/build.rs @@ -0,0 +1,7 @@ +fn main() { + windows::build!( + Windows::Win32::Globalization::CP_UTF8, + Windows::Win32::Graphics::Hlsl::*, + Windows::Win32::System::Diagnostics::Debug::ERROR_FILE_NOT_FOUND, + ); +} diff --git a/examples/dxc/bindings/src/lib.rs b/examples/dxc/bindings/src/lib.rs new file mode 100644 index 0000000000..d9ddca77a8 --- /dev/null +++ b/examples/dxc/bindings/src/lib.rs @@ -0,0 +1 @@ +windows::include_bindings!(); diff --git a/examples/dxc/src/copy.hlsl b/examples/dxc/src/copy.hlsl new file mode 100644 index 0000000000..1d36a625ad --- /dev/null +++ b/examples/dxc/src/copy.hlsl @@ -0,0 +1,8 @@ +Texture2D g_input : register(t0, space0); +RWTexture2D g_output : register(u0, space0); + +[numthreads(8, 8, 1)] +void copyCs(uint3 dispatchThreadId : SV_DispatchThreadID) +{ + g_output[dispatchThreadId.xy] = g_input[dispatchThreadId.xy]; +} diff --git a/examples/dxc/src/main.rs b/examples/dxc/src/main.rs new file mode 100644 index 0000000000..0a2f06b326 --- /dev/null +++ b/examples/dxc/src/main.rs @@ -0,0 +1,222 @@ +use bindings::Windows::Win32::{ + Globalization::CP_UTF8, + Graphics::Hlsl::*, + System::{ + Diagnostics::Debug::ERROR_FILE_NOT_FOUND, + SystemServices::{BOOL, PWSTR}, + }, +}; +use libloading::{Library, Symbol}; +use std::path::Path; +use std::rc::Rc; +use windows::*; + +// Only exists in newer DXC headers +const DXC_CP_UTF8: u32 = CP_UTF8; + +#[cfg(target_os = "windows")] +fn dxcompiler_lib_name() -> &'static Path { + Path::new("dxcompiler.dll") +} + +#[cfg(target_os = "linux")] +fn dxcompiler_lib_name() -> &'static Path { + Path::new("./libdxcompiler.so") +} + +#[cfg(target_os = "macos")] +fn dxcompiler_lib_name() -> &'static Path { + Path::new("./libdxcompiler.dynlib") +} + +fn blob_encoding_as_str(blob: &IDxcBlobEncoding) -> &str { + let mut known: BOOL = false.into(); + let mut cp = 0; + unsafe { blob.GetEncoding(known.set_abi(), cp.set_abi()) }.unwrap(); + assert!(bool::from(known)); + assert_eq!(cp, DXC_CP_UTF8); + unsafe { + let slice = std::slice::from_raw_parts( + blob.GetBufferPointer() as *const u8, + blob.GetBufferSize() - 1, + ); + std::str::from_utf8_unchecked(slice) + } +} + +fn create_blob(library: &IDxcLibrary, data: &str) -> windows::Result { + let mut blob = None; + unsafe { + library.CreateBlobWithEncodingFromPinned( + data.as_ptr() as *const _, + data.len() as u32, + DXC_CP_UTF8, + &mut blob, + ) + } + .and_some(blob) +} + +#[allow(non_snake_case)] +fn main() -> windows::Result<()> { + let lib = unsafe { Library::new(dxcompiler_lib_name()) }.unwrap(); + let create: Symbol = unsafe { lib.get(b"DxcCreateInstance\0") }.unwrap(); + dbg!(&create); + + let compiler: IDxcCompiler2 = unsafe { DxcCreateInstanceProc(&create, &CLSID_DxcCompiler) }?; + let library: IDxcLibrary = unsafe { DxcCreateInstanceProc(&create, &CLSID_DxcLibrary) }?; + + dbg!(&compiler, &library); + + let main_shader = "#include \"copy.hlsl\""; + let shader_blob = create_blob(&library, main_shader)?; + + unsafe extern "system" fn LoadSource( + this: RawPtr, + pfilename: PWSTR, + ppincludesource: *mut RawPtr, + ) -> HRESULT { + let this = &mut *(this as *mut IncludeHandlerData); + dbg!(&this, pfilename, ppincludesource); + + if pfilename != "./foo/bar/copy.hlsl" { + return HRESULT::from_win32(ERROR_FILE_NOT_FOUND.0); + } + let copy_shader = include_str!("copy.hlsl"); + let shader_blob = create_blob(&this.library, copy_shader).unwrap(); + *ppincludesource = shader_blob.abi(); + this.alive_shaders.push(shader_blob); + HRESULT(0) + } + + unsafe extern "system" fn QueryInterface( + _ptr: RawPtr, + _iid: &Guid, + _interface: *mut RawPtr, + ) -> HRESULT { + todo!() + } + + unsafe extern "system" fn AddRef(this: ::windows::RawPtr) -> u32 { + let this = &mut *(this as *mut IncludeHandlerData); + dbg!(this.refs.add_ref()) + } + + unsafe extern "system" fn Release(this: ::windows::RawPtr) -> u32 { + let this = &mut *(this as *mut IncludeHandlerData); + dbg!(this.refs.release()) + } + + #[cfg(not(windows))] + unsafe extern "system" fn dtor(_ptr: ::windows::RawPtr) { + todo!() + } + + let include_handler = IDxcIncludeHandler_abi( + // IUnknown + QueryInterface, + AddRef, + Release, + // Virtual destructors breaking the vtable, hack for !windows DXC + #[cfg(not(windows))] + dtor, + #[cfg(not(windows))] + dtor, + LoadSource, + ); + + #[derive(Debug)] + struct IncludeHandlerData { + vtable: *const IDxcIncludeHandler_abi, + refs: RefCount, + library: Rc, + alive_shaders: Vec, + } + + #[derive(Debug)] + struct IncludeHandler(std::ptr::NonNull); + + let library = Rc::new(library); + + let handler_data = Box::new(IncludeHandlerData { + vtable: &include_handler, + refs: RefCount::new(1), + library, + alive_shaders: vec![], + }); + + let handler = + IncludeHandler(unsafe { std::ptr::NonNull::new_unchecked(Box::into_raw(handler_data)) }); + + unsafe impl ::windows::Interface for IncludeHandler { + type Vtable = IDxcIncludeHandler_abi; + const IID: ::windows::Guid = ::windows::Guid::from_values( + 2137128061, + 38157, + 18047, + [179, 227, 60, 2, 251, 73, 24, 124], + ); + } + impl ::std::convert::From for IDxcIncludeHandler { + fn from(value: IncludeHandler) -> Self { + unsafe { ::std::mem::transmute(value) } + } + } + impl ::std::convert::From<&IncludeHandler> for &IDxcIncludeHandler { + fn from(value: &IncludeHandler) -> Self { + unsafe { ::std::mem::transmute(value) } + } + } + impl<'a> IntoParam<'a, IDxcIncludeHandler> for IncludeHandler { + fn into_param(self) -> Param<'a, IDxcIncludeHandler> { + Param::Owned(::std::convert::Into::::into(self)) + } + } + impl<'a> IntoParam<'a, IDxcIncludeHandler> for &'a IncludeHandler { + fn into_param(self) -> Param<'a, IDxcIncludeHandler> { + Param::Borrowed(::std::convert::Into::<&IDxcIncludeHandler>::into(self)) + } + } + + dbg!(&handler); + + let mut args = vec![]; + let defines = vec![]; + let mut result = None; + let result = unsafe { + compiler.Compile( + shader_blob, + "foo/bar/baz.hlsl", + "copyCs", + "cs_6_5", + args.as_mut_ptr(), // Should not be mut? + args.len() as u32, + defines.as_ptr(), + defines.len() as u32, + // TODO: This also accepts a borrow which does not decrease our refcount to 0 + handler, + &mut result, + ) + } + .and_some(result)?; + + let mut status = HRESULT::default(); + unsafe { result.GetStatus(&mut status) }.ok()?; + if status.is_err() { + let mut errors = None; + let errors = unsafe { result.GetErrorBuffer(&mut errors) }.and_some(errors)?; + let errors = blob_encoding_as_str(&errors); + eprintln!("Compilation failed with {:?}: `{}`", status, errors); + status.ok() + } else { + let mut blob = None; + let blob = unsafe { result.GetResult(&mut blob) }.and_some(blob)?; + let shader = unsafe { + std::slice::from_raw_parts(blob.GetBufferPointer().cast::(), blob.GetBufferSize()) + }; + for c in shader.chunks(16) { + println!("{:02x?}", c); + } + Ok(()) + } +} diff --git a/src/bindings.rs b/src/bindings.rs index 0c160d0f02..960d836356 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -2282,13 +2282,18 @@ pub mod Windows { } #[repr(transparent)] #[derive(:: std :: cmp :: Eq)] - pub struct BSTR(*mut u16); + #[doc = r" https://docs.microsoft.com/en-us/previous-versions/windows/desktop/automat/bstr#remarks"] + #[doc = r" Uses [`::windows::widestring::UStr`] and not [`::windows::widestring::UCstr`], the latter checks for internal nulls."] + pub struct BSTR(*mut ::windows::widestring::WideChar); impl BSTR { pub fn is_empty(&self) -> bool { self.0.is_null() } - fn from_wide(value: &[u16]) -> Self { - if value.len() == 0 { + pub fn len(&self) -> usize { + unsafe { SysStringLen(self) as usize } + } + fn from_wide(value: &::windows::widestring::WideStr) -> Self { + if value.is_empty() { return Self(::std::ptr::null_mut()); } unsafe { @@ -2298,15 +2303,11 @@ pub mod Windows { ) } } - fn as_wide(&self) -> &[u16] { + fn as_wide(&self) -> &::windows::widestring::WideStr { if self.0.is_null() { - return &[]; - } - unsafe { - ::std::slice::from_raw_parts( - self.0 as *const u16, - SysStringLen(self) as usize, - ) + ::windows::widestring::WideStr::from_slice(&[]) + } else { + unsafe { ::windows::widestring::WideStr::from_ptr(self.0, self.len()) } } } } @@ -2317,7 +2318,7 @@ pub mod Windows { } impl ::std::convert::From<&str> for BSTR { fn from(value: &str) -> Self { - let value: ::std::vec::Vec = value.encode_utf16().collect(); + let value = ::windows::widestring::WideString::from_str(value); Self::from_wide(&value) } } @@ -2331,16 +2332,20 @@ pub mod Windows { value.as_str().into() } } + #[cfg(windows)] + type FromWidestringError = ::std::string::FromUtf16Error; + #[cfg(not(windows))] + type FromWidestringError = ::windows::widestring::FromUtf32Error; impl<'a> ::std::convert::TryFrom<&'a BSTR> for ::std::string::String { - type Error = ::std::string::FromUtf16Error; + type Error = FromWidestringError; fn try_from(value: &BSTR) -> ::std::result::Result { - ::std::string::String::from_utf16(value.as_wide()) + value.as_wide().to_string() } } impl ::std::convert::TryFrom for ::std::string::String { - type Error = ::std::string::FromUtf16Error; + type Error = FromWidestringError; fn try_from(value: BSTR) -> ::std::result::Result { - ::std::string::String::try_from(&value) + value.as_wide().to_string() } } impl ::std::default::Default for BSTR { @@ -2350,11 +2355,7 @@ pub mod Windows { } impl ::std::fmt::Display for BSTR { fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { - use std::fmt::Write; - for c in ::std::char::decode_utf16(self.as_wide().iter().cloned()) { - f.write_char(c.map_err(|_| ::std::fmt::Error)?)? - } - Ok(()) + f.write_str(&self.as_wide().to_string().unwrap()) } } impl ::std::fmt::Debug for BSTR { @@ -2379,7 +2380,8 @@ pub mod Windows { } impl ::std::cmp::PartialEq<&str> for BSTR { fn eq(&self, other: &&str) -> bool { - self.as_wide().iter().copied().eq(other.encode_utf16()) + let other = ::windows::widestring::WideString::from_str(other); + self.as_wide().eq(&other) } } impl ::std::cmp::PartialEq for &str { @@ -2390,20 +2392,18 @@ pub mod Windows { impl ::std::ops::Drop for BSTR { fn drop(&mut self) { if !self.0.is_null() { - unsafe { - SysFreeString(self as &Self); - } + unsafe { SysFreeString(self as &Self) }; } } } unsafe impl ::windows::Abi for BSTR { - type Abi = *mut u16; - fn set_abi(&mut self) -> *mut *mut u16 { + type Abi = *mut ::windows::widestring::WideChar; + fn set_abi(&mut self) -> *mut *mut ::windows::widestring::WideChar { debug_assert!(self.0.is_null()); &mut self.0 as *mut _ as _ } } - pub type BSTR_abi = *mut u16; + pub type BSTR_abi = *mut ::windows::widestring::WideChar; #[repr(transparent)] #[derive( :: std :: cmp :: PartialEq, @@ -2565,58 +2565,88 @@ pub mod Windows { pub mod SystemServices { #[repr(transparent)] #[derive( - :: std :: clone :: Clone, - :: std :: marker :: Copy, - :: std :: cmp :: Eq, - :: std :: fmt :: Debug, + :: std :: clone :: Clone, :: std :: marker :: Copy, :: std :: cmp :: Eq, )] - pub struct PWSTR(pub *mut u16); + #[doc = r" Uses [`::windows::widestring::UCStr`] for null-terminated, checked strings."] + pub struct PWSTR(pub *mut ::windows::widestring::WideChar); impl PWSTR { pub const NULL: Self = Self(::std::ptr::null_mut()); pub fn is_null(&self) -> bool { self.0.is_null() } + pub fn is_empty(&self) -> bool { + self.is_null() + } + pub fn len(&self) -> usize { + self.as_wide().len() + } + fn as_wide(&self) -> &::windows::widestring::WideCStr { + if self.is_null() { + Default::default() + } else { + unsafe { ::windows::widestring::WideCStr::from_ptr_str(self.0) } + } + } } impl ::std::default::Default for PWSTR { fn default() -> Self { Self(::std::ptr::null_mut()) } } + impl ::std::fmt::Display for PWSTR { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + f.write_str(&self.as_wide().to_string().unwrap()) + } + } + impl ::std::fmt::Debug for PWSTR { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + ::std::write!(f, "{}", self) + } + } impl ::std::cmp::PartialEq for PWSTR { fn eq(&self, other: &Self) -> bool { - self.0 == other.0 + self.as_wide().eq(other.as_wide()) + } + } + impl ::std::cmp::PartialEq<&str> for PWSTR { + fn eq(&self, other: &&str) -> bool { + let other = unsafe { + ::windows::widestring::WideCString::from_str_unchecked(other) + }; + self.as_wide().eq(&other) + } + } + impl ::std::cmp::PartialEq for &str { + fn eq(&self, other: &PWSTR) -> bool { + other.eq(self) } } unsafe impl ::windows::Abi for PWSTR { type Abi = Self; fn drop_param(param: &mut ::windows::Param) { if let ::windows::Param::Boxed(value) = param { - if !value.0.is_null() { - unsafe { - ::std::boxed::Box::from_raw(value.0); - } + if !value.is_null() { + unsafe { ::windows::widestring::WideCString::from_raw(value.0) }; } } } } impl<'a> ::windows::IntoParam<'a, PWSTR> for &'a str { fn into_param(self) -> ::windows::Param<'a, PWSTR> { - ::windows::Param::Boxed(PWSTR(::std::boxed::Box::<[u16]>::into_raw( - self.encode_utf16() - .chain(::std::iter::once(0)) - .collect::>() - .into_boxed_slice(), - ) as _)) + ::windows::Param::Boxed(PWSTR( + ::windows::widestring::WideCString::from_str(self) + .unwrap() + .into_raw(), + )) } } impl<'a> ::windows::IntoParam<'a, PWSTR> for String { fn into_param(self) -> ::windows::Param<'a, PWSTR> { - ::windows::Param::Boxed(PWSTR(::std::boxed::Box::<[u16]>::into_raw( - self.encode_utf16() - .chain(::std::iter::once(0)) - .collect::>() - .into_boxed_slice(), - ) as _)) + ::windows::Param::Boxed(PWSTR( + ::windows::widestring::WideCString::from_str(self) + .unwrap() + .into_raw(), + )) } } #[repr(transparent)] diff --git a/src/lib.rs b/src/lib.rs index e3d07bea22..e270961b0d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -45,3 +45,6 @@ pub type RawPtr = *mut std::ffi::c_void; #[doc(hidden)] pub use const_sha1::ConstBuffer; + +#[doc(hidden)] +pub use widestring; diff --git a/src/runtime/ref_count.rs b/src/runtime/ref_count.rs index 0c4da90a75..d16f609dec 100644 --- a/src/runtime/ref_count.rs +++ b/src/runtime/ref_count.rs @@ -2,11 +2,11 @@ use std::sync::atomic::{fence, AtomicI32, Ordering}; /// A thread-safe reference count for use with COM/HSTRING implementations. #[repr(transparent)] -#[derive(Default)] +#[derive(Default, Debug)] pub struct RefCount(pub(crate) AtomicI32); impl RefCount { - /// Creates a new `RefCount` with an initial value of `1`. + /// Creates a new `RefCount` with an initial value of `count`. pub fn new(count: u32) -> Self { Self(AtomicI32::new(count as _)) }