diff --git a/crates/libs/bindgen/src/rust/writer.rs b/crates/libs/bindgen/src/rust/writer.rs index 3da4429bcc..a925515119 100644 --- a/crates/libs/bindgen/src/rust/writer.rs +++ b/crates/libs/bindgen/src/rust/writer.rs @@ -697,6 +697,34 @@ impl Writer { self.GetResults() } } + #features + impl<#constraints> windows_core::AsyncOperation for #ident { + type Output = #return_type; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != #namespace AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&#namespace #handler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } + } + #features + impl<#constraints> std::future::IntoFuture for #ident { + type Output = windows_core::Result<#return_type>; + type IntoFuture = windows_core::FutureWrapper<#ident>; + + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } + } } } } diff --git a/crates/libs/core/src/future.rs b/crates/libs/core/src/future.rs new file mode 100644 index 0000000000..bbbc77a6b9 --- /dev/null +++ b/crates/libs/core/src/future.rs @@ -0,0 +1,74 @@ +#![cfg(feature = "std")] + +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, + task::{Poll, Waker}, +}; + +/// Wraps an `IAsyncOperation`, `IAsyncOperationWithProgress`, `IAsyncAction`, or `IAsyncActionWithProgress`. +/// Impls for this trait are generated automatically by windows-bindgen. +pub trait AsyncOperation { + /// The type produced when the operation finishes. + type Output; + /// Returns whether the operation is finished, in which case `self.get_results()` can be used to get the returned data. + /// Wraps `self.Status() != AsyncStatus::Started`. + fn is_complete(&self) -> crate::Result; + /// Register a callback that will be called once the operation is finished. + /// This can only be called once. + /// Wraps `self.SetCompleted(f)`. + fn set_completed(&self, f: impl Fn() + Send + 'static) -> crate::Result<()>; + /// Get the result value from a completed operation. + /// Wraps `self.GetResults()`. + fn get_results(&self) -> crate::Result; + /// Attempts to cancel the operation. Any error is ignored. + /// Wraps `self.Cancel()`. + fn cancel(&self); +} + +/// A wrapper around an `AsyncOperation` that implements `std::future::Future`. +/// This is used by generated `IntoFuture` impls. It shouldn't be necessary to use this type manually. +pub struct FutureWrapper { + inner: T, + waker: Option>>, +} + +impl FutureWrapper { + /// Creates a `FutureWrapper`, which implements `std::future::Future`. + pub fn new(inner: T) -> Self { + Self { inner, waker: None } + } +} + +impl Unpin for FutureWrapper {} + +impl Future for FutureWrapper { + type Output = crate::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + if self.inner.is_complete()? { + Poll::Ready(self.inner.get_results()) + } else { + if let Some(saved_waker) = &self.waker { + // Update the saved waker, in case the future has been transferred to a different executor. + // (e.g. if using `select`.) + let mut saved_waker = saved_waker.lock().unwrap(); + saved_waker.clone_from(cx.waker()); + } else { + let saved_waker = Arc::new(Mutex::new(cx.waker().clone())); + self.waker = Some(saved_waker.clone()); + self.inner.set_completed(move || { + saved_waker.lock().unwrap().wake_by_ref(); + })?; + } + Poll::Pending + } + } +} + +impl Drop for FutureWrapper { + fn drop(&mut self) { + self.inner.cancel(); + } +} diff --git a/crates/libs/core/src/lib.rs b/crates/libs/core/src/lib.rs index d8a2b77f87..1a5e8ae03c 100644 --- a/crates/libs/core/src/lib.rs +++ b/crates/libs/core/src/lib.rs @@ -24,6 +24,7 @@ pub mod imp; mod as_impl; mod com_object; +mod future; mod guid; mod inspectable; mod interface; @@ -41,6 +42,7 @@ mod weak; pub use as_impl::*; pub use com_object::*; +pub use future::*; pub use guid::*; pub use inspectable::*; pub use interface::*; diff --git a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs index 64d785eda9..8e7f3daa67 100644 --- a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs +++ b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs @@ -1036,6 +1036,33 @@ impl DeleteSmsMessageOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for DeleteSmsMessageOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for DeleteSmsMessageOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct DeleteSmsMessagesOperation(windows_core::IUnknown); @@ -1122,6 +1149,33 @@ impl DeleteSmsMessagesOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for DeleteSmsMessagesOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for DeleteSmsMessagesOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct GetSmsDeviceOperation(windows_core::IUnknown); @@ -1211,6 +1265,33 @@ impl GetSmsDeviceOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for GetSmsDeviceOperation { + type Output = SmsDevice; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for GetSmsDeviceOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct GetSmsMessageOperation(windows_core::IUnknown); @@ -1299,6 +1380,33 @@ impl GetSmsMessageOperation { self.GetResults() } } +#[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for GetSmsMessageOperation { + type Output = ISmsMessage; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for GetSmsMessageOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] @@ -1407,6 +1515,33 @@ impl GetSmsMessagesOperation { self.GetResults() } } +#[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] +impl windows_core::AsyncOperation for GetSmsMessagesOperation { + type Output = super::super::Foundation::Collections::IVectorView; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +#[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] +impl std::future::IntoFuture for GetSmsMessagesOperation { + type Output = windows_core::Result>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] @@ -1493,6 +1628,33 @@ impl SendSmsMessageOperation { self.GetResults() } } +#[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for SendSmsMessageOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for SendSmsMessageOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct SmsAppMessage(windows_core::IUnknown); diff --git a/crates/libs/windows/src/Windows/Foundation/mod.rs b/crates/libs/windows/src/Windows/Foundation/mod.rs index a414916d61..a08ea00950 100644 --- a/crates/libs/windows/src/Windows/Foundation/mod.rs +++ b/crates/libs/windows/src/Windows/Foundation/mod.rs @@ -78,6 +78,31 @@ impl IAsyncAction { self.GetResults() } } +impl windows_core::AsyncOperation for IAsyncAction { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for IAsyncAction { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncAction {} unsafe impl Sync for IAsyncAction {} impl windows_core::RuntimeType for IAsyncAction { @@ -183,6 +208,31 @@ impl IAsyncActionWithProgress windows_core::AsyncOperation for IAsyncActionWithProgress { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncActionWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for IAsyncActionWithProgress { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncActionWithProgress {} unsafe impl Sync for IAsyncActionWithProgress {} impl windows_core::RuntimeType for IAsyncActionWithProgress { @@ -338,6 +388,31 @@ impl IAsyncOperation { self.GetResults() } } +impl windows_core::AsyncOperation for IAsyncOperation { + type Output = TResult; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for IAsyncOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncOperation {} unsafe impl Sync for IAsyncOperation {} impl windows_core::RuntimeType for IAsyncOperation { @@ -455,6 +530,31 @@ impl windows_core::AsyncOperation for IAsyncOperationWithProgress { + type Output = TResult; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for IAsyncOperationWithProgress { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncOperationWithProgress {} unsafe impl Sync for IAsyncOperationWithProgress {} impl windows_core::RuntimeType for IAsyncOperationWithProgress { diff --git a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs index e82467d540..7fd0df683f 100644 --- a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs +++ b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs @@ -504,6 +504,31 @@ impl SignOutUserOperation { self.GetResults() } } +impl windows_core::AsyncOperation for SignOutUserOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for SignOutUserOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for SignOutUserOperation {} unsafe impl Sync for SignOutUserOperation {} #[repr(transparent)] @@ -587,6 +612,31 @@ impl UserAuthenticationOperation { self.GetResults() } } +impl windows_core::AsyncOperation for UserAuthenticationOperation { + type Output = UserIdentity; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for UserAuthenticationOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for UserAuthenticationOperation {} unsafe impl Sync for UserAuthenticationOperation {} #[repr(transparent)] diff --git a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs index 51059b48e8..3fa547fb79 100644 --- a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs +++ b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs @@ -1325,6 +1325,31 @@ impl DataReaderLoadOperation { self.GetResults() } } +impl windows_core::AsyncOperation for DataReaderLoadOperation { + type Output = u32; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for DataReaderLoadOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for DataReaderLoadOperation {} unsafe impl Sync for DataReaderLoadOperation {} #[repr(transparent)] @@ -1593,6 +1618,31 @@ impl DataWriterStoreOperation { self.GetResults() } } +impl windows_core::AsyncOperation for DataWriterStoreOperation { + type Output = u32; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + fn cancel(&self) { + let _ = self.Cancel(); + } +} +impl std::future::IntoFuture for DataWriterStoreOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for DataWriterStoreOperation {} unsafe impl Sync for DataWriterStoreOperation {} #[repr(transparent)] diff --git a/crates/samples/windows/ocr/Cargo.toml b/crates/samples/windows/ocr/Cargo.toml index 377ff0e03a..e75742439e 100644 --- a/crates/samples/windows/ocr/Cargo.toml +++ b/crates/samples/windows/ocr/Cargo.toml @@ -4,6 +4,9 @@ version = "0.0.0" edition = "2021" publish = false +[dependencies] +futures = "0.3.5" + [dependencies.windows] path = "../../../libs/windows" features = [ diff --git a/crates/samples/windows/ocr/src/main.rs b/crates/samples/windows/ocr/src/main.rs index 4dc85d8d06..7672ece94d 100644 --- a/crates/samples/windows/ocr/src/main.rs +++ b/crates/samples/windows/ocr/src/main.rs @@ -6,18 +6,22 @@ use windows::{ }; fn main() -> Result<()> { + futures::executor::block_on(main_async()) +} + +async fn main_async() -> Result<()> { let mut message = std::env::current_dir().unwrap(); message.push("message.png"); let file = - StorageFile::GetFileFromPathAsync(&HSTRING::from(message.to_str().unwrap()))?.get()?; - let stream = file.OpenAsync(FileAccessMode::Read)?.get()?; + StorageFile::GetFileFromPathAsync(&HSTRING::from(message.to_str().unwrap()))?.await?; + let stream = file.OpenAsync(FileAccessMode::Read)?.await?; - let decode = BitmapDecoder::CreateAsync(&stream)?.get()?; - let bitmap = decode.GetSoftwareBitmapAsync()?.get()?; + let decode = BitmapDecoder::CreateAsync(&stream)?.await?; + let bitmap = decode.GetSoftwareBitmapAsync()?.await?; let engine = OcrEngine::TryCreateFromUserProfileLanguages()?; - let result = engine.RecognizeAsync(&bitmap)?.get()?; + let result = engine.RecognizeAsync(&bitmap)?.await?; println!("{}", result.Text()?); Ok(()) diff --git a/crates/tests/winrt/Cargo.toml b/crates/tests/winrt/Cargo.toml index 10c6bf801f..45916940ed 100644 --- a/crates/tests/winrt/Cargo.toml +++ b/crates/tests/winrt/Cargo.toml @@ -23,10 +23,12 @@ features = [ "Foundation_Numerics", "Storage_Streams", "System", + "System_Threading", "UI_Composition", "Win32_System_Com", "Win32_System_WinRT", ] [dev-dependencies] +futures = "0.3" helpers = { package = "test_helpers", path = "../helpers" } diff --git a/crates/tests/winrt/tests/async.rs b/crates/tests/winrt/tests/async.rs index 8ce9a224ae..08323a6409 100644 --- a/crates/tests/winrt/tests/async.rs +++ b/crates/tests/winrt/tests/async.rs @@ -1,7 +1,13 @@ +use std::future::IntoFuture; + +use futures::{executor::LocalPool, future, task::SpawnExt}; +use windows::{ + Storage::Streams::*, + System::Threading::{ThreadPool, WorkItemHandler}, +}; + #[test] fn async_get() -> windows::core::Result<()> { - use windows::Storage::Streams::*; - let stream = &InMemoryRandomAccessStream::new()?; let writer = DataWriter::CreateDataWriter(stream)?; @@ -23,3 +29,71 @@ fn async_get() -> windows::core::Result<()> { Ok(()) } + +async fn async_await() -> windows::core::Result<()> { + let stream = &InMemoryRandomAccessStream::new()?; + + let writer = DataWriter::CreateDataWriter(stream)?; + writer.WriteByte(1)?; + writer.WriteByte(2)?; + writer.WriteByte(3)?; + writer.StoreAsync()?.await?; + + stream.Seek(0)?; + let reader = DataReader::CreateDataReader(stream)?; + reader.LoadAsync(3)?.await?; + + let mut bytes: [u8; 3] = [0; 3]; + reader.ReadBytes(&mut bytes)?; + + assert!(bytes[0] == 1); + assert!(bytes[1] == 2); + assert!(bytes[2] == 3); + + Ok(()) +} + +#[test] +fn test_async_await() -> windows::core::Result<()> { + futures::executor::block_on(async_await()) +} + +#[test] +fn test_async_updates_waker() -> windows::core::Result<()> { + let mut pool = LocalPool::new(); + + let (tx, rx) = std::sync::mpsc::channel::<()>(); + + let winrt_future = ThreadPool::RunAsync(&WorkItemHandler::new(move |_| { + rx.recv().unwrap(); + Ok(()) + }))? + .into_future(); + + let task = pool + .spawner() + .spawn_with_handle(async move { + // Poll the future once on a LocalPool task + match future::select(winrt_future, future::ready(())).await { + future::Either::Left(_) => panic!("threadpool action can't finish yet"), + future::Either::Right(((), future)) => future, + } + }) + .unwrap(); + let winrt_future = pool.run_until(task); + + pool.spawner() + .spawn(async move { + // Now run the future to completion on a *different* LocalPool task. + // This will hang unless winrt_future properly updates its saved waker to the new task. + let (result, ()) = future::join(winrt_future, async { + tx.send(()).unwrap(); + }) + .await; + result.unwrap(); + }) + .unwrap(); + pool.run(); + + Ok(()) +}