aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--DriverThread.cpp70
-rw-r--r--DriverThread.hpp45
-rw-r--r--ddk-template-cplusplus.cpp31
3 files changed, 128 insertions, 18 deletions
diff --git a/DriverThread.cpp b/DriverThread.cpp
index 2eb7c6d..93eda51 100644
--- a/DriverThread.cpp
+++ b/DriverThread.cpp
@@ -1,15 +1,33 @@
#include <DriverThread.hpp>
+// Thread
+
DriverThread::Thread::Thread(void)
{
}
-NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID threadContext)
+DriverThread::Thread::~Thread(void)
+{
+ WaitForTermination();
+}
+
+extern "C" void InterceptorThreadRoutine(PVOID threadContext)
+{
+ DriverThread::Thread * self = (DriverThread::Thread *)threadContext;
+
+ self->m_threadId = PsGetCurrentThreadId();
+ PsTerminateSystemThread(self->m_routine(self->m_threadContext));
+}
+
+NTSTATUS DriverThread::Thread::Start(threadRoutine routine, PVOID threadContext)
{
HANDLE threadHandle;
NTSTATUS status;
- status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, threadRoutine, threadContext);
+ LockGuard lock(m_mutex);
+ m_routine = routine;
+ m_threadContext = threadContext;
+ status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this);
if (!NT_SUCCESS(status))
{
@@ -22,14 +40,28 @@ NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID thread
NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout)
{
+ LockGuard lock(m_mutex);
+ if (m_threadObject == nullptr || PsGetCurrentThreadId() == m_threadId)
+ {
+ return STATUS_UNSUCCESSFUL;
+ }
+
LARGE_INTEGER li_timeout = {.QuadPart = timeout};
NTSTATUS status =
KeWaitForSingleObject(m_threadObject, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout));
ObDereferenceObject(m_threadObject);
+ m_threadObject = nullptr;
return status;
}
+HANDLE DriverThread::Thread::GetThreadId(void)
+{
+ return m_threadId;
+}
+
+// Spinlock
+
DriverThread::Spinlock::Spinlock(void)
{
KeInitializeSpinLock(&m_spinLock);
@@ -45,6 +77,8 @@ void DriverThread::Spinlock::Release(KIRQL * const oldIrql)
KeReleaseSpinLock(&m_spinLock, *oldIrql);
}
+// Semaphore
+
DriverThread::Semaphore::Semaphore(LONG initialValue, LONG maxValue)
{
KeInitializeSemaphore(&m_semaphore, initialValue, maxValue);
@@ -60,3 +94,35 @@ LONG DriverThread::Semaphore::Release(LONG adjustment)
{
return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE);
}
+
+// Mutex
+
+DriverThread::Mutex::Mutex(void)
+{
+}
+
+DriverThread::Mutex::~Mutex(void)
+{
+}
+
+void DriverThread::Mutex::Lock(void)
+{
+ while (m_interlock == 1 || InterlockedCompareExchange(&m_interlock, 1, 0) == 1) {}
+}
+
+void DriverThread::Mutex::Unlock(void)
+{
+ m_interlock = 0;
+}
+
+// LockGuard
+
+DriverThread::LockGuard::LockGuard(Mutex & m) : m_Lock(m)
+{
+ m_Lock.Lock();
+}
+
+DriverThread::LockGuard::~LockGuard(void)
+{
+ m_Lock.Unlock();
+}
diff --git a/DriverThread.hpp b/DriverThread.hpp
index 39f5a89..a15b3d1 100644
--- a/DriverThread.hpp
+++ b/DriverThread.hpp
@@ -3,20 +3,55 @@
#include <ntddk.h>
-#define TERMINATE_MYSELF(ntstatus) PsTerminateSystemThread(ntstatus);
+extern "C" void InterceptorThreadRoutine(PVOID threadContext);
+
+typedef NTSTATUS (*threadRoutine_t)(PVOID);
namespace DriverThread
{
+class Mutex
+{
+public:
+ Mutex(void);
+ ~Mutex(void);
+
+private:
+ void Lock();
+ void Unlock();
+
+ volatile long int m_interlock;
+
+ friend class LockGuard;
+};
+
+class LockGuard
+{
+public:
+ LockGuard(Mutex & m);
+ ~LockGuard(void);
+
+private:
+ Mutex m_Lock;
+};
+
class Thread
{
public:
- Thread();
- NTSTATUS Start(PKSTART_ROUTINE threadRoutine, PVOID threadContext);
+ Thread(void);
+ ~Thread(void);
+ NTSTATUS Start(threadRoutine_t routine, PVOID threadContext);
NTSTATUS WaitForTermination(LONGLONG timeout = 0);
+ HANDLE GetThreadId(void);
private:
- PETHREAD m_threadObject;
+ friend void ::InterceptorThreadRoutine(PVOID threadContext);
+
+ HANDLE m_threadId = nullptr;
+ PETHREAD m_threadObject = nullptr;
+ Mutex m_mutex;
+ threadRoutine_t m_routine;
+ PVOID m_threadContext;
};
class Spinlock
@@ -33,7 +68,7 @@ private:
class Semaphore
{
public:
- explicit Semaphore(LONG initialValue = 0, LONG maxValue = MAXLONG);
+ Semaphore(LONG initialValue = 0, LONG maxValue = MAXLONG);
NTSTATUS Wait(LONGLONG timeout = 0);
LONG Release(LONG adjustment = 1);
diff --git a/ddk-template-cplusplus.cpp b/ddk-template-cplusplus.cpp
index aafe492..b540a22 100644
--- a/ddk-template-cplusplus.cpp
+++ b/ddk-template-cplusplus.cpp
@@ -14,17 +14,25 @@ public:
}
};
-static void threadRoutine(PVOID threadContext)
+struct threadContext
+{
+ DriverThread::Semaphore sem;
+ DriverThread::Thread dth;
+};
+
+static NTSTATUS threadRoutine(PVOID threadContext)
{
DbgPrint("ThreadRoutine %p, ThreadContext: %p\n", threadRoutine, threadContext);
for (size_t i = 3; i > 0; --i)
{
DbgPrint("ThreadLoop: %zu\n", i);
}
- DbgPrint("Fin.\n");
- DriverThread::Semaphore * const sem = (DriverThread::Semaphore *)threadContext;
- sem->Release();
- TERMINATE_MYSELF(STATUS_SUCCESS);
+ struct threadContext * const ctx = (struct threadContext *)threadContext;
+ DbgPrint("Fin. ThreadId: %p\n", ctx->dth.GetThreadId());
+ ctx->sem.Release();
+ DbgPrint("Thread WaitForTermination: 0x%X\n", ctx->dth.WaitForTermination()); // must return STATUS_UNSUCCESSFUL;
+
+ return STATUS_SUCCESS;
}
static void test_cplusplus(void)
@@ -32,12 +40,13 @@ static void test_cplusplus(void)
TestSmth t;
t.doSmth();
- DriverThread::Semaphore sem;
- DriverThread::Thread dt;
- dt.Start(threadRoutine, (PVOID)&sem);
- sem.Wait();
- DbgPrint("Thread signaled semaphore.\n");
- dt.WaitForTermination();
+ struct threadContext ctx;
+ ctx.dth.Start(threadRoutine, (PVOID)&ctx);
+ ctx.sem.Wait();
+ DbgPrint("MainThread semaphore signaled.\n");
+ ctx.dth.WaitForTermination();
+ ctx.dth.WaitForTermination();
+ DbgPrint("MainThread EOF\n");
}
extern "C"