diff options
Diffstat (limited to 'CRT/DriverThread.cpp')
-rw-r--r-- | CRT/DriverThread.cpp | 297 |
1 files changed, 297 insertions, 0 deletions
diff --git a/CRT/DriverThread.cpp b/CRT/DriverThread.cpp new file mode 100644 index 0000000..3a77dc8 --- /dev/null +++ b/CRT/DriverThread.cpp @@ -0,0 +1,297 @@ +#include <DriverThread.hpp> + +class WorkQueueArgs : public DriverThread::ThreadArgs +{ + friend class WorkQueue; + +public: + WorkQueueArgs(DriverThread::WorkQueue * wq) : m_wq(wq){}; + WorkQueueArgs(const WorkQueueArgs &) = delete; + DriverThread::WorkQueue * getWorkQueue() + { + return m_wq; + } + +private: + DriverThread::WorkQueue * m_wq; +}; + +// Thread + +DriverThread::Thread::Thread(void) +{ +} + +DriverThread::Thread::~Thread(void) +{ + WaitForTermination(); +} + +extern "C" void InterceptorThreadRoutine(PVOID threadContext) +{ + NTSTATUS threadReturn; + DriverThread::Thread * self = (DriverThread::Thread *)threadContext; + + self->m_threadId = PsGetCurrentThreadId(); + threadReturn = self->m_routine(self->m_threadContext); + self->m_threadId = nullptr; + self->m_threadContext = nullptr; + PsTerminateSystemThread(threadReturn); +} + +NTSTATUS DriverThread::Thread::Start(ThreadRoutine routine, eastl::shared_ptr<ThreadArgs> args) +{ + HANDLE threadHandle; + NTSTATUS status; + + LockGuard lock(m_mutex); + if (m_threadObject != nullptr) + { + return STATUS_UNSUCCESSFUL; + } + + m_routine = routine; + m_threadContext = args; + status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this); + + if (!NT_SUCCESS(status)) + { + return status; + } + + status = + ObReferenceObjectByHandle(threadHandle, THREAD_ALL_ACCESS, NULL, KernelMode, (PVOID *)&m_threadObject, NULL); + + if (!NT_SUCCESS(status)) + { + return status; + } + + return ZwClose(threadHandle); +} + +NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout) +{ + if (PsGetCurrentThreadId() == m_threadId) + { + return STATUS_UNSUCCESSFUL; + } + + LockGuard lock(m_mutex); + if (m_threadObject == nullptr) + { + return STATUS_UNSUCCESSFUL; + } + + LARGE_INTEGER li_timeout = {.QuadPart = timeout}; + NTSTATUS status = + KeWaitForSingleObject(m_threadObject, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout)); + + ObDereferenceObject(m_threadObject); + m_threadObject = nullptr; + return status; +} + +// Spinlock + +DriverThread::Spinlock::Spinlock(void) +{ + KeInitializeSpinLock(&m_spinLock); +} + +NTSTATUS DriverThread::Spinlock::Acquire(void) +{ + return KeAcquireSpinLock(&m_spinLock, &m_oldIrql); +} + +void DriverThread::Spinlock::Release(void) +{ + KeReleaseSpinLock(&m_spinLock, m_oldIrql); +} + +KIRQL DriverThread::Spinlock::GetOldIrql(void) +{ + return m_oldIrql; +} + +// Semaphore + +DriverThread::Semaphore::Semaphore(LONG initialValue, LONG maxValue) +{ + KeInitializeSemaphore(&m_semaphore, initialValue, maxValue); +} + +NTSTATUS DriverThread::Semaphore::Wait(LONGLONG timeout) +{ + LARGE_INTEGER li_timeout = {.QuadPart = timeout}; + return KeWaitForSingleObject(&m_semaphore, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout)); +} + +LONG DriverThread::Semaphore::Release(LONG adjustment) +{ + return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE); +} + +// Event + +DriverThread::Event::Event() +{ + KeInitializeEvent(&m_event, NotificationEvent, FALSE); +} + +NTSTATUS DriverThread::Event::Wait(LONGLONG timeout) +{ + LARGE_INTEGER li_timeout = {.QuadPart = timeout}; + return KeWaitForSingleObject(&m_event, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout)); +} + +NTSTATUS DriverThread::Event::Notify() +{ + return KeSetEvent(&m_event, 0, FALSE); +} + +// Mutex + +DriverThread::Mutex::Mutex(void) +{ +} + +DriverThread::Mutex::~Mutex(void) +{ +} + +void DriverThread::Mutex::Lock(void) +{ + while (m_interlock == 1 || InterlockedCompareExchange(&m_interlock, 1, 0) == 1) {} +} + +void DriverThread::Mutex::Unlock(void) +{ + m_interlock = 0; +} + +// LockGuard + +DriverThread::LockGuard::LockGuard(Mutex & m) : m_Lock(m) +{ + m_Lock.Lock(); +} + +DriverThread::LockGuard::~LockGuard(void) +{ + m_Lock.Unlock(); +} + +// WorkQueue + +DriverThread::WorkQueue::WorkQueue(void) + : m_mutex(), m_queue(), m_wakeEvent(), m_stopWorker(false), m_worker(), m_workerRoutine(nullptr) +{ +} + +DriverThread::WorkQueue::~WorkQueue(void) +{ + Stop(); +} + +NTSTATUS DriverThread::WorkQueue::Start(WorkerRoutine routine) +{ + NTSTATUS status; + + { + LockGuard lock(m_mutex); + m_workerRoutine = routine; + auto wqa = eastl::make_shared<WorkQueueArgs>(this); + status = m_worker.Start(WorkerInterceptorRoutine, wqa); + } + + if (!NT_SUCCESS(status) && status != STATUS_UNSUCCESSFUL) + { + Stop(); + } + + return status; +} + +void DriverThread::WorkQueue::Stop(bool wait) +{ + LockGuard lock(m_mutex); + if (m_stopWorker == true) + { + return; + } + m_stopWorker = true; + m_wakeEvent.Notify(); + if (wait) + { + m_worker.WaitForTermination(); + } +} + +void DriverThread::WorkQueue::Enqueue(WorkItem & item) +{ + { + LockGuard lock(m_mutex); + m_queue.emplace_back(item); + } + m_wakeEvent.Notify(); +} + +void DriverThread::WorkQueue::Enqueue(eastl::deque<WorkItem> & items) +{ + { + LockGuard lock(m_mutex); + m_queue.insert(m_queue.end(), items.begin(), items.end()); + } + m_wakeEvent.Notify(); +} + +NTSTATUS DriverThread::WorkQueue::WorkerInterceptorRoutine(eastl::shared_ptr<ThreadArgs> args) +{ + auto wqa = eastl::static_pointer_cast<WorkQueueArgs>(args); + WorkQueue * wq = wqa->getWorkQueue(); + + PAGED_CODE(); + + for (;;) + { + eastl::deque<WorkItem> doQueue; + std::size_t nItems; + + { + LockGuard lock(wq->m_mutex); + nItems = wq->m_queue.size(); + } + + if (nItems == 0) + { + if (wq->m_stopWorker == true) + { + break; + } + + wq->m_wakeEvent.Wait(); + continue; + } + + { + LockGuard lock(wq->m_mutex); + doQueue = wq->m_queue; + wq->m_queue.clear(); + } + + while (doQueue.size() > 0) + { + WorkItem & item = doQueue.front(); + + if (wq->m_workerRoutine(item) != STATUS_SUCCESS) + { + wq->m_stopWorker = true; + } + + doQueue.pop_front(); + } + } + + return STATUS_SUCCESS; +} |