Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HSTRING builder and registry support #3133

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/libs/registry/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ path = "../targets"
[dependencies.windows-result]
version = "0.1.1"
path = "../result"

[dependencies.windows-strings]
version = "0.1.0"
path = "../strings"
43 changes: 43 additions & 0 deletions crates/libs/registry/src/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ impl Key {
unsafe { self.set_value(name, REG_SZ, value.as_ptr() as _, value.len() * 2) }
}

/// Sets the name and value in the registry key.
pub fn set_hstring<T: AsRef<str>>(
&self,
name: T,
value: &windows_strings::HSTRING,
) -> Result<()> {
unsafe { self.set_value(name, REG_SZ, value.as_ptr() as _, value.len() * 2) }
}
ChrisDenton marked this conversation as resolved.
Show resolved Hide resolved

/// Sets the name and value in the registry key.
pub fn set_multi_string<T: AsRef<str>>(&self, name: T, value: &[T]) -> Result<()> {
let mut packed = value.iter().fold(vec![0u16; 0], |mut packed, value| {
Expand Down Expand Up @@ -278,6 +287,40 @@ impl Key {
}
}

/// Gets the value for the name in the registry key.
pub fn get_hstring<T: AsRef<str>>(&self, name: T) -> Result<HSTRING> {
let name = pcwstr(name);
let mut ty = 0;
let mut len = 0;

let result = unsafe {
RegQueryValueExW(self.0, name.as_ptr(), null(), &mut ty, null_mut(), &mut len)
};

win32_error(result)?;

if !matches!(ty, REG_SZ | REG_EXPAND_SZ) {
kennykerr marked this conversation as resolved.
Show resolved Hide resolved
return Err(invalid_data());
}

let mut value = HStringBuilder::new(len as usize / 2)?;

let result = unsafe {
RegQueryValueExW(
self.0,
name.as_ptr(),
null(),
null_mut(),
value.as_mut_ptr() as _,
&mut len,
)
};

win32_error(result)?;
value.trim_end();
Ok(value.into())
}

/// Gets the value for the name in the registry key.
pub fn get_bytes<T: AsRef<str>>(&self, name: T) -> Result<Vec<u8>> {
let name = pcwstr(name);
Expand Down
3 changes: 3 additions & 0 deletions crates/libs/registry/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pub use r#type::Type;
pub use windows_result::Result;
use windows_result::*;

pub use windows_strings::HSTRING;
use windows_strings::*;

/// The predefined `HKEY_CLASSES_ROOT` registry key.
pub const CLASSES_ROOT: &Key = &Key(HKEY_CLASSES_ROOT);

Expand Down
1 change: 0 additions & 1 deletion crates/libs/result/src/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ pub type BOOL = i32;
pub type BSTR = *const u16;
pub const ERROR_INVALID_DATA: WIN32_ERROR = 13u32;
pub const ERROR_NO_UNICODE_TRANSLATION: WIN32_ERROR = 1113u32;
pub const E_INVALIDARG: HRESULT = 0x80070057_u32 as _;
pub const E_UNEXPECTED: HRESULT = 0x8000FFFF_u32 as _;
pub const FORMAT_MESSAGE_ALLOCATE_BUFFER: FORMAT_MESSAGE_OPTIONS = 256u32;
pub const FORMAT_MESSAGE_FROM_HMODULE: FORMAT_MESSAGE_OPTIONS = 2048u32;
Expand Down
4 changes: 2 additions & 2 deletions crates/libs/strings/.natvis
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
<?xml version="1.0" encoding="utf-8"?>
<AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010">
<Type Name="windows_strings::hstring::HSTRING">
<Intrinsic Name="header" Expression="(windows_strings::hstring::Header *)__0.tag" />
<Intrinsic Name="is_empty" Expression="__0.tag == 0" />
<Intrinsic Name="header" Expression="(windows_strings::hstring_header::HStringHeader *)__0" />
<Intrinsic Name="is_empty" Expression="__0 == 0" />
<DisplayString Condition="is_empty()">""</DisplayString>
<DisplayString>{header()->data,[header()->len]su}</DisplayString>

Expand Down
40 changes: 0 additions & 40 deletions crates/libs/strings/src/heap.rs

This file was deleted.

92 changes: 18 additions & 74 deletions crates/libs/strings/src/hstring.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
use super::*;

/// A WinRT string ([HSTRING](https://docs.microsoft.com/en-us/windows/win32/winrt/hstring))
/// is reference-counted and immutable.
/// An ([HSTRING](https://docs.microsoft.com/en-us/windows/win32/winrt/hstring))
/// is a reference-counted and immutable UTF-16 string type.
#[repr(transparent)]
pub struct HSTRING(Option<core::ptr::NonNull<Header>>);
pub struct HSTRING(pub(crate) *mut HStringHeader);
kennykerr marked this conversation as resolved.
Show resolved Hide resolved

impl HSTRING {
/// Create an empty `HSTRING`.
///
/// This function does not allocate memory.
pub const fn new() -> Self {
Self(None)
Self(core::ptr::null_mut())
}

/// Returns `true` if the string is empty.
pub const fn is_empty(&self) -> bool {
pub fn is_empty(&self) -> bool {
// An empty HSTRING is represented by a null pointer.
self.0.is_none()
self.0.is_null()
}

/// Returns the length of the string. The length is measured in `u16`s (UTF-16 code units), not including the terminating null character.
pub fn len(&self) -> usize {
if let Some(header) = self.get_header() {
if let Some(header) = self.as_header() {
header.len as usize
} else {
0
Expand All @@ -35,7 +35,7 @@ impl HSTRING {

/// Returns a raw pointer to the `HSTRING` buffer.
pub fn as_ptr(&self) -> *const u16 {
if let Some(header) = self.get_header() {
if let Some(header) = self.as_header() {
header.data
} else {
const EMPTY: [u16; 1] = [0];
Expand Down Expand Up @@ -66,7 +66,7 @@ impl HSTRING {
return Ok(Self::new());
}

let ptr = Header::alloc(len.try_into()?)?;
let ptr = HStringHeader::alloc(len.try_into()?)?;

// Place each utf-16 character into the buffer and
// increase len as we go along.
Expand All @@ -79,11 +79,11 @@ impl HSTRING {

// Write a 0 byte to the end of the buffer.
(*ptr).data.offset((*ptr).len as isize).write(0);
Ok(Self(core::ptr::NonNull::new(ptr)))
Ok(Self(ptr))
}

fn get_header(&self) -> Option<&Header> {
self.0.map(|header| unsafe { header.as_ref() })
fn as_header(&self) -> Option<&HStringHeader> {
unsafe { self.0.as_ref() }
}
}

Expand All @@ -95,8 +95,8 @@ impl Default for HSTRING {

impl Clone for HSTRING {
fn clone(&self) -> Self {
if let Some(header) = self.get_header() {
Self(core::ptr::NonNull::new(header.duplicate().unwrap()))
if let Some(header) = self.as_header() {
Self(header.duplicate().unwrap())
} else {
Self::new()
}
Expand All @@ -105,17 +105,12 @@ impl Clone for HSTRING {

impl Drop for HSTRING {
fn drop(&mut self) {
if self.is_empty() {
return;
}

if let Some(header) = self.0.take() {
// REFERENCE_FLAG indicates a string backed by static or stack memory that is
if let Some(header) = self.as_header() {
// HSTRING_REFERENCE_FLAG indicates a string backed by static or stack memory that is
// thus not reference-counted and does not need to be freed.
unsafe {
let header = header.as_ref();
if header.flags & REFERENCE_FLAG == 0 && header.count.release() == 0 {
heap_free(header as *const _ as *mut _);
if header.flags & HSTRING_REFERENCE_FLAG == 0 && header.count.release() == 0 {
HStringHeader::free(self.0);
}
}
}
Expand Down Expand Up @@ -407,54 +402,3 @@ impl From<HSTRING> for std::ffi::OsString {
Self::from(&hstring)
}
}

const REFERENCE_FLAG: u32 = 1;

#[repr(C)]
struct Header {
kennykerr marked this conversation as resolved.
Show resolved Hide resolved
flags: u32,
len: u32,
_0: u32,
_1: u32,
data: *mut u16,
count: RefCount,
buffer_start: u16,
}

impl Header {
fn alloc(len: u32) -> Result<*mut Header> {
debug_assert!(len != 0);
// Allocate enough space for header and two bytes per character.
// The space for the terminating null character is already accounted for inside of `Header`.
let alloc_size = core::mem::size_of::<Header>() + 2 * len as usize;

let header = heap_alloc(alloc_size)? as *mut Header;

unsafe {
// Use `ptr::write` (since `header` is unintialized). `Header` is safe to be all zeros.
header.write(core::mem::MaybeUninit::<Header>::zeroed().assume_init());
(*header).len = len;
(*header).count = RefCount::new(1);
(*header).data = &mut (*header).buffer_start;
}

Ok(header)
}

fn duplicate(&self) -> Result<*mut Header> {
if self.flags & REFERENCE_FLAG == 0 {
// If this is not a "fast pass" string then simply increment the reference count.
self.count.add_ref();
Ok(self as *const Header as *mut Header)
} else {
// Otherwise, allocate a new string and copy the value into the new string.
let copy = Header::alloc(self.len)?;
// SAFETY: since we are duplicating the string it is safe to copy all data from self to the initialized `copy`.
// We copy `len + 1` characters since `len` does not account for the terminating null character.
unsafe {
core::ptr::copy_nonoverlapping(self.data, (*copy).data, self.len as usize + 1);
}
Ok(copy)
}
}
}
83 changes: 83 additions & 0 deletions crates/libs/strings/src/hstring_builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use super::*;

/// An [HSTRING] builder that supports preallocating the `HSTRING` to avoid extra allocations and copies.
///
/// This is similar to the `WindowsPreallocateStringBuffer` function but implemented directly in Rust for efficiency.
/// It is implemented as a separate type since [HSTRING] values are immutable.
pub struct HStringBuilder(*mut HStringHeader);

impl HStringBuilder {
/// Creates a preallocated `HSTRING` value.
pub fn new(len: usize) -> Result<Self> {
Ok(Self(HStringHeader::alloc(len.try_into()?)?))
}

/// Shortens the string by removing any trailing 0 characters.
pub fn trim_end(&mut self) {
if let Some(header) = self.as_header_mut() {
while header.len > 0
&& unsafe { header.data.offset(header.len as isize - 1).read() == 0 }
{
header.len -= 1;
}

if header.len == 0 {
unsafe {
HStringHeader::free(self.0);
}
self.0 = core::ptr::null_mut();
}
}
}

fn as_header(&self) -> Option<&HStringHeader> {
unsafe { self.0.as_ref() }
}

fn as_header_mut(&mut self) -> Option<&mut HStringHeader> {
unsafe { self.0.as_mut() }
}
}

impl From<HStringBuilder> for HSTRING {
fn from(value: HStringBuilder) -> Self {
if let Some(header) = value.as_header() {
unsafe { header.data.offset(header.len as isize).write(0) };
let result = Self(value.0);
core::mem::forget(value);
result
} else {
Self::new()
}
}
}

impl core::ops::Deref for HStringBuilder {
type Target = [u16];

fn deref(&self) -> &[u16] {
if let Some(header) = self.as_header() {
unsafe { core::slice::from_raw_parts(header.data, header.len as usize) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this expose uninitialized memory? i.e. *HStringBuilder::new(123).unwrap() is a [u16] slice whose contents are uninitialized.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks for the reminder: #3141

} else {
&[]
}
}
}

impl core::ops::DerefMut for HStringBuilder {
fn deref_mut(&mut self) -> &mut [u16] {
if let Some(header) = self.as_header() {
unsafe { core::slice::from_raw_parts_mut(header.data, header.len as usize) }
} else {
&mut []
}
}
}

impl Drop for HStringBuilder {
fn drop(&mut self) {
unsafe {
HStringHeader::free(self.0);
}
}
}
Loading