diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2021-04-20 17:04:24 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2021-04-20 17:06:22 +0200 |
commit | abc7a2f0b862f192c562d62053fc210b778cedb1 (patch) | |
tree | 352840c79f95b1c2dd01a5017222614975cc33e8 /DriverThread.cpp | |
parent | 3d51ea5b54a55c5417236ed00212d1e3d5134dd2 (diff) |
Added MT support for ring0 drivers.
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
Diffstat (limited to 'DriverThread.cpp')
-rw-r--r-- | DriverThread.cpp | 70 |
1 files changed, 68 insertions, 2 deletions
diff --git a/DriverThread.cpp b/DriverThread.cpp index 2eb7c6d..93eda51 100644 --- a/DriverThread.cpp +++ b/DriverThread.cpp @@ -1,15 +1,33 @@ #include <DriverThread.hpp> +// Thread + DriverThread::Thread::Thread(void) { } -NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID threadContext) +DriverThread::Thread::~Thread(void) +{ + WaitForTermination(); +} + +extern "C" void InterceptorThreadRoutine(PVOID threadContext) +{ + DriverThread::Thread * self = (DriverThread::Thread *)threadContext; + + self->m_threadId = PsGetCurrentThreadId(); + PsTerminateSystemThread(self->m_routine(self->m_threadContext)); +} + +NTSTATUS DriverThread::Thread::Start(threadRoutine routine, PVOID threadContext) { HANDLE threadHandle; NTSTATUS status; - status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, threadRoutine, threadContext); + LockGuard lock(m_mutex); + m_routine = routine; + m_threadContext = threadContext; + status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this); if (!NT_SUCCESS(status)) { @@ -22,14 +40,28 @@ NTSTATUS DriverThread::Thread::Start(PKSTART_ROUTINE threadRoutine, PVOID thread NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout) { + LockGuard lock(m_mutex); + if (m_threadObject == nullptr || PsGetCurrentThreadId() == m_threadId) + { + return STATUS_UNSUCCESSFUL; + } + LARGE_INTEGER li_timeout = {.QuadPart = timeout}; NTSTATUS status = KeWaitForSingleObject(m_threadObject, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout)); ObDereferenceObject(m_threadObject); + m_threadObject = nullptr; return status; } +HANDLE DriverThread::Thread::GetThreadId(void) +{ + return m_threadId; +} + +// Spinlock + DriverThread::Spinlock::Spinlock(void) { KeInitializeSpinLock(&m_spinLock); @@ -45,6 +77,8 @@ void DriverThread::Spinlock::Release(KIRQL * const oldIrql) KeReleaseSpinLock(&m_spinLock, *oldIrql); } +// Semaphore + DriverThread::Semaphore::Semaphore(LONG initialValue, LONG maxValue) { KeInitializeSemaphore(&m_semaphore, initialValue, maxValue); @@ -60,3 +94,35 @@ LONG DriverThread::Semaphore::Release(LONG adjustment) { return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE); } + +// Mutex + +DriverThread::Mutex::Mutex(void) +{ +} + +DriverThread::Mutex::~Mutex(void) +{ +} + +void DriverThread::Mutex::Lock(void) +{ + while (m_interlock == 1 || InterlockedCompareExchange(&m_interlock, 1, 0) == 1) {} +} + +void DriverThread::Mutex::Unlock(void) +{ + m_interlock = 0; +} + +// LockGuard + +DriverThread::LockGuard::LockGuard(Mutex & m) : m_Lock(m) +{ + m_Lock.Lock(); +} + +DriverThread::LockGuard::~LockGuard(void) +{ + m_Lock.Unlock(); +} |