diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2023-09-12 12:10:47 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2023-09-12 12:10:47 +0200 |
commit | 85c7ff11bdaf7f6a5d230744f71c5b9b7bc9bf6f (patch) | |
tree | b915b402fb5f20ab114d9286f98be31a4987254b /CRT | |
parent | 6e8f68f653a832491d9bf05f06bfab81aea0a9cb (diff) |
Added ThreadArgs, Event, WorkItem.
* added/modified WorkQueue example
* enable additional EASTL features in user space
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Diffstat (limited to 'CRT')
-rw-r--r-- | CRT/DriverThread.cpp | 138 | ||||
-rw-r--r-- | CRT/DriverThread.hpp | 96 |
2 files changed, 166 insertions, 68 deletions
diff --git a/CRT/DriverThread.cpp b/CRT/DriverThread.cpp index ecedc48..3a77dc8 100644 --- a/CRT/DriverThread.cpp +++ b/CRT/DriverThread.cpp @@ -1,5 +1,21 @@ #include <DriverThread.hpp> +class WorkQueueArgs : public DriverThread::ThreadArgs +{ + friend class WorkQueue; + +public: + WorkQueueArgs(DriverThread::WorkQueue * wq) : m_wq(wq){}; + WorkQueueArgs(const WorkQueueArgs &) = delete; + DriverThread::WorkQueue * getWorkQueue() + { + return m_wq; + } + +private: + DriverThread::WorkQueue * m_wq; +}; + // Thread DriverThread::Thread::Thread(void) @@ -13,13 +29,17 @@ DriverThread::Thread::~Thread(void) extern "C" void InterceptorThreadRoutine(PVOID threadContext) { + NTSTATUS threadReturn; DriverThread::Thread * self = (DriverThread::Thread *)threadContext; self->m_threadId = PsGetCurrentThreadId(); - PsTerminateSystemThread(self->m_routine(self->m_threadContext)); + threadReturn = self->m_routine(self->m_threadContext); + self->m_threadId = nullptr; + self->m_threadContext = nullptr; + PsTerminateSystemThread(threadReturn); } -NTSTATUS DriverThread::Thread::Start(threadRoutine_t routine, PVOID threadContext) +NTSTATUS DriverThread::Thread::Start(ThreadRoutine routine, eastl::shared_ptr<ThreadArgs> args) { HANDLE threadHandle; NTSTATUS status; @@ -31,7 +51,7 @@ NTSTATUS DriverThread::Thread::Start(threadRoutine_t routine, PVOID threadContex } m_routine = routine; - m_threadContext = threadContext; + m_threadContext = args; status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this); if (!NT_SUCCESS(status)) @@ -72,11 +92,6 @@ NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout) return status; } -HANDLE DriverThread::Thread::GetThreadId(void) -{ - return m_threadId; -} - // Spinlock DriverThread::Spinlock::Spinlock(void) @@ -117,6 +132,24 @@ LONG DriverThread::Semaphore::Release(LONG adjustment) return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE); } +// Event + +DriverThread::Event::Event() +{ + KeInitializeEvent(&m_event, NotificationEvent, FALSE); +} + +NTSTATUS DriverThread::Event::Wait(LONGLONG timeout) +{ + LARGE_INTEGER li_timeout = {.QuadPart = timeout}; + return KeWaitForSingleObject(&m_event, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout)); +} + +NTSTATUS DriverThread::Event::Notify() +{ + return KeSetEvent(&m_event, 0, FALSE); +} + // Mutex DriverThread::Mutex::Mutex(void) @@ -151,11 +184,9 @@ DriverThread::LockGuard::~LockGuard(void) // WorkQueue -DriverThread::WorkQueue::WorkQueue(void) : m_worker() +DriverThread::WorkQueue::WorkQueue(void) + : m_mutex(), m_queue(), m_wakeEvent(), m_stopWorker(false), m_worker(), m_workerRoutine(nullptr) { - InitializeSListHead(&m_work); - KeInitializeEvent(&m_wakeEvent, SynchronizationEvent, FALSE); - m_stopWorker = FALSE; } DriverThread::WorkQueue::~WorkQueue(void) @@ -163,14 +194,15 @@ DriverThread::WorkQueue::~WorkQueue(void) Stop(); } -NTSTATUS DriverThread::WorkQueue::Start(workerRoutine_t workerRoutine) +NTSTATUS DriverThread::WorkQueue::Start(WorkerRoutine routine) { NTSTATUS status; { LockGuard lock(m_mutex); - m_workerRoutine = workerRoutine; - status = m_worker.Start(WorkerInterceptorRoutine, this); + m_workerRoutine = routine; + auto wqa = eastl::make_shared<WorkQueueArgs>(this); + status = m_worker.Start(WorkerInterceptorRoutine, wqa); } if (!NT_SUCCESS(status) && status != STATUS_UNSUCCESSFUL) @@ -181,73 +213,83 @@ NTSTATUS DriverThread::WorkQueue::Start(workerRoutine_t workerRoutine) return status; } -void DriverThread::WorkQueue::Stop(void) +void DriverThread::WorkQueue::Stop(bool wait) { LockGuard lock(m_mutex); - if (m_stopWorker == TRUE) + if (m_stopWorker == true) { return; } - m_stopWorker = TRUE; - KeSetEvent(&m_wakeEvent, 0, FALSE); + m_stopWorker = true; + m_wakeEvent.Notify(); + if (wait) + { + m_worker.WaitForTermination(); + } +} + +void DriverThread::WorkQueue::Enqueue(WorkItem & item) +{ + { + LockGuard lock(m_mutex); + m_queue.emplace_back(item); + } + m_wakeEvent.Notify(); } -void DriverThread::WorkQueue::Enqueue(WorkItem * item) +void DriverThread::WorkQueue::Enqueue(eastl::deque<WorkItem> & items) { - if (InterlockedPushEntrySList(&m_work, &item->QueueEntry) == NULL) { - // Work queue was empty. So, signal the work queue event in case the - // worker thread is waiting on the event for more operations. - KeSetEvent(&m_wakeEvent, 0, FALSE); + LockGuard lock(m_mutex); + m_queue.insert(m_queue.end(), items.begin(), items.end()); } + m_wakeEvent.Notify(); } -NTSTATUS DriverThread::WorkQueue::WorkerInterceptorRoutine(PVOID workerContext) +NTSTATUS DriverThread::WorkQueue::WorkerInterceptorRoutine(eastl::shared_ptr<ThreadArgs> args) { - DriverThread::WorkQueue * wq = (DriverThread::WorkQueue *)workerContext; - PSLIST_ENTRY listEntryRev, listEntry, next; + auto wqa = eastl::static_pointer_cast<WorkQueueArgs>(args); + WorkQueue * wq = wqa->getWorkQueue(); PAGED_CODE(); for (;;) { - // Flush all the queued operations into a local list - listEntryRev = InterlockedFlushSList(&wq->m_work); + eastl::deque<WorkItem> doQueue; + std::size_t nItems; - if (listEntryRev == NULL) { + LockGuard lock(wq->m_mutex); + nItems = wq->m_queue.size(); + } - // There's no work to do. If we are allowed to stop, then stop. - if (wq->m_stopWorker == TRUE) + if (nItems == 0) + { + if (wq->m_stopWorker == true) { break; } - // Otherwise, wait for more operations to be enqueued. - KeWaitForSingleObject(&wq->m_wakeEvent, Executive, KernelMode, FALSE, 0); + wq->m_wakeEvent.Wait(); continue; } - // Need to reverse the flushed list in order to preserve the FIFO order - listEntry = NULL; - while (listEntryRev != NULL) { - next = listEntryRev->Next; - listEntryRev->Next = listEntry; - listEntry = listEntryRev; - listEntryRev = next; + LockGuard lock(wq->m_mutex); + doQueue = wq->m_queue; + wq->m_queue.clear(); } - // Now process the correctly ordered list of operations one by one - while (listEntry) + while (doQueue.size() > 0) { - PSLIST_ENTRY arg = listEntry; - listEntry = listEntry->Next; - DriverThread::WorkItem * wi = CONTAINING_RECORD(arg, DriverThread::WorkItem, QueueEntry); - if (wq->m_workerRoutine(wi) != STATUS_SUCCESS) + WorkItem & item = doQueue.front(); + + if (wq->m_workerRoutine(item) != STATUS_SUCCESS) { - wq->m_stopWorker = TRUE; + wq->m_stopWorker = true; } + + doQueue.pop_front(); } } diff --git a/CRT/DriverThread.hpp b/CRT/DriverThread.hpp index d863a66..ac2eddd 100644 --- a/CRT/DriverThread.hpp +++ b/CRT/DriverThread.hpp @@ -3,14 +3,14 @@ #include <ntddk.h> +#include <EASTL/deque.h> +#include <EASTL/functional.h> +#include <EASTL/shared_ptr.h> + extern "C" void InterceptorThreadRoutine(PVOID threadContext); namespace DriverThread { -class WorkItem; -typedef NTSTATUS (*threadRoutine_t)(PVOID); -typedef NTSTATUS (*workerRoutine_t)(WorkItem * item); - class Mutex { public: @@ -36,14 +36,36 @@ private: Mutex m_Lock; }; +class ThreadArgs : public virtual eastl::enable_shared_from_this<ThreadArgs> +{ +public: + ThreadArgs(void) + { + } + ThreadArgs(const ThreadArgs &) = delete; + virtual ~ThreadArgs(void) + { + } +}; + +using ThreadRoutine = eastl::function<NTSTATUS(eastl::shared_ptr<ThreadArgs> args)>; + class Thread { public: Thread(void); + Thread(const Thread &) = delete; ~Thread(void); - NTSTATUS Start(threadRoutine_t routine, PVOID threadContext); + NTSTATUS Start(ThreadRoutine routine, eastl::shared_ptr<ThreadArgs> args); NTSTATUS WaitForTermination(LONGLONG timeout = 0); - HANDLE GetThreadId(void); + HANDLE GetThreadId(void) + { + return m_threadId; + } + bool isRunning(void) + { + return GetThreadId() != nullptr; + } private: friend void ::InterceptorThreadRoutine(PVOID threadContext); @@ -51,8 +73,8 @@ private: HANDLE m_threadId = nullptr; PETHREAD m_threadObject = nullptr; Mutex m_mutex; - threadRoutine_t m_routine; - PVOID m_threadContext; + ThreadRoutine m_routine; + eastl::shared_ptr<ThreadArgs> m_threadContext; }; class Spinlock @@ -79,31 +101,65 @@ private: KSEMAPHORE m_semaphore; }; -class WorkItem +class Event +{ +public: + Event(); + NTSTATUS Wait(LONGLONG timeout = 0); + NTSTATUS Notify(); + +private: + KEVENT m_event; +}; + +class WorkItem final { + friend class WorkQueue; + public: - SLIST_ENTRY QueueEntry; - PSLIST_ENTRY WorkListEntry; + WorkItem(const eastl::shared_ptr<void> & user) : m_user(std::move(user)) + { + } + virtual ~WorkItem(void) + { + } + template <class T> + eastl::shared_ptr<T> Get(void) + { + return eastl::static_pointer_cast<T>(m_user); + } + template <class T> + void Get(eastl::shared_ptr<T> & dest) + { + dest = eastl::static_pointer_cast<T>(m_user); + } + +private: + eastl::shared_ptr<void> m_user; }; -class WorkQueue +using WorkerRoutine = eastl::function<NTSTATUS(WorkItem & item)>; + +class WorkQueue final { public: WorkQueue(void); + WorkQueue(const WorkQueue &) = delete; ~WorkQueue(void); - NTSTATUS Start(workerRoutine_t workerRoutine); - void Stop(void); - void Enqueue(WorkItem * item); + NTSTATUS Start(WorkerRoutine routine); + void Stop(bool wait = true); + void Enqueue(WorkItem & item); + void Enqueue(eastl::deque<WorkItem> & items); private: Mutex m_mutex; - SLIST_HEADER m_work; - KEVENT m_wakeEvent; - BOOLEAN m_stopWorker; // Work LIST must be empty and StopWorker TRUE to be able to stop! + eastl::deque<WorkItem> m_queue; + Event m_wakeEvent; + bool m_stopWorker; // Work LIST must be empty and StopWorker TRUE to be able to stop! Thread m_worker; - workerRoutine_t m_workerRoutine; + WorkerRoutine m_workerRoutine; - static NTSTATUS WorkerInterceptorRoutine(PVOID workerContext); + static NTSTATUS WorkerInterceptorRoutine(eastl::shared_ptr<ThreadArgs> args); }; }; // namespace DriverThread |