diff options
-rw-r--r-- | DriverThread.cpp | 70 | ||||
-rw-r--r-- | DriverThread.hpp | 45 | ||||
-rw-r--r-- | ddk-template-cplusplus.cpp | 31 |
3 files changed, 128 insertions, 18 deletions
diff --git a/DriverThread.cpp b/DriverThread.cpp index 2eb7c6d..93eda51 100644 --- a/DriverThread.cpp +++ b/DriverThread.cpp @@ -1,15 +1,33 @@ #include <DriverThread.hpp> +// Thread + DriverThread::Thread::Thread(void) { } -NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID threadContext) +DriverThread::Thread::~Thread(void) +{ + WaitForTermination(); +} + +extern "C" void InterceptorThreadRoutine(PVOID threadContext) +{ + DriverThread::Thread * self = (DriverThread::Thread *)threadContext; + + self->m_threadId = PsGetCurrentThreadId(); + PsTerminateSystemThread(self->m_routine(self->m_threadContext)); +} + +NTSTATUS DriverThread::Thread::Start(threadRoutine routine, PVOID threadContext) { HANDLE threadHandle; NTSTATUS status; - status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, threadRoutine, threadContext); + LockGuard lock(m_mutex); + m_routine = routine; + m_threadContext = threadContext; + status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this); if (!NT_SUCCESS(status)) { @@ -22,14 +40,28 @@ NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID thread NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout) { + LockGuard lock(m_mutex); + if (m_threadObject == nullptr || PsGetCurrentThreadId() == m_threadId) + { + return STATUS_UNSUCCESSFUL; + } + LARGE_INTEGER li_timeout = {.QuadPart = timeout}; NTSTATUS status = KeWaitForSingleObject(m_threadObject, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout)); ObDereferenceObject(m_threadObject); + m_threadObject = nullptr; return status; } +HANDLE DriverThread::Thread::GetThreadId(void) +{ + return m_threadId; +} + +// Spinlock + DriverThread::Spinlock::Spinlock(void) { KeInitializeSpinLock(&m_spinLock); @@ -45,6 +77,8 @@ void DriverThread::Spinlock::Release(KIRQL * const oldIrql) KeReleaseSpinLock(&m_spinLock, *oldIrql); } +// Semaphore + DriverThread::Semaphore::Semaphore(LONG initialValue, LONG maxValue) { KeInitializeSemaphore(&m_semaphore, initialValue, maxValue); @@ -60,3 +94,35 @@ LONG DriverThread::Semaphore::Release(LONG adjustment) { return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE); } + +// Mutex + +DriverThread::Mutex::Mutex(void) +{ +} + +DriverThread::Mutex::~Mutex(void) +{ +} + +void DriverThread::Mutex::Lock(void) +{ + while (m_interlock == 1 || InterlockedCompareExchange(&m_interlock, 1, 0) == 1) {} +} + +void DriverThread::Mutex::Unlock(void) +{ + m_interlock = 0; +} + +// LockGuard + +DriverThread::LockGuard::LockGuard(Mutex & m) : m_Lock(m) +{ + m_Lock.Lock(); +} + +DriverThread::LockGuard::~LockGuard(void) +{ + m_Lock.Unlock(); +} diff --git a/DriverThread.hpp b/DriverThread.hpp index 39f5a89..a15b3d1 100644 --- a/DriverThread.hpp +++ b/DriverThread.hpp @@ -3,20 +3,55 @@ #include <ntddk.h> -#define TERMINATE_MYSELF(ntstatus) PsTerminateSystemThread(ntstatus); +extern "C" void InterceptorThreadRoutine(PVOID threadContext); + +typedef NTSTATUS (*threadRoutine_t)(PVOID); namespace DriverThread { +class Mutex +{ +public: + Mutex(void); + ~Mutex(void); + +private: + void Lock(); + void Unlock(); + + volatile long int m_interlock; + + friend class LockGuard; +}; + +class LockGuard +{ +public: + LockGuard(Mutex & m); + ~LockGuard(void); + +private: + Mutex m_Lock; +}; + class Thread { public: - Thread(); - NTSTATUS Start(PKSTART_ROUTINE threadRoutine, PVOID threadContext); + Thread(void); + ~Thread(void); + NTSTATUS Start(threadRoutine_t routine, PVOID threadContext); NTSTATUS WaitForTermination(LONGLONG timeout = 0); + HANDLE GetThreadId(void); private: - PETHREAD m_threadObject; + friend void ::InterceptorThreadRoutine(PVOID threadContext); + + HANDLE m_threadId = nullptr; + PETHREAD m_threadObject = nullptr; + Mutex m_mutex; + threadRoutine_t m_routine; + PVOID m_threadContext; }; class Spinlock @@ -33,7 +68,7 @@ private: class Semaphore { public: - explicit Semaphore(LONG initialValue = 0, LONG maxValue = MAXLONG); + Semaphore(LONG initialValue = 0, LONG maxValue = MAXLONG); NTSTATUS Wait(LONGLONG timeout = 0); LONG Release(LONG adjustment = 1); diff --git a/ddk-template-cplusplus.cpp b/ddk-template-cplusplus.cpp index aafe492..b540a22 100644 --- a/ddk-template-cplusplus.cpp +++ b/ddk-template-cplusplus.cpp @@ -14,17 +14,25 @@ public: } }; -static void threadRoutine(PVOID threadContext) +struct threadContext +{ + DriverThread::Semaphore sem; + DriverThread::Thread dth; +}; + +static NTSTATUS threadRoutine(PVOID threadContext) { DbgPrint("ThreadRoutine %p, ThreadContext: %p\n", threadRoutine, threadContext); for (size_t i = 3; i > 0; --i) { DbgPrint("ThreadLoop: %zu\n", i); } - DbgPrint("Fin.\n"); - DriverThread::Semaphore * const sem = (DriverThread::Semaphore *)threadContext; - sem->Release(); - TERMINATE_MYSELF(STATUS_SUCCESS); + struct threadContext * const ctx = (struct threadContext *)threadContext; + DbgPrint("Fin. ThreadId: %p\n", ctx->dth.GetThreadId()); + ctx->sem.Release(); + DbgPrint("Thread WaitForTermination: 0x%X\n", ctx->dth.WaitForTermination()); // must return STATUS_UNSUCCESSFUL; + + return STATUS_SUCCESS; } static void test_cplusplus(void) @@ -32,12 +40,13 @@ static void test_cplusplus(void) TestSmth t; t.doSmth(); - DriverThread::Semaphore sem; - DriverThread::Thread dt; - dt.Start(threadRoutine, (PVOID)&sem); - sem.Wait(); - DbgPrint("Thread signaled semaphore.\n"); - dt.WaitForTermination(); + struct threadContext ctx; + ctx.dth.Start(threadRoutine, (PVOID)&ctx); + ctx.sem.Wait(); + DbgPrint("MainThread semaphore signaled.\n"); + ctx.dth.WaitForTermination(); + ctx.dth.WaitForTermination(); + DbgPrint("MainThread EOF\n"); } extern "C" |