aboutsummaryrefslogtreecommitdiff
path: root/DriverThread.cpp
diff options
context:
space:
mode:
authorToni Uhlig <matzeton@googlemail.com>2021-04-20 17:04:24 +0200
committerToni Uhlig <matzeton@googlemail.com>2021-04-20 17:06:22 +0200
commitabc7a2f0b862f192c562d62053fc210b778cedb1 (patch)
tree352840c79f95b1c2dd01a5017222614975cc33e8 /DriverThread.cpp
parent3d51ea5b54a55c5417236ed00212d1e3d5134dd2 (diff)
Added MT support for ring0 drivers.
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Diffstat (limited to 'DriverThread.cpp')
-rw-r--r--DriverThread.cpp70
1 files changed, 68 insertions, 2 deletions
diff --git a/DriverThread.cpp b/DriverThread.cpp
index 2eb7c6d..93eda51 100644
--- a/DriverThread.cpp
+++ b/DriverThread.cpp
@@ -1,15 +1,33 @@
#include <DriverThread.hpp>
+// Thread
+
DriverThread::Thread::Thread(void)
{
}
-NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID threadContext)
+DriverThread::Thread::~Thread(void)
+{
+ WaitForTermination();
+}
+
+extern "C" void InterceptorThreadRoutine(PVOID threadContext)
+{
+ DriverThread::Thread * self = (DriverThread::Thread *)threadContext;
+
+ self->m_threadId = PsGetCurrentThreadId();
+ PsTerminateSystemThread(self->m_routine(self->m_threadContext));
+}
+
+NTSTATUS DriverThread::Thread::Start(threadRoutine routine, PVOID threadContext)
{
HANDLE threadHandle;
NTSTATUS status;
- status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, threadRoutine, threadContext);
+ LockGuard lock(m_mutex);
+ m_routine = routine;
+ m_threadContext = threadContext;
+ status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this);
if (!NT_SUCCESS(status))
{
@@ -22,14 +40,28 @@ NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID thread
NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout)
{
+ LockGuard lock(m_mutex);
+ if (m_threadObject == nullptr || PsGetCurrentThreadId() == m_threadId)
+ {
+ return STATUS_UNSUCCESSFUL;
+ }
+
LARGE_INTEGER li_timeout = {.QuadPart = timeout};
NTSTATUS status =
KeWaitForSingleObject(m_threadObject, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout));
ObDereferenceObject(m_threadObject);
+ m_threadObject = nullptr;
return status;
}
+HANDLE DriverThread::Thread::GetThreadId(void)
+{
+ return m_threadId;
+}
+
+// Spinlock
+
DriverThread::Spinlock::Spinlock(void)
{
KeInitializeSpinLock(&m_spinLock);
@@ -45,6 +77,8 @@ void DriverThread::Spinlock::Release(KIRQL * const oldIrql)
KeReleaseSpinLock(&m_spinLock, *oldIrql);
}
+// Semaphore
+
DriverThread::Semaphore::Semaphore(LONG initialValue, LONG maxValue)
{
KeInitializeSemaphore(&m_semaphore, initialValue, maxValue);
@@ -60,3 +94,35 @@ LONG DriverThread::Semaphore::Release(LONG adjustment)
{
return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE);
}
+
+// Mutex
+
+DriverThread::Mutex::Mutex(void)
+{
+}
+
+DriverThread::Mutex::~Mutex(void)
+{
+}
+
+void DriverThread::Mutex::Lock(void)
+{
+ while (m_interlock == 1 || InterlockedCompareExchange(&m_interlock, 1, 0) == 1) {}
+}
+
+void DriverThread::Mutex::Unlock(void)
+{
+ m_interlock = 0;
+}
+
+// LockGuard
+
+DriverThread::LockGuard::LockGuard(Mutex & m) : m_Lock(m)
+{
+ m_Lock.Lock();
+}
+
+DriverThread::LockGuard::~LockGuard(void)
+{
+ m_Lock.Unlock();
+}