diff options
-rw-r--r-- | Makefile | 47 | ||||
-rw-r--r-- | README.md | 12 | ||||
-rw-r--r-- | common.hpp | 308 | ||||
-rw-r--r-- | driver-protobuf-c-tcp.bat | 27 | ||||
-rw-r--r-- | driver-protobuf-c-tcp.cpp | 139 | ||||
-rw-r--r-- | driver-protobuf-c.cpp | 130 | ||||
-rw-r--r-- | driver.cpp | 4 | ||||
-rw-r--r-- | echo_srv.py | 16 | ||||
-rw-r--r-- | userspace_client_protobuf.cpp | 152 |
9 files changed, 744 insertions, 91 deletions
@@ -7,24 +7,30 @@ include $(DPP_ROOT)/Makefile.inc DRIVER0_NAME = driver DRIVER0_OBJECTS = $(DRIVER0_NAME).o ksocket.o berkeley.o DRIVER0_TARGET = $(DRIVER0_NAME).sys -DRIVER0_CFLAGS = -I. -Wl,--exclude-all-symbols -DNDEBUG DRIVER1_NAME = driver-protobuf-c DRIVER1_OBJECTS = $(DRIVER1_NAME).o protobuf-c/protobuf-c.o protobuf-c/example.pb-c.o DRIVER1_TARGET = $(DRIVER1_NAME).sys -DRIVER1_CFLAGS = -I. -Iprotobuf-c -Wl,--exclude-all-symbols -DNDEBUG -USERSPACE_NAME = userspace_client -USERSPACE_OBJECTS = $(USERSPACE_NAME).o -USERSPACE_TARGET = $(USERSPACE_NAME).exe +DRIVER2_NAME = driver-protobuf-c-tcp +DRIVER2_OBJECTS = $(DRIVER2_NAME).o ksocket.o berkeley.o protobuf-c/protobuf-c.o protobuf-c/example.pb-c.o +DRIVER2_TARGET = $(DRIVER2_NAME).sys + +USERSPACE0_NAME = userspace_client +USERSPACE0_OBJECTS = $(USERSPACE0_NAME).o +USERSPACE0_TARGET = $(USERSPACE0_NAME).exe + +USERSPACE1_NAME = userspace_client_protobuf +USERSPACE1_OBJECTS = $(USERSPACE1_NAME).o protobuf-c/protobuf-c.o protobuf-c/example.pb-c.o +USERSPACE1_TARGET = $(USERSPACE1_NAME).exe # mingw-w64-dpp related CFLAGS_protobuf-c/protobuf-c.o = -Wno-unused-but-set-variable -CUSTOM_CFLAGS = $(DRIVER0_CFLAGS) +CUSTOM_CFLAGS = -I. -Wl,--exclude-all-symbols -DNDEBUG DRIVER_LIBS += -lnetio USER_LIBS += -lws2_32 -all: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(USERSPACE_TARGET) +all: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) %.o: %.cpp $(call BUILD_CPP_OBJECT,$<,$@) @@ -38,21 +44,32 @@ $(DRIVER0_TARGET): $(DRIVER0_OBJECTS) $(DRIVER1_TARGET): $(DRIVER1_OBJECTS) $(call LINK_CPP_KERNEL_TARGET,$(DRIVER1_OBJECTS),$@) -$(USERSPACE_TARGET): $(USERSPACE_OBJECTS) - $(call LINK_CPP_USER_TARGET,$(USERSPACE_OBJECTS),$@) +$(DRIVER2_TARGET): $(DRIVER2_OBJECTS) + $(call LINK_CPP_KERNEL_TARGET,$(DRIVER2_OBJECTS),$@) + +$(USERSPACE0_TARGET): $(USERSPACE0_OBJECTS) + $(call LINK_CPP_USER_TARGET,$(USERSPACE0_OBJECTS),$@) + +$(USERSPACE1_TARGET): $(USERSPACE1_OBJECTS) + $(call LINK_CPP_USER_TARGET,$(USERSPACE1_OBJECTS),$@) -install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(USERSPACE_TARGET) +install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) $(call INSTALL_EXEC_SIGN,$(DRIVER0_TARGET)) $(call INSTALL_EXEC_SIGN,$(DRIVER1_TARGET)) - $(call INSTALL_EXEC,$(USERSPACE_TARGET)) + $(call INSTALL_EXEC_SIGN,$(DRIVER2_TARGET)) + $(call INSTALL_EXEC,$(USERSPACE0_TARGET)) + $(call INSTALL_EXEC,$(USERSPACE1_TARGET)) $(INSTALL) '$(DRIVER0_NAME).bat' '$(DESTDIR)/' $(INSTALL) '$(DRIVER1_NAME).bat' '$(DESTDIR)/' + $(INSTALL) '$(DRIVER2_NAME).bat' '$(DESTDIR)/' clean: - rm -f $(DRIVER0_OBJECTS) $(DRIVER1_OBJECTS) - rm -f $(DRIVER0_TARGET) $(DRIVER0_TARGET).map $(DRIVER1_TARGET) $(DRIVER1_TARGET).map - rm -f $(USERSPACE_OBJECTS) - rm -f $(USERSPACE_TARGET) + rm -f $(DRIVER0_OBJECTS) $(DRIVER1_OBJECTS) $(DRIVER2_OBJECTS) + rm -f $(DRIVER0_TARGET) $(DRIVER0_TARGET).map \ + $(DRIVER1_TARGET) $(DRIVER1_TARGET).map \ + $(DRIVER2_TARGET) $(DRIVER2_TARGET).map + rm -f $(USERSPACE0_OBJECTS) $(USERSPACE1_OBJECTS) + rm -f $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) .NOTPARALLEL: clean .PHONY: all install clean @@ -26,13 +26,15 @@ make DPP_ROOT="[path-to-mingw-w64-dpp-template-dir]" DESTDIR="[path-to-install-d The directory `[path-to-install-dir]` should now contain three new files: - * `driver.bat` / `driver-protobuf-c.bat`: setup the driver service, start it, stop it when it's done and delete it - * `driver.sys`: example driver that uses kernel sockets - * `userspace_client.exe`: example userspace application which communicates with the driver via TCP socket + * `driver.bat` / `driver-protobuf-c.bat` / `driver-protobuf-c-tcp.bat`: setup the driver service, start it, stop it when it's done and delete it + * `driver.sys`: example driver that uses kernel sockets (used together with `userspace_client.exe`) + * `userspace_client.exe`: example userspace application which communicates with the driver via TCP * `driver-protobuf-c.sys`: example driver that make use of protobuf-c (local, no TCP/IP) + * `driver-protobuf-c-tcp.sys`: example driver that make use of protobuf-c via TCP/IP (used together with `userspace_client_protobuf.exe`) + * `userspace_client_protobuf.exe`: example userspace application which leverages protocol buffers to communicate with the driver via TCP -Start `driver.bat` as `Administrator` and then `userspace_client.exe`. +Start `*.bat` as `Administrator`. -If everything works fine, there should be a text displayed in `userspace_client.exe` console window, received from the driver. +If everything works fine, there should be a text displayed in `userspace_client.exe` / `userspace_client_protobuf.exe` console window, received from the driver. For more debug output, it is recommended to use a debugger or log viewer like `dbgview`. diff --git a/common.hpp b/common.hpp new file mode 100644 index 0000000..0cbbe6b --- /dev/null +++ b/common.hpp @@ -0,0 +1,308 @@ +#include "protobuf-c/example.pb-c.h" + +#include <EASTL/algorithm.h> +#include <EASTL/array.h> +#include <EASTL/initializer_list.h> +#include <EASTL/string.h> +#include <EASTL/vector.h> + +#ifndef SOCKET_ERROR +#define SOCKET_ERROR -1 +#endif + +#define SEND_ALL(sock, socket_buffer, retval) \ + if (socket_buffer.GetRemainingSendSize() > 0) { \ + do { \ + retval = send(sock, socket_buffer.GetStartS(), \ + socket_buffer.GetRemainingSendSize(), 0); \ + if (retval == SOCKET_ERROR || retval == 0) \ + break; \ + if (!socket_buffer.Consume(iResult)) \ + break; \ + } while (!socket_buffer.AllConsumed()); \ + socket_buffer.Sweep(); \ + } +#define RECV_PDU_BEGIN(sock, socket_buffer, retval, pdu_type, pdu_len) \ + do { \ + if (socket_buffer.GetRemainingRecvSize() == 0) \ + break; \ + uint16_t pdu_type; \ + uint32_t pdu_len; \ + do { \ + retval = recv(sock, socket_buffer.GetStartS(), \ + socket_buffer.GetRemainingRecvSize(), 0); \ + if (retval == SOCKET_ERROR || retval == 0) \ + break; \ + } while (!socket_buffer.GetPdu(retval, pdu_type, pdu_len)); +#define RECV_PDU_END(socket_buffer, pdu_len) \ + socket_buffer.Consume(pdu_len); \ + socket_buffer.Sweep(); \ + } \ + while (0) + +class BaseSerializer { +public: + virtual uint16_t GetPduType(void) = 0; + virtual size_t GetSerializedSize(void) = 0; + virtual size_t Serialize(uint8_t *buf) = 0; +}; + +class BaseDeserializer { +public: + virtual bool Deserialize(size_t pdu_len, uint8_t *buf) = 0; + virtual void DeserializeFree(void) = 0; +}; + +template <size_t SIZE> class SocketBuffer { +public: + SocketBuffer(void) {} + ~SocketBuffer(void) {} + size_t GetRemainingSendSize(void) { return GetUsed(); } + size_t GetRemainingRecvSize(void) { return GetSize() - GetUsed(); } + uint8_t *GetStart(void) { return buffer + consumed; } + char *GetStartS(void) { return (char *)buffer + consumed; } + bool SizeCheck(size_t required_size) { return used + required_size < SIZE; }; + template <typename T> bool GetPrimitve(size_t offset, T &out) { + if (offset + sizeof(out) > used) + return false; + out = *(T *)(GetStart() + offset); + return true; + } + template <typename T> bool AddPrimitve(T value) { + if (!SizeCheck(sizeof(value))) + return false; + *(T *)(GetEnd()) = value; + used += sizeof(value); + return true; + } + bool GetPdu(size_t received_size, uint16_t &pdu_type, uint32_t &pdu_len) { + if (received_size > SIZE - used) + return false; // You did something wrong! + used += received_size; + if (used < sizeof(pdu_type) + sizeof(pdu_len)) + return false; + if (GetPrimitve<uint32_t>(0, pdu_len) == false) + return false; + if (GetPrimitve<uint16_t>(4, pdu_type) == false) + return false; + + pdu_len = ntohl(pdu_len); + pdu_type = ntohs(pdu_type); + + if (used < sizeof(pdu_type) + sizeof(pdu_len) + pdu_len) + return false; + + consumed += sizeof(pdu_type) + sizeof(pdu_len); + + return true; + } + bool AddPdu(BaseSerializer &bs) { + uint16_t pdu_type = bs.GetPduType(); + uint32_t pdu_len = bs.GetSerializedSize(); + + if (!SizeCheck(sizeof(pdu_type) + sizeof(pdu_len) + pdu_len)) + return false; + + if (!AddPrimitve<uint32_t>(htonl(pdu_len))) + return false; + if (!AddPrimitve<uint16_t>(htons(pdu_type))) + return false; + if (bs.Serialize(GetEnd()) != pdu_len) + return false; + used += pdu_len; + + return true; + } + bool AllConsumed(void) { return used == consumed; } + bool Consume(size_t consuming_size) { + if (consuming_size + consumed > used) + return false; + consumed += consuming_size; + return true; + } + void Sweep() { + if (consumed == 0) + return; + if (used != consumed) + memmove(buffer, buffer + consumed, used - consumed); + used -= consumed; + consumed = 0; + } + +private: + uint8_t buffer[SIZE]; + size_t used = 0; + size_t consumed = 0; + + size_t GetSize(void) { return SIZE; } + size_t GetUsed(void) { return used - consumed; } + uint8_t *GetEnd() { return buffer + used; } +}; + +class ProtobufCBinaryDataClass : public ProtobufCBinaryData { +public: + ProtobufCBinaryDataClass(std::initializer_list<uint8_t> bytes) { + len = bytes.size(); + data = new uint8_t[len]; + eastl::copy(bytes.begin(), bytes.end(), data); + } + ~ProtobufCBinaryDataClass(void) { delete data; } +}; + +/* Could be also be done via protobuf, but I decided for speed/memory + * efficiency. */ +enum PduTypes { + PDU_SOMETHING_WITH_UINTS = 1, + PDU_SOMETHING_MORE = 2, + PDU_EVEN_MORE = 3 +}; + +class SomethingWithUINTsSerializer : virtual public BaseSerializer { +public: + SomethingWithUINTsSerializer(void) {} + ~SomethingWithUINTsSerializer(void) {} + void SetId(uint32_t id) { + swu.has_id = TRUE; + swu.id = id; + } + void SetIpAddress(uint32_t ip_address) { + swu.has_ip_address = TRUE; + swu.ip_address = ip_address; + } + void SetPortNum(uint32_t port_num) { + swu.has_port_num = TRUE; + swu.port_num = port_num; + } + uint16_t GetPduType(void) override { return PDU_SOMETHING_WITH_UINTS; } + size_t GetSerializedSize(void) override { + return something_with_uints__get_packed_size(&swu); + } + size_t Serialize(uint8_t *buf) override { + return something_with_uints__pack(&swu, buf); + } + + SomethingWithUINTs swu = SOMETHING_WITH_UINTS__INIT; +}; + +class SomethingWithUINTsDeserializer : virtual public BaseDeserializer { +public: + SomethingWithUINTsDeserializer(void) {} + ~SomethingWithUINTsDeserializer(void) { DeserializeFree(); } + bool Deserialize(size_t pdu_len, uint8_t *buf) override { + swu = something_with_uints__unpack(NULL, pdu_len, buf); + return swu != NULL; + } + void DeserializeFree(void) override { + if (swu != NULL) + something_with_uints__free_unpacked(swu, NULL); + } + + SomethingWithUINTs *swu = NULL; +}; + +class SomethingMoreSerializer : virtual public BaseSerializer, + public SomethingWithUINTsSerializer { +public: + SomethingMoreSerializer(void) { sm.uints = &swu; } + ~SomethingMoreSerializer(void) {} + void SetErrorCode(SomethingMore__Errors error_code) { + sm.has_error_code = TRUE; + sm.error_code = error_code; + } + uint16_t GetPduType(void) override { return PDU_SOMETHING_MORE; } + size_t GetSerializedSize(void) override { + return something_more__get_packed_size(&sm); + } + size_t Serialize(uint8_t *buf) override { + return something_more__pack(&sm, buf); + } + + SomethingMore sm = SOMETHING_MORE__INIT; +}; + +class SomethingMoreDeserializer : virtual public BaseDeserializer { +public: + SomethingMoreDeserializer(void) {} + ~SomethingMoreDeserializer(void) { DeserializeFree(); } + bool Deserialize(size_t pdu_len, uint8_t *buf) override { + sm = something_more__unpack(NULL, pdu_len, buf); + return sm != NULL; + } + void DeserializeFree(void) override { + if (sm != NULL) + something_more__free_unpacked(sm, NULL); + } + + SomethingMore *sm = NULL; +}; + +class EvenMoreSerializer : virtual public BaseSerializer { +public: + explicit EvenMoreSerializer(void) = default; + explicit EvenMoreSerializer(EvenMore__SomeEnum enum_value, + ProtobufCBinaryDataClass name, + ProtobufCBinaryDataClass value) + : uints() { + em.enum_value = enum_value; + em.name = name; + em.value = value; + } + ~EvenMoreSerializer(void) { free(s); } + void AddUints(SomethingWithUINTsSerializer *uints) { + this->uints.push_back(uints); + } + void SetS(eastl::string s) { + size_t l = s.size(); + this->s = (char *)malloc(l + 1); + memcpy(this->s, s.c_str(), l); + this->s[l] = '\0'; + } + uint16_t GetPduType(void) override { return PDU_EVEN_MORE; } + size_t GetSerializedSize(void) override { + em.s = s; + em.n_uints = this->uints.size(); + if (em.n_uints > 0) { + SomethingWithUINTs *out[em.n_uints]; + ConvertUintsVectorToCArray(out); + return even_more__get_packed_size(&em); + } + return even_more__get_packed_size(&em); + } + size_t Serialize(uint8_t *buf) override { + em.s = s; + em.n_uints = this->uints.size(); + if (em.n_uints > 0) { + SomethingWithUINTs *out[em.n_uints]; + ConvertUintsVectorToCArray(out); + return even_more__pack(&em, buf); + } + return even_more__pack(&em, buf); + } + + EvenMore em = EVEN_MORE__INIT; + eastl::vector<SomethingWithUINTsSerializer *> uints; + char *s = NULL; + + void ConvertUintsVectorToCArray(SomethingWithUINTs **out) { + for (size_t i = 0; i < em.n_uints; ++i) { + out[i] = &uints[i]->swu; + } + em.uints = out; + } +}; + +class EvenMoreDeserializer : virtual public BaseDeserializer { +public: + EvenMoreDeserializer(void) {} + ~EvenMoreDeserializer(void) { DeserializeFree(); } + bool Deserialize(size_t pdu_len, uint8_t *buf) override { + em = even_more__unpack(NULL, pdu_len, buf); + return em != NULL; + } + void DeserializeFree(void) override { + if (em != NULL) + even_more__free_unpacked(em, NULL); + } + + EvenMore *em = NULL; +}; diff --git a/driver-protobuf-c-tcp.bat b/driver-protobuf-c-tcp.bat new file mode 100644 index 0000000..9aa935b --- /dev/null +++ b/driver-protobuf-c-tcp.bat @@ -0,0 +1,27 @@ +@echo off +set SERVICE_NAME=protobuf-c-tcp +set DRIVER="%~dp0\driver-protobuf-c-tcp.sys" + +net session >nul 2>&1 +if NOT %ERRORLEVEL% EQU 0 ( + echo ERROR: This script requires Administrator privileges! + pause + exit /b 1 +) + +echo --------------------------------------- +echo -- Service Name: %SERVICE_NAME% +echo -- Driver......: %DRIVER% +echo --------------------------------------- + +sc create %SERVICE_NAME% binPath= %DRIVER% type= kernel +echo --------------------------------------- +sc start %SERVICE_NAME% +echo --------------------------------------- +sc query %SERVICE_NAME% +echo [PRESS A KEY TO STOP THE DRIVER] +pause +sc stop %SERVICE_NAME% +sc delete %SERVICE_NAME% +echo Done. +timeout /t 3 diff --git a/driver-protobuf-c-tcp.cpp b/driver-protobuf-c-tcp.cpp new file mode 100644 index 0000000..0808b88 --- /dev/null +++ b/driver-protobuf-c-tcp.cpp @@ -0,0 +1,139 @@ +#include "berkeley.h" +#include "ksocket.h" +#include "protobuf-c/example.pb-c.h" +#include "wsk.h" + +#include "common.hpp" + +extern "C" { +DRIVER_INITIALIZE DriverEntry; +DRIVER_UNLOAD DriverUnload; + +#define DebuggerPrint(...) \ + DbgPrintEx(DPFLTR_IHVDRIVER_ID, DPFLTR_ERROR_LEVEL, __VA_ARGS__); + +NTSTATUS +NTAPI +DriverEntry(_In_ PDRIVER_OBJECT DriverObject, + _In_ PUNICODE_STRING RegistryPath) { + UNREFERENCED_PARAMETER(DriverObject); + UNREFERENCED_PARAMETER(RegistryPath); + + NTSTATUS Status; + + DebuggerPrint("Hi.\n"); + Status = KsInitialize(); + + if (!NT_SUCCESS(Status)) { + return Status; + } + + int server_sockfd = socket_listen(AF_INET, SOCK_STREAM, 0); + + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(9095); + + int result = bind(server_sockfd, (struct sockaddr *)&addr, sizeof(addr)); + if (result != 0) { + DebuggerPrint("TCP server bind failed\n"); + return STATUS_FAILED_DRIVER_ENTRY; + } + + result = listen(server_sockfd, 1); + if (result != 0) { + DebuggerPrint("TCP server listen failed\n"); + return STATUS_FAILED_DRIVER_ENTRY; + } + + socklen_t addrlen = sizeof(addr); + int client_sockfd = accept(server_sockfd, (struct sockaddr *)&addr, &addrlen); + if (client_sockfd < 0) { + DebuggerPrint("TCP accept failed\n"); + return STATUS_FAILED_DRIVER_ENTRY; + } + + int iResult; + SocketBuffer<1024> sb_send, sb_recv; + + do { + bool ok = false; + RECV_PDU_BEGIN(client_sockfd, sb_recv, iResult, pdu_type, pdu_len) { + DebuggerPrint("PDU type/len: %u/%u\n", pdu_type, pdu_len); + switch ((enum PduTypes)pdu_type) { + case PDU_SOMETHING_WITH_UINTS: { + SomethingWithUINTsDeserializer swud; + if ((ok = swud.Deserialize(pdu_len, sb_recv.GetStart())) == true) { + SomethingWithUINTsSerializer swus; + if (swud.swu->has_id == TRUE) { + DebuggerPrint("Id: 0x%X\n", swud.swu->id); + swus.SetId(swud.swu->id + 1); + } + ok = sb_send.AddPdu(swus); + } + break; + } + case PDU_SOMETHING_MORE: { + SomethingMoreDeserializer smd; + if ((ok = smd.Deserialize(pdu_len, sb_recv.GetStart())) == true) { + SomethingMoreSerializer sms; + if (smd.sm->has_error_code == TRUE) { + DebuggerPrint("Error Code: %u\n", smd.sm->error_code); + } + if (smd.sm->uints->has_id == TRUE) { + DebuggerPrint("Id: 0x%X\n", smd.sm->uints->id); + sms.SetId(smd.sm->uints->id + 1); + } + sms.SetErrorCode(SOMETHING_MORE__ERRORS__SUCCESS); + sms.SetIpAddress(0xCCCCCCCC); + sms.SetPortNum(0xDDDDDDDD); + ok = sb_send.AddPdu(sms); + } + break; + } + case PDU_EVEN_MORE: { + EvenMoreDeserializer emd; + if ((ok = emd.Deserialize(pdu_len, sb_recv.GetStart())) == true) { + DebuggerPrint("EnumValue: %d\n", emd.em->enum_value); + if (emd.em->s != NULL) { + DebuggerPrint("String: '%s'\n", emd.em->s); + } + EvenMoreSerializer ems; + SomethingWithUINTsSerializer swus; + swus.SetId(0xDEADDEAD); + ems.SetS("Hi userspace!"); + ems.AddUints(&swus); + ok = sb_send.AddPdu(ems); + } + break; + } + } + } + RECV_PDU_END(sb_recv, pdu_len); + + if (ok == true) { + SEND_ALL(client_sockfd, sb_send, iResult); + if (iResult == SOCKET_ERROR || iResult == 0) { + DebuggerPrint("send failed\n"); + break; + } + } else { + DebuggerPrint("Serialization/Deserialization failed\n"); + break; + } + } while (iResult != SOCKET_ERROR && iResult > 0); + + DebuggerPrint("Client gone.\n") closesocket(client_sockfd); + closesocket(server_sockfd); + KsDestroy(); + + return STATUS_SUCCESS; +} + +VOID DriverUnload(_In_ struct _DRIVER_OBJECT *DriverObject) { + UNREFERENCED_PARAMETER(DriverObject); + + DebuggerPrint("Bye.\n"); +} +} diff --git a/driver-protobuf-c.cpp b/driver-protobuf-c.cpp index 0ba11c7..762137c 100644 --- a/driver-protobuf-c.cpp +++ b/driver-protobuf-c.cpp @@ -1,10 +1,11 @@ - -extern "C" { #include "berkeley.h" #include "ksocket.h" #include "protobuf-c/example.pb-c.h" #include "wsk.h" +#include "common.hpp" + +extern "C" { DRIVER_INITIALIZE DriverEntry; DRIVER_UNLOAD DriverUnload; @@ -19,91 +20,78 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, UNREFERENCED_PARAMETER(RegistryPath); size_t len = 0; - uint8_t *buf = NULL; { - SomethingMore sm = SOMETHING_MORE__INIT; - SomethingWithUINTs swu = SOMETHING_WITH_UINTS__INIT; - - sm.error_code = SOMETHING_MORE__ERRORS__SUCCESS; - sm.uints = &swu; - swu.has_id = TRUE; - swu.id = 0x12345678; - swu.has_ip_address = TRUE; - swu.ip_address = 0xAAAAAAAA; - swu.has_port_num = TRUE; - swu.port_num = 0xBBBBBBBB; - - len = something_more__get_packed_size(&sm); - buf = (uint8_t *)malloc(len); - if (something_more__pack(&sm, buf) != len) { - DebuggerPrint("Packing failed.\n"); - } - } + uint8_t *buf = NULL; + + { + SomethingMoreSerializer sms; + sms.SetErrorCode(SOMETHING_MORE__ERRORS__SUCCESS); + sms.SetId(0x12345678); + sms.SetIpAddress(0xAAAAAAAA); + sms.SetPortNum(0xBBBBBBBB); + len = sms.GetSerializedSize(); + buf = (uint8_t *)malloc(len); + if (buf == NULL) { + return STATUS_UNSUCCESSFUL; + } + if (sms.Serialize(buf) != len) { + DebuggerPrint("Packing failed.\n"); + free(buf); + return STATUS_UNSUCCESSFUL; + } - DebuggerPrint("Packed Size: %zu\n", len); + DebuggerPrint("Packed Size: %zu\n", len); - { - SomethingMore *smth = something_more__unpack(NULL, len, buf); - if (smth != NULL && smth->uints != NULL && smth->uints->id == 0x12345678 && - smth->uints->ip_address == 0xAAAAAAAA && - smth->uints->port_num == 0xBBBBBBBB) { - DebuggerPrint("Success!\n"); + SomethingMoreDeserializer smd; + if (smd.Deserialize(len, buf) == true && smd.sm->uints != NULL && + smd.sm->uints->id == 0x12345678 && + smd.sm->uints->ip_address == 0xAAAAAAAA && + smd.sm->uints->port_num == 0xBBBBBBBB) { + DebuggerPrint("Success!\n"); + } } - DebuggerPrint("id: %x, ip_address: %x, port_num: %x\n", smth->uints->id, - smth->uints->ip_address, smth->uints->port_num); - something_more__free_unpacked(smth, NULL); - } - free(buf); + free(buf); + } { - EvenMore em = EVEN_MORE__INIT; - SomethingWithUINTs swu[] = {SOMETHING_WITH_UINTS__INIT, - SOMETHING_WITH_UINTS__INIT, - SOMETHING_WITH_UINTS__INIT}; - SomethingWithUINTs *swu_vec[] = {&swu[0], &swu[1], &swu[2]}; - uint8_t bin[] = {0xde, 0xad, 0xc0, 0xde}; - ProtobufCBinaryData pbin = {.len = sizeof(bin), .data = bin}; - char str[] = "This is a zero-terminated String!"; - - em.enum_value = EVEN_MORE__SOME_ENUM__FIRST; - swu[0].has_id = TRUE; - swu[0].id = 0xdeadc0de; - swu[1].has_ip_address = TRUE; - swu[1].ip_address = 0xdeadbeef; - swu[2].has_port_num = TRUE; - swu[2].port_num = 0xcafecafe; - em.n_uints = sizeof(swu) / sizeof(swu[0]); - em.uints = swu_vec; - em.name = pbin; - em.value = pbin; - em.s = str; - - len = even_more__get_packed_size(&em); + EvenMoreSerializer ems(EVEN_MORE__SOME_ENUM__FIRST, + {0xde, 0xad, 0xc0, 0xde}, {0xde, 0xad, 0xc0, 0xde}); + SomethingWithUINTsSerializer sws[3]; + + sws[0].SetId(0xdeadc0de); + sws[1].SetIpAddress(0xdeadbeef); + sws[2].SetPortNum(0xcafecafe); + + ems.SetS("This is a zero-terminated String!"); + ems.AddUints(&sws[0]); + ems.AddUints(&sws[1]); + ems.AddUints(&sws[2]); + + len = ems.GetSerializedSize(); uint8_t tmp[len]; - if (even_more__pack(&em, tmp) != len) { + if (ems.Serialize(tmp) != len) { DebuggerPrint("Packing failed.\n"); } DebuggerPrint("Packed Size: %zu\n", len); - EvenMore *result = even_more__unpack(NULL, len, tmp); - if (result != NULL && result->n_uints == 3 && result->uints != NULL && - result->name.len > 0 && result->name.data != NULL && - result->value.len > 0 && result->value.data != NULL && - result->s != NULL) { - if (result->enum_value != EVEN_MORE__SOME_ENUM__FIRST || - result->uints[0]->has_id != TRUE || - result->uints[0]->id != 0xdeadc0de || - result->uints[1]->has_ip_address != TRUE || - result->uints[1]->ip_address != 0xdeadbeef || - result->uints[1]->has_port_num != TRUE || - result->uints[2]->port_num != 0xcafecafe) { + EvenMoreDeserializer emd; + if (emd.Deserialize(len, tmp) == true && emd.em->n_uints == 3 && + emd.em->uints != NULL && emd.em->name.len > 0 && + emd.em->name.data != NULL && emd.em->value.len > 0 && + emd.em->value.data != NULL && emd.em->s != NULL) { + if (emd.em->enum_value == EVEN_MORE__SOME_ENUM__FIRST || + emd.em->uints[0]->has_id == TRUE || + emd.em->uints[0]->id == 0xdeadc0de || + emd.em->uints[1]->has_ip_address == TRUE || + emd.em->uints[1]->ip_address == 0xdeadbeef || + emd.em->uints[2]->has_port_num == TRUE || + emd.em->uints[2]->port_num == 0xcafecafe) { DebuggerPrint("Success!\n"); } - DebuggerPrint("Deserialized String: '%s'\n", result->s); + DebuggerPrint("Deserialized String: '%s'\n", emd.em->s); } - even_more__free_unpacked(result, NULL); } return STATUS_SUCCESS; @@ -117,6 +117,10 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, socklen_t addrlen = sizeof(addr); int client_sockfd = accept(server_sockfd, (struct sockaddr *)&addr, &addrlen); + if (client_sockfd < 0) { + DebuggerPrint("TCP accept failed\n"); + return STATUS_FAILED_DRIVER_ENTRY; + } result = recv(client_sockfd, recv_buffer, sizeof(recv_buffer) - 1, 0); if (result > 0) { diff --git a/echo_srv.py b/echo_srv.py new file mode 100644 index 0000000..b3855eb --- /dev/null +++ b/echo_srv.py @@ -0,0 +1,16 @@ +import socket + +HOST = '127.0.0.1' +PORT = 9095 + +with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((HOST, PORT)) + s.listen() + conn, addr = s.accept() + with conn: + while True: + data = conn.recv(1024) + print('recvd: {} bytes'.format(len(data))) + if len(data) == 0: + break; + conn.sendall(data) diff --git a/userspace_client_protobuf.cpp b/userspace_client_protobuf.cpp new file mode 100644 index 0000000..7ee674e --- /dev/null +++ b/userspace_client_protobuf.cpp @@ -0,0 +1,152 @@ +#include <stdio.h> +#include <stdlib.h> // Needed for _wtoi +#include <winsock2.h> +#include <ws2tcpip.h> + +#include "common.hpp" +#include "protobuf-c/example.pb-c.h" + +int main(int argc, char **argv) { + WSADATA wsaData = {}; + int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); + + UNREFERENCED_PARAMETER(argc); + UNREFERENCED_PARAMETER(argv); + + if (iResult != 0) { + wprintf(L"WSAStartup failed: %d\n", iResult); + return 1; + } + + SOCKET sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + + if (sock == INVALID_SOCKET) { + wprintf(L"socket function failed with error = %d\n", WSAGetLastError()); + } else { + wprintf(L"socket function succeeded\n"); + } + + sockaddr_in clientService; + clientService.sin_family = AF_INET; + clientService.sin_addr.s_addr = inet_addr("127.0.0.1"); + clientService.sin_port = htons(9095); + + do { + iResult = connect(sock, (SOCKADDR *)&clientService, sizeof(clientService)); + if (iResult == SOCKET_ERROR) { + wprintf(L"connect function failed with error: %ld\n", WSAGetLastError()); + Sleep(1000); + } + } while (iResult == SOCKET_ERROR); + + wprintf(L"Connected to server.\n"); + + uint32_t start_id = 0x12345678; + for (size_t i = 0; i < 256; ++i) { + SocketBuffer<1024> sb_send, sb_recv; + SomethingMoreSerializer sms; + + sms.SetErrorCode(SOMETHING_MORE__ERRORS__SUCCESS); + sms.SetId(start_id++); + sms.SetIpAddress(0xAAAAAAAA); + sms.SetPortNum(0xBBBBBBBB); + + if (!sb_send.AddPdu(sms)) { + wprintf(L"Serialization failed\n"); + break; + } + + SEND_ALL(sock, sb_send, iResult); + if (iResult == SOCKET_ERROR || iResult == 0) { + wprintf(L"send failed with error: %d\n", WSAGetLastError()); + break; + } + + RECV_PDU_BEGIN(sock, sb_recv, iResult, pdu_type, pdu_len) { + wprintf(L"PDU type/len: %u/%u\n", pdu_type, pdu_len); + switch ((enum PduTypes)pdu_type) { + case PDU_SOMETHING_WITH_UINTS: { + break; + } + case PDU_SOMETHING_MORE: { + SomethingMoreDeserializer smd; + if (smd.Deserialize(pdu_len, sb_recv.GetStart()) == true && + smd.sm->uints != NULL && smd.sm->uints->has_id == TRUE && + smd.sm->uints->has_ip_address == TRUE && + smd.sm->uints->has_port_num == TRUE) + wprintf(L"Id: 0x%X, IpAddress: 0x%X, PortNum: 0x%X\n", + smd.sm->uints->id, smd.sm->uints->ip_address, + smd.sm->uints->port_num); + break; + } + case PDU_EVEN_MORE: { + break; + } + } + } + RECV_PDU_END(sb_recv, pdu_len); + + //////////////////////////////////////////////////////// + + EvenMoreSerializer ems(EVEN_MORE__SOME_ENUM__FIRST, + {0xde, 0xad, 0xc0, 0xde}, {0xde, 0xad, 0xc0, 0xde}); + SomethingWithUINTsSerializer swus[3]; + + swus[0].SetId(0xdeadc0de); + swus[1].SetIpAddress(0xdeadbeef); + swus[2].SetPortNum(0xcafecafe); + + ems.SetS("This is a zero-terminated String!"); + ems.AddUints(&swus[0]); + ems.AddUints(&swus[1]); + ems.AddUints(&swus[2]); + + if (!sb_send.AddPdu(ems)) { + wprintf(L"Serialization failed\n"); + break; + } + + SEND_ALL(sock, sb_send, iResult); + if (iResult == SOCKET_ERROR || iResult == 0) { + wprintf(L"send failed with error: %d\n", WSAGetLastError()); + break; + } + + RECV_PDU_BEGIN(sock, sb_recv, iResult, pdu_type, pdu_len) { + wprintf(L"PDU type/len: %u/%u\n", pdu_type, pdu_len); + switch ((enum PduTypes)pdu_type) { + case PDU_SOMETHING_WITH_UINTS: { + break; + } + case PDU_SOMETHING_MORE: { + break; + } + case PDU_EVEN_MORE: { + EvenMoreDeserializer emd; + if (emd.Deserialize(pdu_len, sb_recv.GetStart()) == true) { + wprintf(L"EnumValue: %d\n", emd.em->enum_value); + if (emd.em->s != NULL) + wprintf(L"String: '%s'\n", emd.em->s); + } + break; + } + } + } + RECV_PDU_END(sb_recv, pdu_len); + } + wprintf(L"Closing Connection ..\n"); + + iResult = closesocket(sock); + if (iResult == SOCKET_ERROR) { + wprintf(L"closesocket function failed with error: %ld\n", + WSAGetLastError()); + WSACleanup(); + return 1; + } + + WSACleanup(); + + system("pause"); + + return 0; +} |