From b159a2e2e4b346370ed1e2395a4c42eb35925313 Mon Sep 17 00:00:00 2001 From: Nick Grifka Date: Fri, 24 May 2024 14:51:24 -0700 Subject: [PATCH] Add thread wait with timeout API (#28) Adding this functionality to Windows platforms since the XDP test code waits on threads with a timeout. Added test case. --- inc/cxplat_posix.h | 2 +- inc/cxplat_winkernel.h | 17 ++++++++- inc/cxplat_winuser.h | 12 +++++- src/lib/cxplat_posix.c | 2 +- src/test/CxPlatTests.h | 8 +++- src/test/bin/cxplat_gtest.cpp | 11 ++++++ src/test/bin/winkernel/control.cpp | 5 +++ src/test/lib/ThreadTest.cpp | 60 ++++++++++++++++++++++++++---- 8 files changed, 105 insertions(+), 12 deletions(-) diff --git a/inc/cxplat_posix.h b/inc/cxplat_posix.h index 05592fe..203179c 100644 --- a/inc/cxplat_posix.h +++ b/inc/cxplat_posix.h @@ -500,7 +500,7 @@ CxPlatThreadDelete( ); void -CxPlatThreadWait( +CxPlatThreadWaitForever( _Inout_ CXPLAT_THREAD* Thread ); diff --git a/inc/cxplat_winkernel.h b/inc/cxplat_winkernel.h index a76a87e..5bfcc55 100644 --- a/inc/cxplat_winkernel.h +++ b/inc/cxplat_winkernel.h @@ -389,6 +389,19 @@ typedef struct _ETHREAD *CXPLAT_THREAD; #define CXPLAT_THREAD_RETURN(Status) PsTerminateSystemThread(Status) +inline +NTSTATUS +CxPlatInternalThreadWaitWithTimeout( + _In_ CXPLAT_THREAD* Thread, + _In_ uint32_t TimeoutMs + ) +{ + LARGE_INTEGER Timeout100Ns; + CXPLAT_DBG_ASSERT(TimeoutMs != UINT32_MAX); + Timeout100Ns.QuadPart = -1 * UInt32x32To64(TimeoutMs, 10000); + return KeWaitForSingleObject(*(Thread), Executive, KernelMode, FALSE, &Timeout100Ns); +} + inline CXPLAT_STATUS CxPlatThreadCreate( @@ -520,13 +533,15 @@ CxPlatThreadCreate( return Status; } #define CxPlatThreadDelete(Thread) ObDereferenceObject(*(Thread)) -#define CxPlatThreadWait(Thread) \ +#define CxPlatThreadWaitForever(Thread) \ KeWaitForSingleObject( \ *(Thread), \ Executive, \ KernelMode, \ FALSE, \ NULL) +#define CxPlatThreadWaitWithTimeout(Thread, TimeoutMs) \ + (STATUS_SUCCESS == CxPlatInternalThreadWaitWithTimeout(Thread, TimeoutMs)) typedef ULONG_PTR CXPLAT_THREAD_ID; #define CxPlatCurThreadID() ((CXPLAT_THREAD_ID)PsGetCurrentThreadId()) diff --git a/inc/cxplat_winuser.h b/inc/cxplat_winuser.h index 6c025fc..29986de 100644 --- a/inc/cxplat_winuser.h +++ b/inc/cxplat_winuser.h @@ -538,7 +538,17 @@ CxPlatThreadCreate( return CXPLAT_STATUS_SUCCESS; } #define CxPlatThreadDelete(Thread) CxPlatCloseHandle(*(Thread)) -#define CxPlatThreadWait(Thread) WaitForSingleObject(*(Thread), INFINITE) +#define CxPlatThreadWaitForever(Thread) WaitForSingleObject(*(Thread), INFINITE) +inline +BOOLEAN +CxPlatThreadWaitWithTimeout( + _In_ CXPLAT_THREAD* Thread, + _In_ uint32_t TimeoutMs + ) +{ + CXPLAT_DBG_ASSERT(TimeoutMs != UINT32_MAX); + return WAIT_OBJECT_0 == WaitForSingleObject(*Thread, TimeoutMs); +} typedef uint32_t CXPLAT_THREAD_ID; #define CxPlatCurThreadID() GetCurrentThreadId() diff --git a/src/lib/cxplat_posix.c b/src/lib/cxplat_posix.c index 663c9ae..f5c0843 100644 --- a/src/lib/cxplat_posix.c +++ b/src/lib/cxplat_posix.c @@ -482,7 +482,7 @@ CxPlatThreadDelete( } void -CxPlatThreadWait( +CxPlatThreadWaitForever( _Inout_ CXPLAT_THREAD* Thread ) { diff --git a/src/test/CxPlatTests.h b/src/test/CxPlatTests.h index 66dcfdd..71d123c 100644 --- a/src/test/CxPlatTests.h +++ b/src/test/CxPlatTests.h @@ -54,6 +54,9 @@ void CxPlatTestProcBasic(); // void CxPlatTestThreadBasic(); +#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) +void CxPlatTestThreadWaitTimeout(); +#endif // // Platform Specific Functions @@ -127,4 +130,7 @@ static const GUID CXPLAT_TEST_DEVICE_INSTANCE = #define IOCTL_CXPLAT_RUN_THREAD_BASIC \ CXPLAT_CTL_CODE(6, METHOD_BUFFERED, FILE_WRITE_DATA) -#define CXPLAT_MAX_IOCTL_FUNC_CODE 6 +#define IOCTL_CXPLAT_RUN_THREAD_WAIT_TIMEOUT \ + CXPLAT_CTL_CODE(7, METHOD_BUFFERED, FILE_WRITE_DATA) + +#define CXPLAT_MAX_IOCTL_FUNC_CODE 7 diff --git a/src/test/bin/cxplat_gtest.cpp b/src/test/bin/cxplat_gtest.cpp index 6beb1cf..fc66249 100644 --- a/src/test/bin/cxplat_gtest.cpp +++ b/src/test/bin/cxplat_gtest.cpp @@ -170,6 +170,17 @@ TEST(ThreadSuite, Basic) { } } +#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) +TEST(ThreadSuite, WithTimeout) { + TestLogger Logger("CxPlatTestThreadWaitTimeout"); + if (TestingKernelMode) { + ASSERT_TRUE(DriverClient.Run(IOCTL_CXPLAT_RUN_THREAD_WAIT_TIMEOUT)); + } else { + CxPlatTestThreadWaitTimeout(); + } +} +#endif + int main(int argc, char** argv) { for (int i = 0; i < argc; ++i) { if (strcmp("--kernel", argv[i]) == 0) { diff --git a/src/test/bin/winkernel/control.cpp b/src/test/bin/winkernel/control.cpp index 0186426..9c9015f 100644 --- a/src/test/bin/winkernel/control.cpp +++ b/src/test/bin/winkernel/control.cpp @@ -364,6 +364,7 @@ size_t CXPLAT_IOCTL_BUFFER_SIZES[] = 0, 0, 0, + 0, }; static_assert( @@ -498,6 +499,10 @@ CxPlatTestCtlEvtIoDeviceControl( CxPlatTestCtlRun(CxPlatTestThreadBasic()); break; + case IOCTL_CXPLAT_RUN_THREAD_WAIT_TIMEOUT: + CxPlatTestCtlRun(CxPlatTestThreadWaitTimeout()); + break; + default: Status = STATUS_NOT_IMPLEMENTED; break; diff --git a/src/test/lib/ThreadTest.cpp b/src/test/lib/ThreadTest.cpp index 14784ee..d51e0cb 100644 --- a/src/test/lib/ThreadTest.cpp +++ b/src/test/lib/ThreadTest.cpp @@ -11,13 +11,24 @@ #include "precomp.h" -#define INITIAL_CONTEXT_VALUE ((CXPLAT_THREAD_ID)-1) +typedef struct THREAD_CONTEXT { + CXPLAT_THREAD_ID ThreadId; + uint32_t DelayMs; +} THREAD_CONTEXT; -CXPLAT_THREAD_CALLBACK(ThreadFn, Context) +#define INITIAL_THREAD_ID_VALUE ((CXPLAT_THREAD_ID)-1) + +CXPLAT_THREAD_CALLBACK(ThreadFn, Ctx) { - TEST_EQUAL_GOTO(*(CXPLAT_THREAD_ID*)Context, INITIAL_CONTEXT_VALUE); + THREAD_CONTEXT* Context = (THREAD_CONTEXT*)Ctx; + + if (Context->DelayMs > 0) { + CxPlatSleep(Context->DelayMs); + } - *(CXPLAT_THREAD_ID*)Context = CxPlatCurThreadID(); + TEST_EQUAL_GOTO(Context->ThreadId, INITIAL_THREAD_ID_VALUE); + + *(CXPLAT_THREAD_ID*)Ctx = CxPlatCurThreadID(); Failure: @@ -28,7 +39,10 @@ void CxPlatTestThreadBasic() { CXPLAT_THREAD Thread; CXPLAT_THREAD_CONFIG ThreadConfig; - CXPLAT_THREAD_ID Context = INITIAL_CONTEXT_VALUE; + THREAD_CONTEXT Context; + + Context.ThreadId = INITIAL_THREAD_ID_VALUE; + Context.DelayMs = 0; ThreadConfig.Flags = 0; ThreadConfig.IdealProcessor = 0; @@ -38,9 +52,41 @@ void CxPlatTestThreadBasic() TEST_CXPLAT(CxPlatThreadCreate(&ThreadConfig, &Thread)); - CxPlatThreadWait(&Thread); + CxPlatThreadWaitForever(&Thread); + + TEST_TRUE_GOTO(Context.ThreadId != INITIAL_THREAD_ID_VALUE); + +Failure: + + CxPlatThreadDelete(&Thread); +} + +#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL) +void CxPlatTestThreadWaitTimeout() +{ + CXPLAT_THREAD Thread; + CXPLAT_THREAD_CONFIG ThreadConfig; + THREAD_CONTEXT Context; + + Context.ThreadId = INITIAL_THREAD_ID_VALUE; + Context.DelayMs = 1000; + + ThreadConfig.Flags = 0; + ThreadConfig.IdealProcessor = 0; + ThreadConfig.Name = "CxPlatTestThreadWaitTimeout"; + ThreadConfig.Callback = ThreadFn; + ThreadConfig.Context = (void*)&Context; + + TEST_CXPLAT(CxPlatThreadCreate(&ThreadConfig, &Thread)); + + TEST_FALSE_GOTO(CxPlatThreadWaitWithTimeout(&Thread, 50)); - TEST_TRUE(Context != INITIAL_CONTEXT_VALUE); + TEST_TRUE_GOTO(CxPlatThreadWaitWithTimeout(&Thread, 5000)); + + TEST_TRUE_GOTO(Context.ThreadId != INITIAL_THREAD_ID_VALUE); + +Failure: CxPlatThreadDelete(&Thread); } +#endif