diff --git a/inc/cxplat.hpp b/inc/cxplat.hpp new file mode 100644 index 0000000..5347bdf --- /dev/null +++ b/inc/cxplat.hpp @@ -0,0 +1,90 @@ +/*++ + + Copyright (c) Microsoft Corporation. + Licensed under the MIT License. + +Abstract: + + C++ header only cxplat wrappers + +--*/ + +#ifdef _WIN32 +#pragma once +#endif + +#ifndef CXPLATCPP_H +#define CXPLATCPP_H + +#include "cxplat.h" +#include "cxplat_sal_stub.h" + +typedef void* CxPlatCallback( + _Inout_ void* Context +); + +class CxPlatAsync { +private: + struct CxPlatAsyncContext { + void* UserContext; + CxPlatCallback *UserCallback; + void* ReturnValue; + }; + + static CXPLAT_THREAD_CALLBACK(CxPlatAsyncWrapperCallback, Context) + { + struct CxPlatAsyncContext* AsyncContext = (struct CxPlatAsyncContext*)Context; + AsyncContext->ReturnValue = AsyncContext->UserCallback(AsyncContext->UserContext); + CXPLAT_THREAD_RETURN(0); + } + + CXPLAT_THREAD Thread {0}; + CXPLAT_THREAD_CONFIG ThreadConfig {0}; + struct CxPlatAsyncContext AsyncContext {0}; + bool Initialized = false; + bool ThreadCompleted = false; +public: + CxPlatAsync(CxPlatCallback Callback, void* UserContext = nullptr) noexcept { + AsyncContext.UserContext = UserContext; + AsyncContext.UserCallback = Callback; + AsyncContext.ReturnValue = nullptr; + + ThreadConfig.Name = "CxPlatAsync"; + ThreadConfig.Callback = CxPlatAsyncWrapperCallback; + ThreadConfig.Context = &AsyncContext; + if (CxPlatThreadCreate(&ThreadConfig, &Thread) != 0) { + Initialized = false; + return; + } + Initialized = true; + } + ~CxPlatAsync() noexcept { + if (Initialized) { + if (!ThreadCompleted) { + CxPlatThreadWaitForever(&Thread); + } + CxPlatThreadDelete(&Thread); + } + } + + void Wait() noexcept { + if (Initialized) { + CxPlatThreadWaitForever(&Thread); + ThreadCompleted = true; + } + } + +#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) + void WaitFor(uint32_t TimeoutMs) noexcept { + if (Initialized) { + CxPlatThreadWaitWithTimeout(&Thread, TimeoutMs); + } + } +#endif + + void* Get() noexcept { + return AsyncContext.ReturnValue; + } +}; + +#endif diff --git a/src/test/CxPlatTests.h b/src/test/CxPlatTests.h index 7c1ff50..b63f1c0 100644 --- a/src/test/CxPlatTests.h +++ b/src/test/CxPlatTests.h @@ -54,6 +54,7 @@ void CxPlatTestProcBasic(); // void CxPlatTestThreadBasic(); +void CxPlatTestThreadAsync(); #if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) void CxPlatTestThreadWaitTimeout(); #endif @@ -141,4 +142,7 @@ static const GUID CXPLAT_TEST_DEVICE_INSTANCE = #define IOCTL_CXPLAT_RUN_VECTOR_BASIC \ CXPLAT_CTL_CODE(8, METHOD_BUFFERED, FILE_WRITE_DATA) -#define CXPLAT_MAX_IOCTL_FUNC_CODE 8 +#define IOCTL_CXPLAT_RUN_THREAD_ASYNC \ + CXPLAT_CTL_CODE(9, METHOD_BUFFERED, FILE_WRITE_DATA) + +#define CXPLAT_MAX_IOCTL_FUNC_CODE 9 diff --git a/src/test/bin/cxplat_gtest.cpp b/src/test/bin/cxplat_gtest.cpp index badec4f..64c2973 100644 --- a/src/test/bin/cxplat_gtest.cpp +++ b/src/test/bin/cxplat_gtest.cpp @@ -170,6 +170,15 @@ TEST(ThreadSuite, Basic) { } } +TEST(ThreadSuite, Async) { + TestLogger Logger("CxPlatTestThreadAsync"); + if (TestingKernelMode) { + ASSERT_TRUE(DriverClient.Run(IOCTL_CXPLAT_RUN_THREAD_ASYNC)); + } else { + CxPlatTestThreadAsync(); + } +} + #if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) TEST(ThreadSuite, WithTimeout) { TestLogger Logger("CxPlatTestThreadWaitTimeout"); diff --git a/src/test/bin/winkernel/control.cpp b/src/test/bin/winkernel/control.cpp index cc6da2c..a7a4ce6 100644 --- a/src/test/bin/winkernel/control.cpp +++ b/src/test/bin/winkernel/control.cpp @@ -366,6 +366,7 @@ size_t CXPLAT_IOCTL_BUFFER_SIZES[] = 0, 0, 0, + 0, }; static_assert( @@ -500,6 +501,10 @@ CxPlatTestCtlEvtIoDeviceControl( CxPlatTestCtlRun(CxPlatTestThreadBasic()); break; + case IOCTL_CXPLAT_RUN_THREAD_ASYNC: + CxPlatTestCtlRun(CxPlatTestThreadAsync()); + break; + case IOCTL_CXPLAT_RUN_THREAD_WAIT_TIMEOUT: CxPlatTestCtlRun(CxPlatTestThreadWaitTimeout()); break; diff --git a/src/test/lib/ThreadTest.cpp b/src/test/lib/ThreadTest.cpp index d51e0cb..a9734df 100644 --- a/src/test/lib/ThreadTest.cpp +++ b/src/test/lib/ThreadTest.cpp @@ -61,6 +61,55 @@ void CxPlatTestThreadBasic() CxPlatThreadDelete(&Thread); } +void CxPlatTestThreadAsync() +{ + { + CxPlatAsync Async([](void*) -> void* { + return nullptr; + }); + } + + { + struct TempCtx { + uint32_t Value; + } TempCtx = { 0 }; + CxPlatAsync Async([](void* Ctx) -> void* { + struct TempCtx* TempCtx = (struct TempCtx*)Ctx; + TempCtx->Value = 123; + return nullptr; + }, &TempCtx); + Async.Wait(); + TEST_EQUAL(123, TempCtx.Value); + } + + { + CXPLAT_THREAD_ID ThreadId = INITIAL_THREAD_ID_VALUE; + CxPlatAsync Async([](void* Ctx) -> void* { + CXPLAT_THREAD_ID* ThreadId = (CXPLAT_THREAD_ID*)Ctx; + *ThreadId = CxPlatCurThreadID(); + return (void*)(intptr_t)(*ThreadId); + }, &ThreadId); + + Async.Wait(); + TEST_EQUAL((CXPLAT_THREAD_ID)((intptr_t)Async.Get()), ThreadId); + TEST_NOT_EQUAL(INITIAL_THREAD_ID_VALUE, ThreadId); + } + +#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) + { + CxPlatAsync Async([](void*) -> void* { + CxPlatSleep(2000); + return (void*)(intptr_t)(0xdeadbeaf); + }); + + Async.WaitFor(50); + TEST_EQUAL(Async.Get(), nullptr); + Async.Wait(); + TEST_NOT_EQUAL(Async.Get(), nullptr); + } +#endif +} + #if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) void CxPlatTestThreadWaitTimeout() { diff --git a/src/test/lib/precomp.h b/src/test/lib/precomp.h index 1fb7648..6151454 100644 --- a/src/test/lib/precomp.h +++ b/src/test/lib/precomp.h @@ -45,6 +45,7 @@ #include "TestAbstractionLayer.h" #include "cxplatvector.h" +#include "cxplat.hpp" #if defined(_ARM64_) || defined(_ARM64EC_) #pragma optimize("", off)