aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile47
-rw-r--r--README.md12
-rw-r--r--common.hpp308
-rw-r--r--driver-protobuf-c-tcp.bat27
-rw-r--r--driver-protobuf-c-tcp.cpp139
-rw-r--r--driver-protobuf-c.cpp130
-rw-r--r--driver.cpp4
-rw-r--r--echo_srv.py16
-rw-r--r--userspace_client_protobuf.cpp152
9 files changed, 744 insertions, 91 deletions
diff --git a/Makefile b/Makefile
index 5af128a..35e3a38 100644
--- a/Makefile
+++ b/Makefile
@@ -7,24 +7,30 @@ include $(DPP_ROOT)/Makefile.inc
DRIVER0_NAME = driver
DRIVER0_OBJECTS = $(DRIVER0_NAME).o ksocket.o berkeley.o
DRIVER0_TARGET = $(DRIVER0_NAME).sys
-DRIVER0_CFLAGS = -I. -Wl,--exclude-all-symbols -DNDEBUG
DRIVER1_NAME = driver-protobuf-c
DRIVER1_OBJECTS = $(DRIVER1_NAME).o protobuf-c/protobuf-c.o protobuf-c/example.pb-c.o
DRIVER1_TARGET = $(DRIVER1_NAME).sys
-DRIVER1_CFLAGS = -I. -Iprotobuf-c -Wl,--exclude-all-symbols -DNDEBUG
-USERSPACE_NAME = userspace_client
-USERSPACE_OBJECTS = $(USERSPACE_NAME).o
-USERSPACE_TARGET = $(USERSPACE_NAME).exe
+DRIVER2_NAME = driver-protobuf-c-tcp
+DRIVER2_OBJECTS = $(DRIVER2_NAME).o ksocket.o berkeley.o protobuf-c/protobuf-c.o protobuf-c/example.pb-c.o
+DRIVER2_TARGET = $(DRIVER2_NAME).sys
+
+USERSPACE0_NAME = userspace_client
+USERSPACE0_OBJECTS = $(USERSPACE0_NAME).o
+USERSPACE0_TARGET = $(USERSPACE0_NAME).exe
+
+USERSPACE1_NAME = userspace_client_protobuf
+USERSPACE1_OBJECTS = $(USERSPACE1_NAME).o protobuf-c/protobuf-c.o protobuf-c/example.pb-c.o
+USERSPACE1_TARGET = $(USERSPACE1_NAME).exe
# mingw-w64-dpp related
CFLAGS_protobuf-c/protobuf-c.o = -Wno-unused-but-set-variable
-CUSTOM_CFLAGS = $(DRIVER0_CFLAGS)
+CUSTOM_CFLAGS = -I. -Wl,--exclude-all-symbols -DNDEBUG
DRIVER_LIBS += -lnetio
USER_LIBS += -lws2_32
-all: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(USERSPACE_TARGET)
+all: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET)
%.o: %.cpp
$(call BUILD_CPP_OBJECT,$<,$@)
@@ -38,21 +44,32 @@ $(DRIVER0_TARGET): $(DRIVER0_OBJECTS)
$(DRIVER1_TARGET): $(DRIVER1_OBJECTS)
$(call LINK_CPP_KERNEL_TARGET,$(DRIVER1_OBJECTS),$@)
-$(USERSPACE_TARGET): $(USERSPACE_OBJECTS)
- $(call LINK_CPP_USER_TARGET,$(USERSPACE_OBJECTS),$@)
+$(DRIVER2_TARGET): $(DRIVER2_OBJECTS)
+ $(call LINK_CPP_KERNEL_TARGET,$(DRIVER2_OBJECTS),$@)
+
+$(USERSPACE0_TARGET): $(USERSPACE0_OBJECTS)
+ $(call LINK_CPP_USER_TARGET,$(USERSPACE0_OBJECTS),$@)
+
+$(USERSPACE1_TARGET): $(USERSPACE1_OBJECTS)
+ $(call LINK_CPP_USER_TARGET,$(USERSPACE1_OBJECTS),$@)
-install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(USERSPACE_TARGET)
+install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET)
$(call INSTALL_EXEC_SIGN,$(DRIVER0_TARGET))
$(call INSTALL_EXEC_SIGN,$(DRIVER1_TARGET))
- $(call INSTALL_EXEC,$(USERSPACE_TARGET))
+ $(call INSTALL_EXEC_SIGN,$(DRIVER2_TARGET))
+ $(call INSTALL_EXEC,$(USERSPACE0_TARGET))
+ $(call INSTALL_EXEC,$(USERSPACE1_TARGET))
$(INSTALL) '$(DRIVER0_NAME).bat' '$(DESTDIR)/'
$(INSTALL) '$(DRIVER1_NAME).bat' '$(DESTDIR)/'
+ $(INSTALL) '$(DRIVER2_NAME).bat' '$(DESTDIR)/'
clean:
- rm -f $(DRIVER0_OBJECTS) $(DRIVER1_OBJECTS)
- rm -f $(DRIVER0_TARGET) $(DRIVER0_TARGET).map $(DRIVER1_TARGET) $(DRIVER1_TARGET).map
- rm -f $(USERSPACE_OBJECTS)
- rm -f $(USERSPACE_TARGET)
+ rm -f $(DRIVER0_OBJECTS) $(DRIVER1_OBJECTS) $(DRIVER2_OBJECTS)
+ rm -f $(DRIVER0_TARGET) $(DRIVER0_TARGET).map \
+ $(DRIVER1_TARGET) $(DRIVER1_TARGET).map \
+ $(DRIVER2_TARGET) $(DRIVER2_TARGET).map
+ rm -f $(USERSPACE0_OBJECTS) $(USERSPACE1_OBJECTS)
+ rm -f $(USERSPACE0_TARGET) $(USERSPACE1_TARGET)
.NOTPARALLEL: clean
.PHONY: all install clean
diff --git a/README.md b/README.md
index 29cb1e8..1bd630b 100644
--- a/README.md
+++ b/README.md
@@ -26,13 +26,15 @@ make DPP_ROOT="[path-to-mingw-w64-dpp-template-dir]" DESTDIR="[path-to-install-d
The directory `[path-to-install-dir]` should now contain three new files:
- * `driver.bat` / `driver-protobuf-c.bat`: setup the driver service, start it, stop it when it's done and delete it
- * `driver.sys`: example driver that uses kernel sockets
- * `userspace_client.exe`: example userspace application which communicates with the driver via TCP socket
+ * `driver.bat` / `driver-protobuf-c.bat` / `driver-protobuf-c-tcp.bat`: setup the driver service, start it, stop it when it's done and delete it
+ * `driver.sys`: example driver that uses kernel sockets (used together with `userspace_client.exe`)
+ * `userspace_client.exe`: example userspace application which communicates with the driver via TCP
* `driver-protobuf-c.sys`: example driver that make use of protobuf-c (local, no TCP/IP)
+ * `driver-protobuf-c-tcp.sys`: example driver that make use of protobuf-c via TCP/IP (used together with `userspace_client_protobuf.exe`)
+ * `userspace_client_protobuf.exe`: example userspace application which leverages protocol buffers to communicate with the driver via TCP
-Start `driver.bat` as `Administrator` and then `userspace_client.exe`.
+Start `*.bat` as `Administrator`.
-If everything works fine, there should be a text displayed in `userspace_client.exe` console window, received from the driver.
+If everything works fine, there should be a text displayed in `userspace_client.exe` / `userspace_client_protobuf.exe` console window, received from the driver.
For more debug output, it is recommended to use a debugger or log viewer like `dbgview`.
diff --git a/common.hpp b/common.hpp
new file mode 100644
index 0000000..0cbbe6b
--- /dev/null
+++ b/common.hpp
@@ -0,0 +1,308 @@
+#include "protobuf-c/example.pb-c.h"
+
+#include <EASTL/algorithm.h>
+#include <EASTL/array.h>
+#include <EASTL/initializer_list.h>
+#include <EASTL/string.h>
+#include <EASTL/vector.h>
+
+#ifndef SOCKET_ERROR
+#define SOCKET_ERROR -1
+#endif
+
+#define SEND_ALL(sock, socket_buffer, retval) \
+ if (socket_buffer.GetRemainingSendSize() > 0) { \
+ do { \
+ retval = send(sock, socket_buffer.GetStartS(), \
+ socket_buffer.GetRemainingSendSize(), 0); \
+ if (retval == SOCKET_ERROR || retval == 0) \
+ break; \
+ if (!socket_buffer.Consume(iResult)) \
+ break; \
+ } while (!socket_buffer.AllConsumed()); \
+ socket_buffer.Sweep(); \
+ }
+#define RECV_PDU_BEGIN(sock, socket_buffer, retval, pdu_type, pdu_len) \
+ do { \
+ if (socket_buffer.GetRemainingRecvSize() == 0) \
+ break; \
+ uint16_t pdu_type; \
+ uint32_t pdu_len; \
+ do { \
+ retval = recv(sock, socket_buffer.GetStartS(), \
+ socket_buffer.GetRemainingRecvSize(), 0); \
+ if (retval == SOCKET_ERROR || retval == 0) \
+ break; \
+ } while (!socket_buffer.GetPdu(retval, pdu_type, pdu_len));
+#define RECV_PDU_END(socket_buffer, pdu_len) \
+ socket_buffer.Consume(pdu_len); \
+ socket_buffer.Sweep(); \
+ } \
+ while (0)
+
+class BaseSerializer {
+public:
+ virtual uint16_t GetPduType(void) = 0;
+ virtual size_t GetSerializedSize(void) = 0;
+ virtual size_t Serialize(uint8_t *buf) = 0;
+};
+
+class BaseDeserializer {
+public:
+ virtual bool Deserialize(size_t pdu_len, uint8_t *buf) = 0;
+ virtual void DeserializeFree(void) = 0;
+};
+
+template <size_t SIZE> class SocketBuffer {
+public:
+ SocketBuffer(void) {}
+ ~SocketBuffer(void) {}
+ size_t GetRemainingSendSize(void) { return GetUsed(); }
+ size_t GetRemainingRecvSize(void) { return GetSize() - GetUsed(); }
+ uint8_t *GetStart(void) { return buffer + consumed; }
+ char *GetStartS(void) { return (char *)buffer + consumed; }
+ bool SizeCheck(size_t required_size) { return used + required_size < SIZE; };
+ template <typename T> bool GetPrimitve(size_t offset, T &out) {
+ if (offset + sizeof(out) > used)
+ return false;
+ out = *(T *)(GetStart() + offset);
+ return true;
+ }
+ template <typename T> bool AddPrimitve(T value) {
+ if (!SizeCheck(sizeof(value)))
+ return false;
+ *(T *)(GetEnd()) = value;
+ used += sizeof(value);
+ return true;
+ }
+ bool GetPdu(size_t received_size, uint16_t &pdu_type, uint32_t &pdu_len) {
+ if (received_size > SIZE - used)
+ return false; // You did something wrong!
+ used += received_size;
+ if (used < sizeof(pdu_type) + sizeof(pdu_len))
+ return false;
+ if (GetPrimitve<uint32_t>(0, pdu_len) == false)
+ return false;
+ if (GetPrimitve<uint16_t>(4, pdu_type) == false)
+ return false;
+
+ pdu_len = ntohl(pdu_len);
+ pdu_type = ntohs(pdu_type);
+
+ if (used < sizeof(pdu_type) + sizeof(pdu_len) + pdu_len)
+ return false;
+
+ consumed += sizeof(pdu_type) + sizeof(pdu_len);
+
+ return true;
+ }
+ bool AddPdu(BaseSerializer &bs) {
+ uint16_t pdu_type = bs.GetPduType();
+ uint32_t pdu_len = bs.GetSerializedSize();
+
+ if (!SizeCheck(sizeof(pdu_type) + sizeof(pdu_len) + pdu_len))
+ return false;
+
+ if (!AddPrimitve<uint32_t>(htonl(pdu_len)))
+ return false;
+ if (!AddPrimitve<uint16_t>(htons(pdu_type)))
+ return false;
+ if (bs.Serialize(GetEnd()) != pdu_len)
+ return false;
+ used += pdu_len;
+
+ return true;
+ }
+ bool AllConsumed(void) { return used == consumed; }
+ bool Consume(size_t consuming_size) {
+ if (consuming_size + consumed > used)
+ return false;
+ consumed += consuming_size;
+ return true;
+ }
+ void Sweep() {
+ if (consumed == 0)
+ return;
+ if (used != consumed)
+ memmove(buffer, buffer + consumed, used - consumed);
+ used -= consumed;
+ consumed = 0;
+ }
+
+private:
+ uint8_t buffer[SIZE];
+ size_t used = 0;
+ size_t consumed = 0;
+
+ size_t GetSize(void) { return SIZE; }
+ size_t GetUsed(void) { return used - consumed; }
+ uint8_t *GetEnd() { return buffer + used; }
+};
+
+class ProtobufCBinaryDataClass : public ProtobufCBinaryData {
+public:
+ ProtobufCBinaryDataClass(std::initializer_list<uint8_t> bytes) {
+ len = bytes.size();
+ data = new uint8_t[len];
+ eastl::copy(bytes.begin(), bytes.end(), data);
+ }
+ ~ProtobufCBinaryDataClass(void) { delete data; }
+};
+
+/* Could be also be done via protobuf, but I decided for speed/memory
+ * efficiency. */
+enum PduTypes {
+ PDU_SOMETHING_WITH_UINTS = 1,
+ PDU_SOMETHING_MORE = 2,
+ PDU_EVEN_MORE = 3
+};
+
+class SomethingWithUINTsSerializer : virtual public BaseSerializer {
+public:
+ SomethingWithUINTsSerializer(void) {}
+ ~SomethingWithUINTsSerializer(void) {}
+ void SetId(uint32_t id) {
+ swu.has_id = TRUE;
+ swu.id = id;
+ }
+ void SetIpAddress(uint32_t ip_address) {
+ swu.has_ip_address = TRUE;
+ swu.ip_address = ip_address;
+ }
+ void SetPortNum(uint32_t port_num) {
+ swu.has_port_num = TRUE;
+ swu.port_num = port_num;
+ }
+ uint16_t GetPduType(void) override { return PDU_SOMETHING_WITH_UINTS; }
+ size_t GetSerializedSize(void) override {
+ return something_with_uints__get_packed_size(&swu);
+ }
+ size_t Serialize(uint8_t *buf) override {
+ return something_with_uints__pack(&swu, buf);
+ }
+
+ SomethingWithUINTs swu = SOMETHING_WITH_UINTS__INIT;
+};
+
+class SomethingWithUINTsDeserializer : virtual public BaseDeserializer {
+public:
+ SomethingWithUINTsDeserializer(void) {}
+ ~SomethingWithUINTsDeserializer(void) { DeserializeFree(); }
+ bool Deserialize(size_t pdu_len, uint8_t *buf) override {
+ swu = something_with_uints__unpack(NULL, pdu_len, buf);
+ return swu != NULL;
+ }
+ void DeserializeFree(void) override {
+ if (swu != NULL)
+ something_with_uints__free_unpacked(swu, NULL);
+ }
+
+ SomethingWithUINTs *swu = NULL;
+};
+
+class SomethingMoreSerializer : virtual public BaseSerializer,
+ public SomethingWithUINTsSerializer {
+public:
+ SomethingMoreSerializer(void) { sm.uints = &swu; }
+ ~SomethingMoreSerializer(void) {}
+ void SetErrorCode(SomethingMore__Errors error_code) {
+ sm.has_error_code = TRUE;
+ sm.error_code = error_code;
+ }
+ uint16_t GetPduType(void) override { return PDU_SOMETHING_MORE; }
+ size_t GetSerializedSize(void) override {
+ return something_more__get_packed_size(&sm);
+ }
+ size_t Serialize(uint8_t *buf) override {
+ return something_more__pack(&sm, buf);
+ }
+
+ SomethingMore sm = SOMETHING_MORE__INIT;
+};
+
+class SomethingMoreDeserializer : virtual public BaseDeserializer {
+public:
+ SomethingMoreDeserializer(void) {}
+ ~SomethingMoreDeserializer(void) { DeserializeFree(); }
+ bool Deserialize(size_t pdu_len, uint8_t *buf) override {
+ sm = something_more__unpack(NULL, pdu_len, buf);
+ return sm != NULL;
+ }
+ void DeserializeFree(void) override {
+ if (sm != NULL)
+ something_more__free_unpacked(sm, NULL);
+ }
+
+ SomethingMore *sm = NULL;
+};
+
+class EvenMoreSerializer : virtual public BaseSerializer {
+public:
+ explicit EvenMoreSerializer(void) = default;
+ explicit EvenMoreSerializer(EvenMore__SomeEnum enum_value,
+ ProtobufCBinaryDataClass name,
+ ProtobufCBinaryDataClass value)
+ : uints() {
+ em.enum_value = enum_value;
+ em.name = name;
+ em.value = value;
+ }
+ ~EvenMoreSerializer(void) { free(s); }
+ void AddUints(SomethingWithUINTsSerializer *uints) {
+ this->uints.push_back(uints);
+ }
+ void SetS(eastl::string s) {
+ size_t l = s.size();
+ this->s = (char *)malloc(l + 1);
+ memcpy(this->s, s.c_str(), l);
+ this->s[l] = '\0';
+ }
+ uint16_t GetPduType(void) override { return PDU_EVEN_MORE; }
+ size_t GetSerializedSize(void) override {
+ em.s = s;
+ em.n_uints = this->uints.size();
+ if (em.n_uints > 0) {
+ SomethingWithUINTs *out[em.n_uints];
+ ConvertUintsVectorToCArray(out);
+ return even_more__get_packed_size(&em);
+ }
+ return even_more__get_packed_size(&em);
+ }
+ size_t Serialize(uint8_t *buf) override {
+ em.s = s;
+ em.n_uints = this->uints.size();
+ if (em.n_uints > 0) {
+ SomethingWithUINTs *out[em.n_uints];
+ ConvertUintsVectorToCArray(out);
+ return even_more__pack(&em, buf);
+ }
+ return even_more__pack(&em, buf);
+ }
+
+ EvenMore em = EVEN_MORE__INIT;
+ eastl::vector<SomethingWithUINTsSerializer *> uints;
+ char *s = NULL;
+
+ void ConvertUintsVectorToCArray(SomethingWithUINTs **out) {
+ for (size_t i = 0; i < em.n_uints; ++i) {
+ out[i] = &uints[i]->swu;
+ }
+ em.uints = out;
+ }
+};
+
+class EvenMoreDeserializer : virtual public BaseDeserializer {
+public:
+ EvenMoreDeserializer(void) {}
+ ~EvenMoreDeserializer(void) { DeserializeFree(); }
+ bool Deserialize(size_t pdu_len, uint8_t *buf) override {
+ em = even_more__unpack(NULL, pdu_len, buf);
+ return em != NULL;
+ }
+ void DeserializeFree(void) override {
+ if (em != NULL)
+ even_more__free_unpacked(em, NULL);
+ }
+
+ EvenMore *em = NULL;
+};
diff --git a/driver-protobuf-c-tcp.bat b/driver-protobuf-c-tcp.bat
new file mode 100644
index 0000000..9aa935b
--- /dev/null
+++ b/driver-protobuf-c-tcp.bat
@@ -0,0 +1,27 @@
+@echo off
+set SERVICE_NAME=protobuf-c-tcp
+set DRIVER="%~dp0\driver-protobuf-c-tcp.sys"
+
+net session >nul 2>&1
+if NOT %ERRORLEVEL% EQU 0 (
+ echo ERROR: This script requires Administrator privileges!
+ pause
+ exit /b 1
+)
+
+echo ---------------------------------------
+echo -- Service Name: %SERVICE_NAME%
+echo -- Driver......: %DRIVER%
+echo ---------------------------------------
+
+sc create %SERVICE_NAME% binPath= %DRIVER% type= kernel
+echo ---------------------------------------
+sc start %SERVICE_NAME%
+echo ---------------------------------------
+sc query %SERVICE_NAME%
+echo [PRESS A KEY TO STOP THE DRIVER]
+pause
+sc stop %SERVICE_NAME%
+sc delete %SERVICE_NAME%
+echo Done.
+timeout /t 3
diff --git a/driver-protobuf-c-tcp.cpp b/driver-protobuf-c-tcp.cpp
new file mode 100644
index 0000000..0808b88
--- /dev/null
+++ b/driver-protobuf-c-tcp.cpp
@@ -0,0 +1,139 @@
+#include "berkeley.h"
+#include "ksocket.h"
+#include "protobuf-c/example.pb-c.h"
+#include "wsk.h"
+
+#include "common.hpp"
+
+extern "C" {
+DRIVER_INITIALIZE DriverEntry;
+DRIVER_UNLOAD DriverUnload;
+
+#define DebuggerPrint(...) \
+ DbgPrintEx(DPFLTR_IHVDRIVER_ID, DPFLTR_ERROR_LEVEL, __VA_ARGS__);
+
+NTSTATUS
+NTAPI
+DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
+ _In_ PUNICODE_STRING RegistryPath) {
+ UNREFERENCED_PARAMETER(DriverObject);
+ UNREFERENCED_PARAMETER(RegistryPath);
+
+ NTSTATUS Status;
+
+ DebuggerPrint("Hi.\n");
+ Status = KsInitialize();
+
+ if (!NT_SUCCESS(Status)) {
+ return Status;
+ }
+
+ int server_sockfd = socket_listen(AF_INET, SOCK_STREAM, 0);
+
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ addr.sin_port = htons(9095);
+
+ int result = bind(server_sockfd, (struct sockaddr *)&addr, sizeof(addr));
+ if (result != 0) {
+ DebuggerPrint("TCP server bind failed\n");
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ result = listen(server_sockfd, 1);
+ if (result != 0) {
+ DebuggerPrint("TCP server listen failed\n");
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ socklen_t addrlen = sizeof(addr);
+ int client_sockfd = accept(server_sockfd, (struct sockaddr *)&addr, &addrlen);
+ if (client_sockfd < 0) {
+ DebuggerPrint("TCP accept failed\n");
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ int iResult;
+ SocketBuffer<1024> sb_send, sb_recv;
+
+ do {
+ bool ok = false;
+ RECV_PDU_BEGIN(client_sockfd, sb_recv, iResult, pdu_type, pdu_len) {
+ DebuggerPrint("PDU type/len: %u/%u\n", pdu_type, pdu_len);
+ switch ((enum PduTypes)pdu_type) {
+ case PDU_SOMETHING_WITH_UINTS: {
+ SomethingWithUINTsDeserializer swud;
+ if ((ok = swud.Deserialize(pdu_len, sb_recv.GetStart())) == true) {
+ SomethingWithUINTsSerializer swus;
+ if (swud.swu->has_id == TRUE) {
+ DebuggerPrint("Id: 0x%X\n", swud.swu->id);
+ swus.SetId(swud.swu->id + 1);
+ }
+ ok = sb_send.AddPdu(swus);
+ }
+ break;
+ }
+ case PDU_SOMETHING_MORE: {
+ SomethingMoreDeserializer smd;
+ if ((ok = smd.Deserialize(pdu_len, sb_recv.GetStart())) == true) {
+ SomethingMoreSerializer sms;
+ if (smd.sm->has_error_code == TRUE) {
+ DebuggerPrint("Error Code: %u\n", smd.sm->error_code);
+ }
+ if (smd.sm->uints->has_id == TRUE) {
+ DebuggerPrint("Id: 0x%X\n", smd.sm->uints->id);
+ sms.SetId(smd.sm->uints->id + 1);
+ }
+ sms.SetErrorCode(SOMETHING_MORE__ERRORS__SUCCESS);
+ sms.SetIpAddress(0xCCCCCCCC);
+ sms.SetPortNum(0xDDDDDDDD);
+ ok = sb_send.AddPdu(sms);
+ }
+ break;
+ }
+ case PDU_EVEN_MORE: {
+ EvenMoreDeserializer emd;
+ if ((ok = emd.Deserialize(pdu_len, sb_recv.GetStart())) == true) {
+ DebuggerPrint("EnumValue: %d\n", emd.em->enum_value);
+ if (emd.em->s != NULL) {
+ DebuggerPrint("String: '%s'\n", emd.em->s);
+ }
+ EvenMoreSerializer ems;
+ SomethingWithUINTsSerializer swus;
+ swus.SetId(0xDEADDEAD);
+ ems.SetS("Hi userspace!");
+ ems.AddUints(&swus);
+ ok = sb_send.AddPdu(ems);
+ }
+ break;
+ }
+ }
+ }
+ RECV_PDU_END(sb_recv, pdu_len);
+
+ if (ok == true) {
+ SEND_ALL(client_sockfd, sb_send, iResult);
+ if (iResult == SOCKET_ERROR || iResult == 0) {
+ DebuggerPrint("send failed\n");
+ break;
+ }
+ } else {
+ DebuggerPrint("Serialization/Deserialization failed\n");
+ break;
+ }
+ } while (iResult != SOCKET_ERROR && iResult > 0);
+
+ DebuggerPrint("Client gone.\n") closesocket(client_sockfd);
+ closesocket(server_sockfd);
+ KsDestroy();
+
+ return STATUS_SUCCESS;
+}
+
+VOID DriverUnload(_In_ struct _DRIVER_OBJECT *DriverObject) {
+ UNREFERENCED_PARAMETER(DriverObject);
+
+ DebuggerPrint("Bye.\n");
+}
+}
diff --git a/driver-protobuf-c.cpp b/driver-protobuf-c.cpp
index 0ba11c7..762137c 100644
--- a/driver-protobuf-c.cpp
+++ b/driver-protobuf-c.cpp
@@ -1,10 +1,11 @@
-
-extern "C" {
#include "berkeley.h"
#include "ksocket.h"
#include "protobuf-c/example.pb-c.h"
#include "wsk.h"
+#include "common.hpp"
+
+extern "C" {
DRIVER_INITIALIZE DriverEntry;
DRIVER_UNLOAD DriverUnload;
@@ -19,91 +20,78 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
UNREFERENCED_PARAMETER(RegistryPath);
size_t len = 0;
- uint8_t *buf = NULL;
{
- SomethingMore sm = SOMETHING_MORE__INIT;
- SomethingWithUINTs swu = SOMETHING_WITH_UINTS__INIT;
-
- sm.error_code = SOMETHING_MORE__ERRORS__SUCCESS;
- sm.uints = &swu;
- swu.has_id = TRUE;
- swu.id = 0x12345678;
- swu.has_ip_address = TRUE;
- swu.ip_address = 0xAAAAAAAA;
- swu.has_port_num = TRUE;
- swu.port_num = 0xBBBBBBBB;
-
- len = something_more__get_packed_size(&sm);
- buf = (uint8_t *)malloc(len);
- if (something_more__pack(&sm, buf) != len) {
- DebuggerPrint("Packing failed.\n");
- }
- }
+ uint8_t *buf = NULL;
+
+ {
+ SomethingMoreSerializer sms;
+ sms.SetErrorCode(SOMETHING_MORE__ERRORS__SUCCESS);
+ sms.SetId(0x12345678);
+ sms.SetIpAddress(0xAAAAAAAA);
+ sms.SetPortNum(0xBBBBBBBB);
+ len = sms.GetSerializedSize();
+ buf = (uint8_t *)malloc(len);
+ if (buf == NULL) {
+ return STATUS_UNSUCCESSFUL;
+ }
+ if (sms.Serialize(buf) != len) {
+ DebuggerPrint("Packing failed.\n");
+ free(buf);
+ return STATUS_UNSUCCESSFUL;
+ }
- DebuggerPrint("Packed Size: %zu\n", len);
+ DebuggerPrint("Packed Size: %zu\n", len);
- {
- SomethingMore *smth = something_more__unpack(NULL, len, buf);
- if (smth != NULL && smth->uints != NULL && smth->uints->id == 0x12345678 &&
- smth->uints->ip_address == 0xAAAAAAAA &&
- smth->uints->port_num == 0xBBBBBBBB) {
- DebuggerPrint("Success!\n");
+ SomethingMoreDeserializer smd;
+ if (smd.Deserialize(len, buf) == true && smd.sm->uints != NULL &&
+ smd.sm->uints->id == 0x12345678 &&
+ smd.sm->uints->ip_address == 0xAAAAAAAA &&
+ smd.sm->uints->port_num == 0xBBBBBBBB) {
+ DebuggerPrint("Success!\n");
+ }
}
- DebuggerPrint("id: %x, ip_address: %x, port_num: %x\n", smth->uints->id,
- smth->uints->ip_address, smth->uints->port_num);
- something_more__free_unpacked(smth, NULL);
- }
- free(buf);
+ free(buf);
+ }
{
- EvenMore em = EVEN_MORE__INIT;
- SomethingWithUINTs swu[] = {SOMETHING_WITH_UINTS__INIT,
- SOMETHING_WITH_UINTS__INIT,
- SOMETHING_WITH_UINTS__INIT};
- SomethingWithUINTs *swu_vec[] = {&swu[0], &swu[1], &swu[2]};
- uint8_t bin[] = {0xde, 0xad, 0xc0, 0xde};
- ProtobufCBinaryData pbin = {.len = sizeof(bin), .data = bin};
- char str[] = "This is a zero-terminated String!";
-
- em.enum_value = EVEN_MORE__SOME_ENUM__FIRST;
- swu[0].has_id = TRUE;
- swu[0].id = 0xdeadc0de;
- swu[1].has_ip_address = TRUE;
- swu[1].ip_address = 0xdeadbeef;
- swu[2].has_port_num = TRUE;
- swu[2].port_num = 0xcafecafe;
- em.n_uints = sizeof(swu) / sizeof(swu[0]);
- em.uints = swu_vec;
- em.name = pbin;
- em.value = pbin;
- em.s = str;
-
- len = even_more__get_packed_size(&em);
+ EvenMoreSerializer ems(EVEN_MORE__SOME_ENUM__FIRST,
+ {0xde, 0xad, 0xc0, 0xde}, {0xde, 0xad, 0xc0, 0xde});
+ SomethingWithUINTsSerializer sws[3];
+
+ sws[0].SetId(0xdeadc0de);
+ sws[1].SetIpAddress(0xdeadbeef);
+ sws[2].SetPortNum(0xcafecafe);
+
+ ems.SetS("This is a zero-terminated String!");
+ ems.AddUints(&sws[0]);
+ ems.AddUints(&sws[1]);
+ ems.AddUints(&sws[2]);
+
+ len = ems.GetSerializedSize();
uint8_t tmp[len];
- if (even_more__pack(&em, tmp) != len) {
+ if (ems.Serialize(tmp) != len) {
DebuggerPrint("Packing failed.\n");
}
DebuggerPrint("Packed Size: %zu\n", len);
- EvenMore *result = even_more__unpack(NULL, len, tmp);
- if (result != NULL && result->n_uints == 3 && result->uints != NULL &&
- result->name.len > 0 && result->name.data != NULL &&
- result->value.len > 0 && result->value.data != NULL &&
- result->s != NULL) {
- if (result->enum_value != EVEN_MORE__SOME_ENUM__FIRST ||
- result->uints[0]->has_id != TRUE ||
- result->uints[0]->id != 0xdeadc0de ||
- result->uints[1]->has_ip_address != TRUE ||
- result->uints[1]->ip_address != 0xdeadbeef ||
- result->uints[1]->has_port_num != TRUE ||
- result->uints[2]->port_num != 0xcafecafe) {
+ EvenMoreDeserializer emd;
+ if (emd.Deserialize(len, tmp) == true && emd.em->n_uints == 3 &&
+ emd.em->uints != NULL && emd.em->name.len > 0 &&
+ emd.em->name.data != NULL && emd.em->value.len > 0 &&
+ emd.em->value.data != NULL && emd.em->s != NULL) {
+ if (emd.em->enum_value == EVEN_MORE__SOME_ENUM__FIRST ||
+ emd.em->uints[0]->has_id == TRUE ||
+ emd.em->uints[0]->id == 0xdeadc0de ||
+ emd.em->uints[1]->has_ip_address == TRUE ||
+ emd.em->uints[1]->ip_address == 0xdeadbeef ||
+ emd.em->uints[2]->has_port_num == TRUE ||
+ emd.em->uints[2]->port_num == 0xcafecafe) {
DebuggerPrint("Success!\n");
}
- DebuggerPrint("Deserialized String: '%s'\n", result->s);
+ DebuggerPrint("Deserialized String: '%s'\n", emd.em->s);
}
- even_more__free_unpacked(result, NULL);
}
return STATUS_SUCCESS;
diff --git a/driver.cpp b/driver.cpp
index b7ecbc1..ceae8ff 100644
--- a/driver.cpp
+++ b/driver.cpp
@@ -117,6 +117,10 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
socklen_t addrlen = sizeof(addr);
int client_sockfd =
accept(server_sockfd, (struct sockaddr *)&addr, &addrlen);
+ if (client_sockfd < 0) {
+ DebuggerPrint("TCP accept failed\n");
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
result = recv(client_sockfd, recv_buffer, sizeof(recv_buffer) - 1, 0);
if (result > 0) {
diff --git a/echo_srv.py b/echo_srv.py
new file mode 100644
index 0000000..b3855eb
--- /dev/null
+++ b/echo_srv.py
@@ -0,0 +1,16 @@
+import socket
+
+HOST = '127.0.0.1'
+PORT = 9095
+
+with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind((HOST, PORT))
+ s.listen()
+ conn, addr = s.accept()
+ with conn:
+ while True:
+ data = conn.recv(1024)
+ print('recvd: {} bytes'.format(len(data)))
+ if len(data) == 0:
+ break;
+ conn.sendall(data)
diff --git a/userspace_client_protobuf.cpp b/userspace_client_protobuf.cpp
new file mode 100644
index 0000000..7ee674e
--- /dev/null
+++ b/userspace_client_protobuf.cpp
@@ -0,0 +1,152 @@
+#include <stdio.h>
+#include <stdlib.h> // Needed for _wtoi
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#include "common.hpp"
+#include "protobuf-c/example.pb-c.h"
+
+int main(int argc, char **argv) {
+ WSADATA wsaData = {};
+ int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
+
+ UNREFERENCED_PARAMETER(argc);
+ UNREFERENCED_PARAMETER(argv);
+
+ if (iResult != 0) {
+ wprintf(L"WSAStartup failed: %d\n", iResult);
+ return 1;
+ }
+
+ SOCKET sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+
+ if (sock == INVALID_SOCKET) {
+ wprintf(L"socket function failed with error = %d\n", WSAGetLastError());
+ } else {
+ wprintf(L"socket function succeeded\n");
+ }
+
+ sockaddr_in clientService;
+ clientService.sin_family = AF_INET;
+ clientService.sin_addr.s_addr = inet_addr("127.0.0.1");
+ clientService.sin_port = htons(9095);
+
+ do {
+ iResult = connect(sock, (SOCKADDR *)&clientService, sizeof(clientService));
+ if (iResult == SOCKET_ERROR) {
+ wprintf(L"connect function failed with error: %ld\n", WSAGetLastError());
+ Sleep(1000);
+ }
+ } while (iResult == SOCKET_ERROR);
+
+ wprintf(L"Connected to server.\n");
+
+ uint32_t start_id = 0x12345678;
+ for (size_t i = 0; i < 256; ++i) {
+ SocketBuffer<1024> sb_send, sb_recv;
+ SomethingMoreSerializer sms;
+
+ sms.SetErrorCode(SOMETHING_MORE__ERRORS__SUCCESS);
+ sms.SetId(start_id++);
+ sms.SetIpAddress(0xAAAAAAAA);
+ sms.SetPortNum(0xBBBBBBBB);
+
+ if (!sb_send.AddPdu(sms)) {
+ wprintf(L"Serialization failed\n");
+ break;
+ }
+
+ SEND_ALL(sock, sb_send, iResult);
+ if (iResult == SOCKET_ERROR || iResult == 0) {
+ wprintf(L"send failed with error: %d\n", WSAGetLastError());
+ break;
+ }
+
+ RECV_PDU_BEGIN(sock, sb_recv, iResult, pdu_type, pdu_len) {
+ wprintf(L"PDU type/len: %u/%u\n", pdu_type, pdu_len);
+ switch ((enum PduTypes)pdu_type) {
+ case PDU_SOMETHING_WITH_UINTS: {
+ break;
+ }
+ case PDU_SOMETHING_MORE: {
+ SomethingMoreDeserializer smd;
+ if (smd.Deserialize(pdu_len, sb_recv.GetStart()) == true &&
+ smd.sm->uints != NULL && smd.sm->uints->has_id == TRUE &&
+ smd.sm->uints->has_ip_address == TRUE &&
+ smd.sm->uints->has_port_num == TRUE)
+ wprintf(L"Id: 0x%X, IpAddress: 0x%X, PortNum: 0x%X\n",
+ smd.sm->uints->id, smd.sm->uints->ip_address,
+ smd.sm->uints->port_num);
+ break;
+ }
+ case PDU_EVEN_MORE: {
+ break;
+ }
+ }
+ }
+ RECV_PDU_END(sb_recv, pdu_len);
+
+ ////////////////////////////////////////////////////////
+
+ EvenMoreSerializer ems(EVEN_MORE__SOME_ENUM__FIRST,
+ {0xde, 0xad, 0xc0, 0xde}, {0xde, 0xad, 0xc0, 0xde});
+ SomethingWithUINTsSerializer swus[3];
+
+ swus[0].SetId(0xdeadc0de);
+ swus[1].SetIpAddress(0xdeadbeef);
+ swus[2].SetPortNum(0xcafecafe);
+
+ ems.SetS("This is a zero-terminated String!");
+ ems.AddUints(&swus[0]);
+ ems.AddUints(&swus[1]);
+ ems.AddUints(&swus[2]);
+
+ if (!sb_send.AddPdu(ems)) {
+ wprintf(L"Serialization failed\n");
+ break;
+ }
+
+ SEND_ALL(sock, sb_send, iResult);
+ if (iResult == SOCKET_ERROR || iResult == 0) {
+ wprintf(L"send failed with error: %d\n", WSAGetLastError());
+ break;
+ }
+
+ RECV_PDU_BEGIN(sock, sb_recv, iResult, pdu_type, pdu_len) {
+ wprintf(L"PDU type/len: %u/%u\n", pdu_type, pdu_len);
+ switch ((enum PduTypes)pdu_type) {
+ case PDU_SOMETHING_WITH_UINTS: {
+ break;
+ }
+ case PDU_SOMETHING_MORE: {
+ break;
+ }
+ case PDU_EVEN_MORE: {
+ EvenMoreDeserializer emd;
+ if (emd.Deserialize(pdu_len, sb_recv.GetStart()) == true) {
+ wprintf(L"EnumValue: %d\n", emd.em->enum_value);
+ if (emd.em->s != NULL)
+ wprintf(L"String: '%s'\n", emd.em->s);
+ }
+ break;
+ }
+ }
+ }
+ RECV_PDU_END(sb_recv, pdu_len);
+ }
+ wprintf(L"Closing Connection ..\n");
+
+ iResult = closesocket(sock);
+ if (iResult == SOCKET_ERROR) {
+ wprintf(L"closesocket function failed with error: %ld\n",
+ WSAGetLastError());
+ WSACleanup();
+ return 1;
+ }
+
+ WSACleanup();
+
+ system("pause");
+
+ return 0;
+}