diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2023-09-12 12:10:47 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2023-09-12 12:10:47 +0200 |
commit | 85c7ff11bdaf7f6a5d230744f71c5b9b7bc9bf6f (patch) | |
tree | b915b402fb5f20ab114d9286f98be31a4987254b | |
parent | 6e8f68f653a832491d9bf05f06bfab81aea0a9cb (diff) |
Added ThreadArgs, Event, WorkItem.
* added/modified WorkQueue example
* enable additional EASTL features in user space
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
-rw-r--r-- | CRT/DriverThread.cpp | 138 | ||||
-rw-r--r-- | CRT/DriverThread.hpp | 96 | ||||
-rw-r--r-- | Makefile.deps | 17 | ||||
-rw-r--r-- | Makefile.native.inc | 11 | ||||
-rw-r--r-- | examples/dpp-example-cplusplus-EASTL.cpp | 1 | ||||
-rw-r--r-- | examples/dpp-example-cplusplus.cpp | 93 |
6 files changed, 252 insertions, 104 deletions
diff --git a/CRT/DriverThread.cpp b/CRT/DriverThread.cpp index ecedc48..3a77dc8 100644 --- a/CRT/DriverThread.cpp +++ b/CRT/DriverThread.cpp @@ -1,5 +1,21 @@ #include <DriverThread.hpp> +class WorkQueueArgs : public DriverThread::ThreadArgs +{ + friend class WorkQueue; + +public: + WorkQueueArgs(DriverThread::WorkQueue * wq) : m_wq(wq){}; + WorkQueueArgs(const WorkQueueArgs &) = delete; + DriverThread::WorkQueue * getWorkQueue() + { + return m_wq; + } + +private: + DriverThread::WorkQueue * m_wq; +}; + // Thread DriverThread::Thread::Thread(void) @@ -13,13 +29,17 @@ DriverThread::Thread::~Thread(void) extern "C" void InterceptorThreadRoutine(PVOID threadContext) { + NTSTATUS threadReturn; DriverThread::Thread * self = (DriverThread::Thread *)threadContext; self->m_threadId = PsGetCurrentThreadId(); - PsTerminateSystemThread(self->m_routine(self->m_threadContext)); + threadReturn = self->m_routine(self->m_threadContext); + self->m_threadId = nullptr; + self->m_threadContext = nullptr; + PsTerminateSystemThread(threadReturn); } -NTSTATUS DriverThread::Thread::Start(threadRoutine_t routine, PVOID threadContext) +NTSTATUS DriverThread::Thread::Start(ThreadRoutine routine, eastl::shared_ptr<ThreadArgs> args) { HANDLE threadHandle; NTSTATUS status; @@ -31,7 +51,7 @@ NTSTATUS DriverThread::Thread::Start(threadRoutine_t routine, PVOID threadContex } m_routine = routine; - m_threadContext = threadContext; + m_threadContext = args; status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this); if (!NT_SUCCESS(status)) @@ -72,11 +92,6 @@ NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout) return status; } -HANDLE DriverThread::Thread::GetThreadId(void) -{ - return m_threadId; -} - // Spinlock DriverThread::Spinlock::Spinlock(void) @@ -117,6 +132,24 @@ LONG DriverThread::Semaphore::Release(LONG adjustment) return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE); } +// Event + +DriverThread::Event::Event() +{ + KeInitializeEvent(&m_event, NotificationEvent, FALSE); +} + +NTSTATUS DriverThread::Event::Wait(LONGLONG timeout) +{ + LARGE_INTEGER li_timeout = {.QuadPart = timeout}; + return KeWaitForSingleObject(&m_event, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout)); +} + +NTSTATUS DriverThread::Event::Notify() +{ + return KeSetEvent(&m_event, 0, FALSE); +} + // Mutex DriverThread::Mutex::Mutex(void) @@ -151,11 +184,9 @@ DriverThread::LockGuard::~LockGuard(void) // WorkQueue -DriverThread::WorkQueue::WorkQueue(void) : m_worker() +DriverThread::WorkQueue::WorkQueue(void) + : m_mutex(), m_queue(), m_wakeEvent(), m_stopWorker(false), m_worker(), m_workerRoutine(nullptr) { - InitializeSListHead(&m_work); - KeInitializeEvent(&m_wakeEvent, SynchronizationEvent, FALSE); - m_stopWorker = FALSE; } DriverThread::WorkQueue::~WorkQueue(void) @@ -163,14 +194,15 @@ DriverThread::WorkQueue::~WorkQueue(void) Stop(); } -NTSTATUS DriverThread::WorkQueue::Start(workerRoutine_t workerRoutine) +NTSTATUS DriverThread::WorkQueue::Start(WorkerRoutine routine) { NTSTATUS status; { LockGuard lock(m_mutex); - m_workerRoutine = workerRoutine; - status = m_worker.Start(WorkerInterceptorRoutine, this); + m_workerRoutine = routine; + auto wqa = eastl::make_shared<WorkQueueArgs>(this); + status = m_worker.Start(WorkerInterceptorRoutine, wqa); } if (!NT_SUCCESS(status) && status != STATUS_UNSUCCESSFUL) @@ -181,73 +213,83 @@ NTSTATUS DriverThread::WorkQueue::Start(workerRoutine_t workerRoutine) return status; } -void DriverThread::WorkQueue::Stop(void) +void DriverThread::WorkQueue::Stop(bool wait) { LockGuard lock(m_mutex); - if (m_stopWorker == TRUE) + if (m_stopWorker == true) { return; } - m_stopWorker = TRUE; - KeSetEvent(&m_wakeEvent, 0, FALSE); + m_stopWorker = true; + m_wakeEvent.Notify(); + if (wait) + { + m_worker.WaitForTermination(); + } +} + +void DriverThread::WorkQueue::Enqueue(WorkItem & item) +{ + { + LockGuard lock(m_mutex); + m_queue.emplace_back(item); + } + m_wakeEvent.Notify(); } -void DriverThread::WorkQueue::Enqueue(WorkItem * item) +void DriverThread::WorkQueue::Enqueue(eastl::deque<WorkItem> & items) { - if (InterlockedPushEntrySList(&m_work, &item->QueueEntry) == NULL) { - // Work queue was empty. So, signal the work queue event in case the - // worker thread is waiting on the event for more operations. - KeSetEvent(&m_wakeEvent, 0, FALSE); + LockGuard lock(m_mutex); + m_queue.insert(m_queue.end(), items.begin(), items.end()); } + m_wakeEvent.Notify(); } -NTSTATUS DriverThread::WorkQueue::WorkerInterceptorRoutine(PVOID workerContext) +NTSTATUS DriverThread::WorkQueue::WorkerInterceptorRoutine(eastl::shared_ptr<ThreadArgs> args) { - DriverThread::WorkQueue * wq = (DriverThread::WorkQueue *)workerContext; - PSLIST_ENTRY listEntryRev, listEntry, next; + auto wqa = eastl::static_pointer_cast<WorkQueueArgs>(args); + WorkQueue * wq = wqa->getWorkQueue(); PAGED_CODE(); for (;;) { - // Flush all the queued operations into a local list - listEntryRev = InterlockedFlushSList(&wq->m_work); + eastl::deque<WorkItem> doQueue; + std::size_t nItems; - if (listEntryRev == NULL) { + LockGuard lock(wq->m_mutex); + nItems = wq->m_queue.size(); + } - // There's no work to do. If we are allowed to stop, then stop. - if (wq->m_stopWorker == TRUE) + if (nItems == 0) + { + if (wq->m_stopWorker == true) { break; } - // Otherwise, wait for more operations to be enqueued. - KeWaitForSingleObject(&wq->m_wakeEvent, Executive, KernelMode, FALSE, 0); + wq->m_wakeEvent.Wait(); continue; } - // Need to reverse the flushed list in order to preserve the FIFO order - listEntry = NULL; - while (listEntryRev != NULL) { - next = listEntryRev->Next; - listEntryRev->Next = listEntry; - listEntry = listEntryRev; - listEntryRev = next; + LockGuard lock(wq->m_mutex); + doQueue = wq->m_queue; + wq->m_queue.clear(); } - // Now process the correctly ordered list of operations one by one - while (listEntry) + while (doQueue.size() > 0) { - PSLIST_ENTRY arg = listEntry; - listEntry = listEntry->Next; - DriverThread::WorkItem * wi = CONTAINING_RECORD(arg, DriverThread::WorkItem, QueueEntry); - if (wq->m_workerRoutine(wi) != STATUS_SUCCESS) + WorkItem & item = doQueue.front(); + + if (wq->m_workerRoutine(item) != STATUS_SUCCESS) { - wq->m_stopWorker = TRUE; + wq->m_stopWorker = true; } + + doQueue.pop_front(); } } diff --git a/CRT/DriverThread.hpp b/CRT/DriverThread.hpp index d863a66..ac2eddd 100644 --- a/CRT/DriverThread.hpp +++ b/CRT/DriverThread.hpp @@ -3,14 +3,14 @@ #include <ntddk.h> +#include <EASTL/deque.h> +#include <EASTL/functional.h> +#include <EASTL/shared_ptr.h> + extern "C" void InterceptorThreadRoutine(PVOID threadContext); namespace DriverThread { -class WorkItem; -typedef NTSTATUS (*threadRoutine_t)(PVOID); -typedef NTSTATUS (*workerRoutine_t)(WorkItem * item); - class Mutex { public: @@ -36,14 +36,36 @@ private: Mutex m_Lock; }; +class ThreadArgs : public virtual eastl::enable_shared_from_this<ThreadArgs> +{ +public: + ThreadArgs(void) + { + } + ThreadArgs(const ThreadArgs &) = delete; + virtual ~ThreadArgs(void) + { + } +}; + +using ThreadRoutine = eastl::function<NTSTATUS(eastl::shared_ptr<ThreadArgs> args)>; + class Thread { public: Thread(void); + Thread(const Thread &) = delete; ~Thread(void); - NTSTATUS Start(threadRoutine_t routine, PVOID threadContext); + NTSTATUS Start(ThreadRoutine routine, eastl::shared_ptr<ThreadArgs> args); NTSTATUS WaitForTermination(LONGLONG timeout = 0); - HANDLE GetThreadId(void); + HANDLE GetThreadId(void) + { + return m_threadId; + } + bool isRunning(void) + { + return GetThreadId() != nullptr; + } private: friend void ::InterceptorThreadRoutine(PVOID threadContext); @@ -51,8 +73,8 @@ private: HANDLE m_threadId = nullptr; PETHREAD m_threadObject = nullptr; Mutex m_mutex; - threadRoutine_t m_routine; - PVOID m_threadContext; + ThreadRoutine m_routine; + eastl::shared_ptr<ThreadArgs> m_threadContext; }; class Spinlock @@ -79,31 +101,65 @@ private: KSEMAPHORE m_semaphore; }; -class WorkItem +class Event +{ +public: + Event(); + NTSTATUS Wait(LONGLONG timeout = 0); + NTSTATUS Notify(); + +private: + KEVENT m_event; +}; + +class WorkItem final { + friend class WorkQueue; + public: - SLIST_ENTRY QueueEntry; - PSLIST_ENTRY WorkListEntry; + WorkItem(const eastl::shared_ptr<void> & user) : m_user(std::move(user)) + { + } + virtual ~WorkItem(void) + { + } + template <class T> + eastl::shared_ptr<T> Get(void) + { + return eastl::static_pointer_cast<T>(m_user); + } + template <class T> + void Get(eastl::shared_ptr<T> & dest) + { + dest = eastl::static_pointer_cast<T>(m_user); + } + +private: + eastl::shared_ptr<void> m_user; }; -class WorkQueue +using WorkerRoutine = eastl::function<NTSTATUS(WorkItem & item)>; + +class WorkQueue final { public: WorkQueue(void); + WorkQueue(const WorkQueue &) = delete; ~WorkQueue(void); - NTSTATUS Start(workerRoutine_t workerRoutine); - void Stop(void); - void Enqueue(WorkItem * item); + NTSTATUS Start(WorkerRoutine routine); + void Stop(bool wait = true); + void Enqueue(WorkItem & item); + void Enqueue(eastl::deque<WorkItem> & items); private: Mutex m_mutex; - SLIST_HEADER m_work; - KEVENT m_wakeEvent; - BOOLEAN m_stopWorker; // Work LIST must be empty and StopWorker TRUE to be able to stop! + eastl::deque<WorkItem> m_queue; + Event m_wakeEvent; + bool m_stopWorker; // Work LIST must be empty and StopWorker TRUE to be able to stop! Thread m_worker; - workerRoutine_t m_workerRoutine; + WorkerRoutine m_workerRoutine; - static NTSTATUS WorkerInterceptorRoutine(PVOID workerContext); + static NTSTATUS WorkerInterceptorRoutine(eastl::shared_ptr<ThreadArgs> args); }; }; // namespace DriverThread diff --git a/Makefile.deps b/Makefile.deps index 0010a52..cfe3393 100644 --- a/Makefile.deps +++ b/Makefile.deps @@ -20,10 +20,8 @@ EASTL_DEPS := $(wildcard $(DPP_ROOT)/EASTL/source/*.cpp) $(wildcard $(DPP_ROOT)/ all: deps -$(LIBCRT_BUILD_DIR): - $(Q)mkdir -p '$(LIBCRT_BUILD_DIR)' - -$(LIBCRT_STATIC_LIB): $(LIBCRT_BUILD_DIR) $(LIBCRT_OBJECTS) +$(LIBCRT_STATIC_LIB): $(LIBCRT_OBJECTS) + $(Q)test -d '$(LIBCRT_BUILD_DIR)' || mkdir -p '$(LIBCRT_BUILD_DIR)' ifneq ($(Q),@) $(Q)$(AR) -rsv '$@' $(LIBCRT_OBJECTS) else @@ -31,7 +29,8 @@ else endif @echo 'AR $@' -$(LIBCXXRT_STATIC_LIB): $(LIBCRT_BUILD_DIR) $(LIBCXXRT_OBJECTS) +$(LIBCXXRT_STATIC_LIB): $(LIBCXXRT_OBJECTS) + $(Q)test -d '$(LIBCRT_BUILD_DIR)' || mkdir -p '$(LIBCRT_BUILD_DIR)' ifneq ($(Q),@) $(Q)$(AR) -rsv '$@' $(LIBCXXRT_OBJECTS) else @@ -39,7 +38,8 @@ else endif @echo 'AR $@' -$(LIBUSERCRT_STATIC_LIB): $(LIBCRT_BUILD_DIR) $(LIBUSERCRT_OBJECTS) +$(LIBUSERCRT_STATIC_LIB): $(LIBUSERCRT_OBJECTS) + $(Q)test -d '$(LIBCRT_BUILD_DIR)' || mkdir -p '$(LIBCRT_BUILD_DIR)' ifneq ($(Q),@) $(Q)$(AR) -rsv '$@' $(LIBUSERCRT_OBJECTS) else @@ -137,7 +137,7 @@ deps-build: \ deps: deps-print-local-notice deps-build $(EASTL_STATIC_LIB): $(CXX) $(EASTL_DEPS) - mkdir -p $(EASTL_BUILDDIR) + $(Q)test -d '$(EASTL_BUILDDIR)' || mkdir -p $(EASTL_BUILDDIR) cd $(EASTL_BUILDDIR) && \ $(CMAKE) ../EASTL \ -DCMAKE_CXX_COMPILER="$(realpath $(CXX))" \ @@ -185,11 +185,10 @@ deps-build: \ deps: deps-build $(EASTL_STATIC_LIB): $(CXX) $(EASTL_DEPS) - mkdir -p $(EASTL_BUILDDIR) + $(Q)test -d '$(EASTL_BUILDDIR)' || mkdir -p $(EASTL_BUILDDIR) cd $(EASTL_BUILDDIR) && \ $(CMAKE) ../EASTL \ -DCMAKE_CXX_COMPILER="$(realpath $(CXX))" \ - -DCMAKE_SYSTEM_NAME="Linux" \ -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=ONLY \ -DCMAKE_CXX_FLAGS='$(CFLAGS) $(CXXFLAGS) $(EASTL_CXXFLAGS)' && \ $(MAKE) $(CMAKE_Q) diff --git a/Makefile.native.inc b/Makefile.native.inc index 78b0f1c..106c9e3 100644 --- a/Makefile.native.inc +++ b/Makefile.native.inc @@ -26,13 +26,12 @@ CFLAGS := -Wall -Wextra -Wno-sign-compare -Wno-strict-aliasing -Wno-c++20-compat ifneq ($(WERROR),) CFLAGS += -Werror endif -CXXFLAGS := -fno-exceptions -fno-rtti -fuse-cxa-atexit +CXXFLAGS := -fuse-cxa-atexit EASTL_CXXFLAGS := -I$(DPP_ROOT)/EASTL/include -I$(DPP_ROOT)/EASTL/test/packages/EABase/include/Common \ - -DEASTL_THREAD_SUPPORT_AVAILABLE=0 \ - -DEASTL_EXCEPTIONS_ENABLED=0 \ - -DEASTL_ASSERT_ENABLED=0 \ - -DEA_COMPILER_NO_EXCEPTIONS=1 \ - -DEA_COMPILER_MANAGED_CPP=1 \ + -DEASTL_THREAD_SUPPORT_AVAILABLE=1 \ + -DEASTL_EXCEPTIONS_ENABLED=1 \ + -DEASTL_ASSERT_ENABLED=1 \ + -DEA_PLATFORM_POSIX=1 \ -Wno-unknown-pragmas \ -Wno-deprecated-copy USER_LDFLAGS := -Wl,--gc-sections diff --git a/examples/dpp-example-cplusplus-EASTL.cpp b/examples/dpp-example-cplusplus-EASTL.cpp index 618bc68..c28806e 100644 --- a/examples/dpp-example-cplusplus-EASTL.cpp +++ b/examples/dpp-example-cplusplus-EASTL.cpp @@ -3,6 +3,7 @@ #endif #include <cstdint> +#include <stdexcept> #include <EASTL/functional.h> #include <EASTL/hash_map.h> diff --git a/examples/dpp-example-cplusplus.cpp b/examples/dpp-example-cplusplus.cpp index cbdd7c6..980922d 100644 --- a/examples/dpp-example-cplusplus.cpp +++ b/examples/dpp-example-cplusplus.cpp @@ -56,29 +56,39 @@ private: unsigned int some_value = 0; }; -class MyWorkItem : public DriverThread::WorkItem +class MyWorkItem { public: - UINT32 counter; + MyWorkItem() + { + DbgPrint("MyWorkItem ctor\n"); + } + ~MyWorkItem() + { + DbgPrint("MyWorkItem dtor\n"); + } + + UINT32 counter = 0, another_counter = 0; }; -static DriverThread::WorkQueue work_queue; +static DriverThread::WorkQueue global_work_queue; static DerivedWithCDtor some_static(0xDEADC0DE); -struct threadContext +class threadContext : public DriverThread::ThreadArgs { +public: DriverThread::Semaphore sem; DriverThread::Thread dth; }; -static NTSTATUS threadRoutine(PVOID threadContext) +static NTSTATUS threadRoutine(eastl::shared_ptr<DriverThread::ThreadArgs> args) { - DbgPrint("ThreadRoutine %p, ThreadContext: %p\n", threadRoutine, threadContext); + DbgPrint("ThreadRoutine %p, ThreadContext: %p\n", threadRoutine, args); + auto ctx = eastl::static_pointer_cast<threadContext>(args); for (size_t i = 3; i > 0; --i) { - DbgPrint("ThreadLoop: %zu\n", i); + DbgPrint("ThreadLoop: %zu (isRunning: %u)\n", i, ctx->dth.isRunning()); } - struct threadContext * const ctx = (struct threadContext *)threadContext; DbgPrint("Fin. ThreadId: %p\n", ctx->dth.GetThreadId()); ctx->sem.Release(); DbgPrint("Thread WaitForTermination: 0x%X\n", ctx->dth.WaitForTermination()); // must return STATUS_UNSUCCESSFUL; @@ -93,34 +103,75 @@ static void test_cplusplus(void) Derived d; d.doSmth(); - struct threadContext ctx; - ctx.dth.Start(threadRoutine, (PVOID)&ctx); - ctx.sem.Wait(); + auto ctx = eastl::make_shared<threadContext>(); + ctx->dth.Start(threadRoutine, ctx); + ctx->sem.Wait(); DbgPrint("MainThread semaphore signaled.\n"); - ctx.dth.WaitForTermination(); - ctx.dth.WaitForTermination(); + ctx->dth.WaitForTermination(); + ctx->dth.WaitForTermination(); DbgPrint("MainThread EOF\n"); - MyWorkItem * wi = new MyWorkItem(); - wi->counter = 3; - work_queue.Enqueue(wi); + DriverThread::WorkQueue work_queue; + DbgPrint("WorkQueue test.\n"); + { + DriverThread::WorkItem wi(eastl::make_shared<MyWorkItem>()); + + auto user = wi.Get<MyWorkItem>(); + user->counter = 3; + user->another_counter = 1; + work_queue.Enqueue(wi); + global_work_queue.Enqueue(wi); + } + { + eastl::deque<DriverThread::WorkItem> items; + + for (size_t i = 1; i < 3; ++i) + { + DriverThread::WorkItem wi(eastl::make_shared<MyWorkItem>()); + + auto user = wi.Get<MyWorkItem>(); + user->counter = 3 + i; + user->another_counter = 1 + i; + items.emplace_back(wi); + } + + work_queue.Enqueue(items); + global_work_queue.Enqueue(items); + } + work_queue.Start( - [](DriverThread::WorkItem * item) + [](DriverThread::WorkItem & item) { - MyWorkItem * wi = - reinterpret_cast<MyWorkItem *>(CONTAINING_RECORD(item, DriverThread::WorkItem, QueueEntry)); + DbgPrint("Worker callback.\n"); + eastl::shared_ptr<MyWorkItem> wi; + item.Get<MyWorkItem>(wi); while (wi->counter-- > 0) { DbgPrint("WorkItem Counter: %u\n", wi->counter); } - DbgPrint("Worker finished.\n"); - delete item; + DbgPrint("Worker finished.\n"); return STATUS_SUCCESS; }); work_queue.Stop(); + global_work_queue.Start( + [](DriverThread::WorkItem & item) + { + DbgPrint("Global Worker callback.\n"); + + eastl::shared_ptr<MyWorkItem> wi; + item.Get<MyWorkItem>(wi); + while (wi->another_counter-- > 0) + { + DbgPrint("WorkItem Another Counter: %u\n", wi->another_counter); + } + + DbgPrint("Global Worker finished.\n"); + return STATUS_SUCCESS; + }); + some_static.doSmth(); } |