Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CxPlatAsync to simplify creating threads with cxplat #49

Merged
merged 10 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions inc/cxplat.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*++

Copyright (c) Microsoft Corporation.
Licensed under the MIT License.

Abstract:

C++ header only cxplat wrappers

--*/
csujedihy marked this conversation as resolved.
Show resolved Hide resolved

#ifdef _WIN32
#pragma once
#endif

#ifndef CXPLATCPP_H
#define CXPLATCPP_H

#include "cxplat.h"
#include "cxplat_sal_stub.h"

typedef void* CxPlatCallback(
_Inout_ void* Context
);

class CxPlatAsync {
private:
struct CxPlatAsyncContext {
void* UserContext;
CxPlatCallback *UserCallback;
void* ReturnValue;
};

static CXPLAT_THREAD_CALLBACK(CxPlatAsyncWrapperCallback, Context)
{
struct CxPlatAsyncContext* AsyncContext = (struct CxPlatAsyncContext*)Context;
AsyncContext->ReturnValue = AsyncContext->UserCallback(AsyncContext->UserContext);
CXPLAT_THREAD_RETURN(0);
}

CXPLAT_THREAD Thread {0};
CXPLAT_THREAD_CONFIG ThreadConfig {0};
struct CxPlatAsyncContext AsyncContext {0};
bool Initialized = false;
bool ThreadCompleted = false;
public:
CxPlatAsync(CxPlatCallback Callback, void* UserContext = nullptr) noexcept {
AsyncContext.UserContext = UserContext;
AsyncContext.UserCallback = Callback;
AsyncContext.ReturnValue = nullptr;
nibanks marked this conversation as resolved.
Show resolved Hide resolved

ThreadConfig.Name = "CxPlatAsync";
ThreadConfig.Callback = CxPlatAsyncWrapperCallback;
ThreadConfig.Context = &AsyncContext;
if (CxPlatThreadCreate(&ThreadConfig, &Thread) != 0) {
Initialized = false;
return;
}
Initialized = true;
}
~CxPlatAsync() noexcept {
if (Initialized) {
if (!ThreadCompleted) {
CxPlatThreadWaitForever(&Thread);
}
CxPlatThreadDelete(&Thread);
}
}

void Wait() noexcept {
if (Initialized) {
CxPlatThreadWaitForever(&Thread);
ThreadCompleted = true;
}
}

#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL)
nibanks marked this conversation as resolved.
Show resolved Hide resolved
void WaitFor(uint32_t TimeoutMs) noexcept {
if (Initialized) {
CxPlatThreadWaitWithTimeout(&Thread, TimeoutMs);
}
}
#endif

void* Get() noexcept {
return AsyncContext.ReturnValue;
}
};

#endif
6 changes: 5 additions & 1 deletion src/test/CxPlatTests.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ void CxPlatTestProcBasic();
//

void CxPlatTestThreadBasic();
void CxPlatTestThreadAsync();
#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL)
void CxPlatTestThreadWaitTimeout();
#endif
Expand Down Expand Up @@ -141,4 +142,7 @@ static const GUID CXPLAT_TEST_DEVICE_INSTANCE =
#define IOCTL_CXPLAT_RUN_VECTOR_BASIC \
CXPLAT_CTL_CODE(8, METHOD_BUFFERED, FILE_WRITE_DATA)

#define CXPLAT_MAX_IOCTL_FUNC_CODE 8
#define IOCTL_CXPLAT_RUN_THREAD_ASYNC \
CXPLAT_CTL_CODE(9, METHOD_BUFFERED, FILE_WRITE_DATA)

#define CXPLAT_MAX_IOCTL_FUNC_CODE 9
9 changes: 9 additions & 0 deletions src/test/bin/cxplat_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ TEST(ThreadSuite, Basic) {
}
}

TEST(ThreadSuite, Async) {
TestLogger Logger("CxPlatTestThreadAsync");
if (TestingKernelMode) {
ASSERT_TRUE(DriverClient.Run(IOCTL_CXPLAT_RUN_THREAD_ASYNC));
} else {
CxPlatTestThreadAsync();
}
}

#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL)
TEST(ThreadSuite, WithTimeout) {
TestLogger Logger("CxPlatTestThreadWaitTimeout");
Expand Down
5 changes: 5 additions & 0 deletions src/test/bin/winkernel/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ size_t CXPLAT_IOCTL_BUFFER_SIZES[] =
0,
0,
0,
0,
};

static_assert(
Expand Down Expand Up @@ -500,6 +501,10 @@ CxPlatTestCtlEvtIoDeviceControl(
CxPlatTestCtlRun(CxPlatTestThreadBasic());
break;

case IOCTL_CXPLAT_RUN_THREAD_ASYNC:
CxPlatTestCtlRun(CxPlatTestThreadAsync());
break;

case IOCTL_CXPLAT_RUN_THREAD_WAIT_TIMEOUT:
CxPlatTestCtlRun(CxPlatTestThreadWaitTimeout());
break;
Expand Down
49 changes: 49 additions & 0 deletions src/test/lib/ThreadTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,55 @@ void CxPlatTestThreadBasic()
CxPlatThreadDelete(&Thread);
}

void CxPlatTestThreadAsync()
{
{
CxPlatAsync Async([](void*) -> void* {
return nullptr;
});
}

{
struct TempCtx {
uint32_t Value;
} TempCtx = { 0 };
CxPlatAsync Async([](void* Ctx) -> void* {
struct TempCtx* TempCtx = (struct TempCtx*)Ctx;
TempCtx->Value = 123;
return nullptr;
}, &TempCtx);
Async.Wait();
TEST_EQUAL(123, TempCtx.Value);
}

{
CXPLAT_THREAD_ID ThreadId = INITIAL_THREAD_ID_VALUE;
CxPlatAsync Async([](void* Ctx) -> void* {
CXPLAT_THREAD_ID* ThreadId = (CXPLAT_THREAD_ID*)Ctx;
*ThreadId = CxPlatCurThreadID();
return (void*)(intptr_t)(*ThreadId);
}, &ThreadId);

Async.Wait();
TEST_EQUAL((CXPLAT_THREAD_ID)((intptr_t)Async.Get()), ThreadId);
TEST_NOT_EQUAL(INITIAL_THREAD_ID_VALUE, ThreadId);
}

#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL)
{
CxPlatAsync Async([](void*) -> void* {
CxPlatSleep(2000);
return (void*)(intptr_t)(0xdeadbeaf);
});

Async.WaitFor(50);
TEST_EQUAL(Async.Get(), nullptr);
Async.Wait();
TEST_NOT_EQUAL(Async.Get(), nullptr);
}
#endif
}

#if defined(CX_PLATFORM_WINUSER) || defined(CX_PLATFORM_WINKERNEL)
void CxPlatTestThreadWaitTimeout()
{
Expand Down
1 change: 1 addition & 0 deletions src/test/lib/precomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "TestAbstractionLayer.h"

#include "cxplatvector.h"
#include "cxplat.hpp"

#if defined(_ARM64_) || defined(_ARM64EC_)
#pragma optimize("", off)
Expand Down