aboutsummaryrefslogtreecommitdiff
path: root/DriverThread.cpp
blob: 93eda51e496d34d1cc445ea8190df0ae50b977bb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#include <DriverThread.hpp>

// Thread

DriverThread::Thread::Thread(void)
{
}

DriverThread::Thread::~Thread(void)
{
    WaitForTermination();
}

extern "C" void InterceptorThreadRoutine(PVOID threadContext)
{
    DriverThread::Thread * self = (DriverThread::Thread *)threadContext;

    self->m_threadId = PsGetCurrentThreadId();
    PsTerminateSystemThread(self->m_routine(self->m_threadContext));
}

NTSTATUS DriverThread::Thread::Start(threadRoutine routine, PVOID threadContext)
{
    HANDLE threadHandle;
    NTSTATUS status;

    LockGuard lock(m_mutex);
    m_routine = routine;
    m_threadContext = threadContext;
    status = PsCreateSystemThread(&threadHandle, (ACCESS_MASK)0, NULL, (HANDLE)0, NULL, InterceptorThreadRoutine, this);

    if (!NT_SUCCESS(status))
    {
        return status;
    }

    ObReferenceObjectByHandle(threadHandle, THREAD_ALL_ACCESS, NULL, KernelMode, (PVOID *)&m_threadObject, NULL);
    return ZwClose(threadHandle);
}

NTSTATUS DriverThread::Thread::WaitForTermination(LONGLONG timeout)
{
    LockGuard lock(m_mutex);
    if (m_threadObject == nullptr || PsGetCurrentThreadId() == m_threadId)
    {
        return STATUS_UNSUCCESSFUL;
    }

    LARGE_INTEGER li_timeout = {.QuadPart = timeout};
    NTSTATUS status =
        KeWaitForSingleObject(m_threadObject, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout));

    ObDereferenceObject(m_threadObject);
    m_threadObject = nullptr;
    return status;
}

HANDLE DriverThread::Thread::GetThreadId(void)
{
    return m_threadId;
}

// Spinlock

DriverThread::Spinlock::Spinlock(void)
{
    KeInitializeSpinLock(&m_spinLock);
}

NTSTATUS DriverThread::Spinlock::Acquire(KIRQL * const oldIrql)
{
    return KeAcquireSpinLock(&m_spinLock, oldIrql);
}

void DriverThread::Spinlock::Release(KIRQL * const oldIrql)
{
    KeReleaseSpinLock(&m_spinLock, *oldIrql);
}

// Semaphore

DriverThread::Semaphore::Semaphore(LONG initialValue, LONG maxValue)
{
    KeInitializeSemaphore(&m_semaphore, initialValue, maxValue);
}

NTSTATUS DriverThread::Semaphore::Wait(LONGLONG timeout)
{
    LARGE_INTEGER li_timeout = {.QuadPart = timeout};
    return KeWaitForSingleObject(&m_semaphore, Executive, KernelMode, FALSE, (timeout == 0 ? NULL : &li_timeout));
}

LONG DriverThread::Semaphore::Release(LONG adjustment)
{
    return KeReleaseSemaphore(&m_semaphore, 0, adjustment, FALSE);
}

// Mutex

DriverThread::Mutex::Mutex(void)
{
}

DriverThread::Mutex::~Mutex(void)
{
}

void DriverThread::Mutex::Lock(void)
{
    while (m_interlock == 1 || InterlockedCompareExchange(&m_interlock, 1, 0) == 1) {}
}

void DriverThread::Mutex::Unlock(void)
{
    m_interlock = 0;
}

// LockGuard

DriverThread::LockGuard::LockGuard(Mutex & m) : m_Lock(m)
{
    m_Lock.Lock();
}

DriverThread::LockGuard::~LockGuard(void)
{
    m_Lock.Unlock();
}