diff options
author | Toni Uhlig <matzeton@googlemail.com> | 2023-09-15 11:21:31 +0200 |
---|---|---|
committer | Toni Uhlig <matzeton@googlemail.com> | 2023-09-15 11:21:31 +0200 |
commit | 0cbfbe129934976359460fdbe69fb97632d81d24 (patch) | |
tree | 3d4685fc1f02244c80d4f5de2fa6c80fae94a425 | |
parent | 37d1e657e5e79bc240ea036cfb8da377b1640490 (diff) |
Added C++ (`ksocket/ksocket.hpp`) Socket wrapper classes.
* another flatbuffers example (WiP!)
* Makefile improvements
Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
-rw-r--r-- | Makefile | 54 | ||||
-rw-r--r-- | examples/apiwrapper.fbs | 6 | ||||
-rw-r--r-- | examples/apiwrapper_builder.h | 55 | ||||
-rw-r--r-- | examples/apiwrapper_reader.h | 54 | ||||
-rw-r--r-- | examples/apiwrapper_verifier.h | 43 | ||||
-rw-r--r-- | examples/driver-flatbuffers-tcp.bat | 27 | ||||
-rw-r--r-- | examples/driver-flatbuffers-tcp.cpp | 108 | ||||
-rw-r--r-- | examples/driver-protobuf-c-tcp.cpp | 1 | ||||
-rw-r--r-- | examples/driver-protobuf-c.cpp | 1 | ||||
-rw-r--r-- | examples/driver.cpp | 129 | ||||
-rw-r--r-- | examples/userspace_client_flatbuffers.cpp | 82 | ||||
-rw-r--r-- | examples/userspace_client_protobuf.cpp | 1 | ||||
-rw-r--r-- | ksocket/berkeley.c | 8 | ||||
-rw-r--r-- | ksocket/berkeley.h | 16 | ||||
-rw-r--r-- | ksocket/helper.hpp | 2 | ||||
-rw-r--r-- | ksocket/ksocket.cpp | 350 | ||||
-rw-r--r-- | ksocket/ksocket.h | 13 | ||||
-rw-r--r-- | ksocket/ksocket.hpp | 233 | ||||
-rw-r--r-- | ksocket/utils.c | 17 | ||||
-rw-r--r-- | ksocket/utils.h | 30 | ||||
-rw-r--r-- | ksocket/wsk.h | 4 |
21 files changed, 1196 insertions, 38 deletions
@@ -12,43 +12,61 @@ FLATBUFFERS_CFLAGS = -Iflatbuffers-build/include -Wno-misleading-indentation FLATBUFFERS_FLATC = flatbuffers-flatcc-build/bin/flatcc DRIVER0_NAME = driver -DRIVER0_OBJECTS = examples/$(DRIVER0_NAME).o ksocket/ksocket.o ksocket/berkeley.o +DRIVER0_OBJECTS = examples/$(DRIVER0_NAME).opp ksocket/ksocket.o ksocket/berkeley.o ksocket/ksocket.opp ksocket/utils.o DRIVER0_TARGET = $(DRIVER0_NAME).sys DRIVER1_NAME = driver-protobuf-c -DRIVER1_OBJECTS = examples/$(DRIVER1_NAME).o examples/example.pb-c.o $(PROTOBUF_OBJECT) +DRIVER1_OBJECTS = examples/$(DRIVER1_NAME).opp ksocket/utils.o examples/example.pb-c.o $(PROTOBUF_OBJECT) DRIVER1_TARGET = $(DRIVER1_NAME).sys DRIVER2_NAME = driver-protobuf-c-tcp -DRIVER2_OBJECTS = examples/$(DRIVER2_NAME).o ksocket/ksocket.o ksocket/berkeley.o examples/example.pb-c.o $(PROTOBUF_OBJECT) +DRIVER2_OBJECTS = examples/$(DRIVER2_NAME).opp ksocket/ksocket.o ksocket/berkeley.o ksocket/utils.o examples/example.pb-c.o $(PROTOBUF_OBJECT) DRIVER2_TARGET = $(DRIVER2_NAME).sys DRIVER3_NAME = driver-flatbuffers -DRIVER3_OBJECTS = examples/$(DRIVER3_NAME).o ksocket/ksocket.o ksocket/berkeley.o $(FLATBUFFERS_LIB) +DRIVER3_OBJECTS = examples/$(DRIVER3_NAME).opp ksocket/ksocket.o ksocket/berkeley.o $(FLATBUFFERS_LIB) DRIVER3_TARGET = $(DRIVER3_NAME).sys +DRIVER4_NAME = driver-flatbuffers-tcp +DRIVER4_OBJECTS = examples/$(DRIVER4_NAME).opp ksocket/ksocket.opp ksocket/ksocket.o ksocket/berkeley.o ksocket/utils.o $(FLATBUFFERS_LIB) +DRIVER4_TARGET = $(DRIVER4_NAME).sys + USERSPACE0_NAME = userspace_client -USERSPACE0_OBJECTS = examples/$(USERSPACE0_NAME).o +USERSPACE0_OBJECTS = examples/$(USERSPACE0_NAME).opp USERSPACE0_TARGET = $(USERSPACE0_NAME).exe USERSPACE1_NAME = userspace_client_protobuf -USERSPACE1_OBJECTS = examples/$(USERSPACE1_NAME).o examples/example.pb-c.o $(PROTOBUF_OBJECT) +USERSPACE1_OBJECTS = examples/$(USERSPACE1_NAME).opp examples/example.pb-c.o $(PROTOBUF_OBJECT) USERSPACE1_TARGET = $(USERSPACE1_NAME).exe +USERSPACE2_NAME = userspace_client_flatbuffers +USERSPACE2_OBJECTS = examples/$(USERSPACE2_NAME).opp ksocket/ksocket_user.opp ksocket/utils_user.o $(FLATBUFFERS_LIB) +USERSPACE2_TARGET = $(USERSPACE2_NAME).exe + # mingw-w64-dpp related +CFLAGS_examples/$(USERSPACE1_NAME).opp = -DBUILD_USERMODE=1 +CFLAGS_examples/$(USERSPACE2_NAME).opp = -DBUILD_USERMODE=1 +CFLAGS_ksocket/utils_user.o = -DBUILD_USERMODE=1 +CFLAGS_ksocket/ksocket_user.opp = -DBUILD_USERMODE=1 CFLAGS_$(PROTOBUF_OBJECT) = $(PROTOBUF_CFLAGS_PRIVATE) CUSTOM_CFLAGS = -I. -Iexamples -Werror $(FLATBUFFERS_CFLAGS) -Wl,--exclude-all-symbols -DNDEBUG DRIVER_LIBS += -lnetio USER_LIBS += -lws2_32 -all: deps $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) +all: deps $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(DRIVER4_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) $(USERSPACE2_TARGET) -%.o: %.cpp +%.opp: %.cpp $(call BUILD_CPP_OBJECT,$<,$@) %.o: %.c $(call BUILD_C_OBJECT,$<,$@) +ksocket/utils_user.o: ksocket/utils.c + $(call BUILD_C_OBJECT,$<,$@) + +ksocket/ksocket_user.opp: ksocket/ksocket.cpp + $(call BUILD_CPP_OBJECT,$<,$@) + $(PROTOBUF_OBJECT): protobuf-c/protobuf-c.c mkdir -p '$(dir $(PROTOBUF_OBJECT))' $(call BUILD_C_OBJECT,$<,$@) @@ -65,12 +83,18 @@ $(DRIVER2_TARGET): $(DRIVER2_OBJECTS) $(DRIVER3_TARGET): $(FLATBUFFERS_LIB) $(DRIVER3_OBJECTS) $(call LINK_CPP_KERNEL_TARGET,$(DRIVER3_OBJECTS),$@) +$(DRIVER4_TARGET): $(FLATBUFFERS_LIB) $(DRIVER4_OBJECTS) + $(call LINK_CPP_KERNEL_TARGET,$(DRIVER4_OBJECTS),$@) + $(USERSPACE0_TARGET): $(USERSPACE0_OBJECTS) $(call LINK_CPP_USER_TARGET,$(USERSPACE0_OBJECTS),$@) $(USERSPACE1_TARGET): $(USERSPACE1_OBJECTS) $(call LINK_CPP_USER_TARGET,$(USERSPACE1_OBJECTS),$@) +$(USERSPACE2_TARGET): $(USERSPACE2_OBJECTS) + $(call LINK_CPP_USER_TARGET,$(USERSPACE2_OBJECTS),$@) + deps: $(FLATBUFFERS_LIB) $(FLATBUFFERS_FLATC) $(FLATBUFFERS_LIB): @@ -98,6 +122,7 @@ $(FLATBUFFERS_FLATC): generate: $(FLATBUFFERS_FLATC) @echo 'Generating flatbuffer files..' $(FLATBUFFERS_FLATC) -a -o examples examples/monster.fbs + $(FLATBUFFERS_FLATC) -a -o examples examples/apiwrapper.fbs @echo '==========================================' @echo '= You need protobuf-c to make this work! =' @echo '==========================================' @@ -105,26 +130,31 @@ generate: $(FLATBUFFERS_FLATC) @echo protoc-c --c_out=. examples/example.proto -install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) +install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(DRIVER4_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) $(USERSPACE2_TARGET) $(call INSTALL_EXEC_SIGN,$(DRIVER0_TARGET)) $(call INSTALL_EXEC_SIGN,$(DRIVER1_TARGET)) $(call INSTALL_EXEC_SIGN,$(DRIVER2_TARGET)) $(call INSTALL_EXEC_SIGN,$(DRIVER3_TARGET)) + $(call INSTALL_EXEC_SIGN,$(DRIVER4_TARGET)) $(call INSTALL_EXEC,$(USERSPACE0_TARGET)) $(call INSTALL_EXEC,$(USERSPACE1_TARGET)) + $(call INSTALL_EXEC,$(USERSPACE2_TARGET)) $(INSTALL) 'examples/$(DRIVER0_NAME).bat' '$(DESTDIR)/' $(INSTALL) 'examples/$(DRIVER1_NAME).bat' '$(DESTDIR)/' $(INSTALL) 'examples/$(DRIVER2_NAME).bat' '$(DESTDIR)/' $(INSTALL) 'examples/$(DRIVER3_NAME).bat' '$(DESTDIR)/' + $(INSTALL) 'examples/$(DRIVER4_NAME).bat' '$(DESTDIR)/' clean: + rm -f examples/*.o examples/*.opp rm -f $(DRIVER0_OBJECTS) $(DRIVER1_OBJECTS) $(DRIVER2_OBJECTS) $(DRIVER3_OBJECTS) rm -f $(DRIVER0_TARGET) $(DRIVER0_TARGET).map \ $(DRIVER1_TARGET) $(DRIVER1_TARGET).map \ $(DRIVER2_TARGET) $(DRIVER2_TARGET).map \ - $(DRIVER3_TARGET) $(DRIVER3_TARGET).map - rm -f $(USERSPACE0_OBJECTS) $(USERSPACE1_OBJECTS) - rm -f $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) + $(DRIVER3_TARGET) $(DRIVER3_TARGET).map \ + $(DRIVER4_TARGET) $(DRIVER4_TARGET).map + rm -f $(USERSPACE0_OBJECTS) $(USERSPACE1_OBJECTS) $(USERSPACE2_OBJECTS) + rm -f $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) $(USERSPACE2_TARGET) distclean: clean rm -rf flatbuffers-build flatbuffers-flatcc-build diff --git a/examples/apiwrapper.fbs b/examples/apiwrapper.fbs new file mode 100644 index 0000000..6de1330 --- /dev/null +++ b/examples/apiwrapper.fbs @@ -0,0 +1,6 @@ +table FunctionAddresses { + names:[string]; + addrs:[uint64]; +} + +root_type FunctionAddresses; diff --git a/examples/apiwrapper_builder.h b/examples/apiwrapper_builder.h new file mode 100644 index 0000000..f0e0c1b --- /dev/null +++ b/examples/apiwrapper_builder.h @@ -0,0 +1,55 @@ +#ifndef APIWRAPPER_BUILDER_H +#define APIWRAPPER_BUILDER_H + +/* Generated by flatcc 0.6.2 FlatBuffers schema compiler for C by dvide.com */ + +#ifndef APIWRAPPER_READER_H +#include "apiwrapper_reader.h" +#endif +#ifndef FLATBUFFERS_COMMON_BUILDER_H +#include "flatbuffers_common_builder.h" +#endif +#include "flatcc/flatcc_prologue.h" +#ifndef flatbuffers_identifier +#define flatbuffers_identifier 0 +#endif +#ifndef flatbuffers_extension +#define flatbuffers_extension "bin" +#endif + +static const flatbuffers_voffset_t __FunctionAddresses_required[] = { 0 }; +typedef flatbuffers_ref_t FunctionAddresses_ref_t; +static FunctionAddresses_ref_t FunctionAddresses_clone(flatbuffers_builder_t *B, FunctionAddresses_table_t t); +__flatbuffers_build_table(flatbuffers_, FunctionAddresses, 2) + +#define __FunctionAddresses_formal_args , flatbuffers_string_vec_ref_t v0, flatbuffers_uint64_vec_ref_t v1 +#define __FunctionAddresses_call_args , v0, v1 +static inline FunctionAddresses_ref_t FunctionAddresses_create(flatbuffers_builder_t *B __FunctionAddresses_formal_args); +__flatbuffers_build_table_prolog(flatbuffers_, FunctionAddresses, FunctionAddresses_file_identifier, FunctionAddresses_type_identifier) + +__flatbuffers_build_string_vector_field(0, flatbuffers_, FunctionAddresses_names, FunctionAddresses) +__flatbuffers_build_vector_field(1, flatbuffers_, FunctionAddresses_addrs, flatbuffers_uint64, uint64_t, FunctionAddresses) + +static inline FunctionAddresses_ref_t FunctionAddresses_create(flatbuffers_builder_t *B __FunctionAddresses_formal_args) +{ + if (FunctionAddresses_start(B) + || FunctionAddresses_names_add(B, v0) + || FunctionAddresses_addrs_add(B, v1)) { + return 0; + } + return FunctionAddresses_end(B); +} + +static FunctionAddresses_ref_t FunctionAddresses_clone(flatbuffers_builder_t *B, FunctionAddresses_table_t t) +{ + __flatbuffers_memoize_begin(B, t); + if (FunctionAddresses_start(B) + || FunctionAddresses_names_pick(B, t) + || FunctionAddresses_addrs_pick(B, t)) { + return 0; + } + __flatbuffers_memoize_end(B, t, FunctionAddresses_end(B)); +} + +#include "flatcc/flatcc_epilogue.h" +#endif /* APIWRAPPER_BUILDER_H */ diff --git a/examples/apiwrapper_reader.h b/examples/apiwrapper_reader.h new file mode 100644 index 0000000..f7e8253 --- /dev/null +++ b/examples/apiwrapper_reader.h @@ -0,0 +1,54 @@ +#ifndef APIWRAPPER_READER_H +#define APIWRAPPER_READER_H + +/* Generated by flatcc 0.6.2 FlatBuffers schema compiler for C by dvide.com */ + +#ifndef FLATBUFFERS_COMMON_READER_H +#include "flatbuffers_common_reader.h" +#endif +#include "flatcc/flatcc_flatbuffers.h" +#ifndef __alignas_is_defined +#include <stdalign.h> +#endif +#include "flatcc/flatcc_prologue.h" +#ifndef flatbuffers_identifier +#define flatbuffers_identifier 0 +#endif +#ifndef flatbuffers_extension +#define flatbuffers_extension "bin" +#endif + + +typedef const struct FunctionAddresses_table *FunctionAddresses_table_t; +typedef struct FunctionAddresses_table *FunctionAddresses_mutable_table_t; +typedef const flatbuffers_uoffset_t *FunctionAddresses_vec_t; +typedef flatbuffers_uoffset_t *FunctionAddresses_mutable_vec_t; +#ifndef FunctionAddresses_file_identifier +#define FunctionAddresses_file_identifier 0 +#endif +/* deprecated, use FunctionAddresses_file_identifier */ +#ifndef FunctionAddresses_identifier +#define FunctionAddresses_identifier 0 +#endif +#define FunctionAddresses_type_hash ((flatbuffers_thash_t)0x73e9a2df) +#define FunctionAddresses_type_identifier "\xdf\xa2\xe9\x73" +#ifndef FunctionAddresses_file_extension +#define FunctionAddresses_file_extension "bin" +#endif + + + +struct FunctionAddresses_table { uint8_t unused__; }; + +static inline size_t FunctionAddresses_vec_len(FunctionAddresses_vec_t vec) +__flatbuffers_vec_len(vec) +static inline FunctionAddresses_table_t FunctionAddresses_vec_at(FunctionAddresses_vec_t vec, size_t i) +__flatbuffers_offset_vec_at(FunctionAddresses_table_t, vec, i, 0) +__flatbuffers_table_as_root(FunctionAddresses) + +__flatbuffers_define_vector_field(0, FunctionAddresses, names, flatbuffers_string_vec_t, 0) +__flatbuffers_define_vector_field(1, FunctionAddresses, addrs, flatbuffers_uint64_vec_t, 0) + + +#include "flatcc/flatcc_epilogue.h" +#endif /* APIWRAPPER_READER_H */ diff --git a/examples/apiwrapper_verifier.h b/examples/apiwrapper_verifier.h new file mode 100644 index 0000000..fc88b41 --- /dev/null +++ b/examples/apiwrapper_verifier.h @@ -0,0 +1,43 @@ +#ifndef APIWRAPPER_VERIFIER_H +#define APIWRAPPER_VERIFIER_H + +/* Generated by flatcc 0.6.2 FlatBuffers schema compiler for C by dvide.com */ + +#ifndef APIWRAPPER_READER_H +#include "apiwrapper_reader.h" +#endif +#include "flatcc/flatcc_verifier.h" +#include "flatcc/flatcc_prologue.h" + +static int FunctionAddresses_verify_table(flatcc_table_verifier_descriptor_t *td); + +static int FunctionAddresses_verify_table(flatcc_table_verifier_descriptor_t *td) +{ + int ret; + if ((ret = flatcc_verify_string_vector_field(td, 0, 0) /* names */)) return ret; + if ((ret = flatcc_verify_vector_field(td, 1, 0, 8, 8, INT64_C(536870911)) /* addrs */)) return ret; + return flatcc_verify_ok; +} + +static inline int FunctionAddresses_verify_as_root(const void *buf, size_t bufsiz) +{ + return flatcc_verify_table_as_root(buf, bufsiz, FunctionAddresses_identifier, &FunctionAddresses_verify_table); +} + +static inline int FunctionAddresses_verify_as_typed_root(const void *buf, size_t bufsiz) +{ + return flatcc_verify_table_as_root(buf, bufsiz, FunctionAddresses_type_identifier, &FunctionAddresses_verify_table); +} + +static inline int FunctionAddresses_verify_as_root_with_identifier(const void *buf, size_t bufsiz, const char *fid) +{ + return flatcc_verify_table_as_root(buf, bufsiz, fid, &FunctionAddresses_verify_table); +} + +static inline int FunctionAddresses_verify_as_root_with_type_hash(const void *buf, size_t bufsiz, flatbuffers_thash_t thash) +{ + return flatcc_verify_table_as_typed_root(buf, bufsiz, thash, &FunctionAddresses_verify_table); +} + +#include "flatcc/flatcc_epilogue.h" +#endif /* APIWRAPPER_VERIFIER_H */ diff --git a/examples/driver-flatbuffers-tcp.bat b/examples/driver-flatbuffers-tcp.bat new file mode 100644 index 0000000..00d5d8c --- /dev/null +++ b/examples/driver-flatbuffers-tcp.bat @@ -0,0 +1,27 @@ +@echo off +set SERVICE_NAME=flatbuffers-tcp +set DRIVER="%~dp0\driver-flatbuffers-tcp.sys" + +net session >nul 2>&1 +if NOT %ERRORLEVEL% EQU 0 ( + echo ERROR: This script requires Administrator privileges! + pause + exit /b 1 +) + +echo --------------------------------------- +echo -- Service Name: %SERVICE_NAME% +echo -- Driver......: %DRIVER% +echo --------------------------------------- + +sc create %SERVICE_NAME% binPath= %DRIVER% type= kernel +echo --------------------------------------- +sc start %SERVICE_NAME% +echo --------------------------------------- +sc query %SERVICE_NAME% +echo [PRESS A KEY TO STOP THE DRIVER] +pause +sc stop %SERVICE_NAME% +sc delete %SERVICE_NAME% +echo Done. +timeout /t 3 diff --git a/examples/driver-flatbuffers-tcp.cpp b/examples/driver-flatbuffers-tcp.cpp new file mode 100644 index 0000000..00b200c --- /dev/null +++ b/examples/driver-flatbuffers-tcp.cpp @@ -0,0 +1,108 @@ +#include <ksocket/berkeley.h> +#include <ksocket/helper.hpp> +#include <ksocket/ksocket.hpp> +#include <ksocket/ksocket.h> +#include <ksocket/wsk.h> + +#include "apiwrapper_builder.h" +#include "apiwrapper_reader.h" +#include "apiwrapper_verifier.h" + +extern "C" { +DRIVER_INITIALIZE DriverEntry; +DRIVER_UNLOAD DriverUnload; + +#define DebuggerPrint(...) \ + DbgPrintEx(DPFLTR_IHVDRIVER_ID, DPFLTR_ERROR_LEVEL, __VA_ARGS__); + +NTSTATUS +NTAPI +DriverEntry(_In_ PDRIVER_OBJECT DriverObject, + _In_ PUNICODE_STRING RegistryPath) { + UNREFERENCED_PARAMETER(DriverObject); + UNREFERENCED_PARAMETER(RegistryPath); + + NTSTATUS Status; + + KSocketBuffer buf; + buf.insert(buf.end(), static_cast<uint16_t>(0x1122)); + buf.insert(buf.end(), static_cast<uint32_t>(0xFFFFFFFF)); + buf.insert(buf.end(), "AAAAAAAA"); + DebuggerPrint("HEX: %s\n", buf.toHex().c_str()); + + DebuggerPrint("Hi.\n"); + Status = KsInitialize(); + + if (!NT_SUCCESS(Status)) { + return Status; + } + + int server_sockfd = socket_listen(AF_INET, SOCK_STREAM, 0); + + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(9096); + + int result = bind(server_sockfd, (struct sockaddr *)&addr, sizeof(addr)); + if (result != 0) { + DebuggerPrint("TCP server bind failed\n"); + return STATUS_FAILED_DRIVER_ENTRY; + } + + result = listen(server_sockfd, 1); + if (result != 0) { + DebuggerPrint("TCP server listen failed\n"); + return STATUS_FAILED_DRIVER_ENTRY; + } + + socklen_t addrlen = sizeof(addr); + int client_sockfd = accept(server_sockfd, (struct sockaddr *)&addr, &addrlen); + if (client_sockfd < 0) { + DebuggerPrint("TCP accept failed\n"); + return STATUS_FAILED_DRIVER_ENTRY; + } + + int iResult; + SocketBuffer<1024> sb_send, sb_recv; + + do { + RECV_PDU_BEGIN(client_sockfd, sb_recv, iResult, pdu_type, pdu_len) { + DebuggerPrint("PDU type/len: %u/%u\n", pdu_type, pdu_len); + if (pdu_type == 0) { + int ret = FunctionAddresses_verify_as_root(sb_recv.GetStart(), pdu_len); + + if (ret == 0) { + FunctionAddresses_table_t fnaddr = FunctionAddresses_as_root(sb_recv.GetStart()); + + if (!fnaddr) { + DebuggerPrint("%s\n", "FunctionAddresses not available!"); + } else { + flatbuffers_string_vec_t names = FunctionAddresses_names(fnaddr); + size_t name_size = flatbuffers_string_vec_len(names); + + DebuggerPrint("Length of names vector: %zu\n", name_size); + } + } else { + DebuggerPrint("Flatbuffer verification failed with %d: %s\n", ret, flatcc_verify_error_string(ret)); + } + } else { + DebuggerPrint("%s\n", "PDU type not supported!"); + } + } + RECV_PDU_END(sb_recv, pdu_len); + } while (iResult != SOCKET_ERROR && iResult > 0); + + DebuggerPrint("Client gone.\n") closesocket(client_sockfd); + closesocket(server_sockfd); + KsDestroy(); + + return STATUS_SUCCESS; +} + +VOID DriverUnload(_In_ struct _DRIVER_OBJECT *DriverObject) { + UNREFERENCED_PARAMETER(DriverObject); + + DebuggerPrint("Bye.\n"); +} +} diff --git a/examples/driver-protobuf-c-tcp.cpp b/examples/driver-protobuf-c-tcp.cpp index fcf4465..a7071ed 100644 --- a/examples/driver-protobuf-c-tcp.cpp +++ b/examples/driver-protobuf-c-tcp.cpp @@ -1,6 +1,7 @@ #include <ksocket/berkeley.h> #include <ksocket/helper.hpp> #include <ksocket/ksocket.h> +#include <ksocket/utils.h> #include <ksocket/wsk.h> #include "examples/common.hpp" diff --git a/examples/driver-protobuf-c.cpp b/examples/driver-protobuf-c.cpp index c3c8445..bbd0b52 100644 --- a/examples/driver-protobuf-c.cpp +++ b/examples/driver-protobuf-c.cpp @@ -1,6 +1,7 @@ #include <ksocket/berkeley.h> #include <ksocket/helper.hpp> #include <ksocket/ksocket.h> +#include <ksocket/utils.h> #include <ksocket/wsk.h> #include "examples/common.hpp" diff --git a/examples/driver.cpp b/examples/driver.cpp index 86e42e9..228acc7 100644 --- a/examples/driver.cpp +++ b/examples/driver.cpp @@ -1,9 +1,11 @@ - -extern "C" { #include <ksocket/berkeley.h> #include <ksocket/ksocket.h> +#include <ksocket/utils.h> #include <ksocket/wsk.h> +#include <ksocket/ksocket.hpp> + +extern "C" { DRIVER_INITIALIZE DriverEntry; DRIVER_UNLOAD DriverUnload; @@ -34,15 +36,15 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, // Perform HTTP request to http://httpbin.org/uuid // + const char send_buffer[] = "GET /uuid HTTP/1.1\r\n" + "Host: httpbin.org\r\n" + "Connection: close\r\n" + "\r\n"; + { int result; UNREFERENCED_PARAMETER(result); - char send_buffer[] = "GET /uuid HTTP/1.1\r\n" - "Host: httpbin.org\r\n" - "Connection: close\r\n" - "\r\n"; - char recv_buffer[1024] = {}; struct addrinfo hints = {}; @@ -66,7 +68,7 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, return STATUS_FAILED_DRIVER_ENTRY; } - result = send(sockfd, send_buffer, sizeof(send_buffer), 0); + result = send(sockfd, send_buffer, sizeof(send_buffer) - 1, 0); if (result <= 0) { DebuggerPrint("TCP client send failed\n"); return STATUS_FAILED_DRIVER_ENTRY; @@ -83,6 +85,48 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, closesocket(sockfd); } + { + KStreamClientIp4 tcp4_client = KStreamClientIp4(); + + if (!tcp4_client.setup()) { + DebuggerPrint("KStreamClientIp4 setup() failed: %d\n", + tcp4_client.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + + if (!tcp4_client.connect("httpbin.org", "80")) { + DebuggerPrint("KStreamClientIp4 connect() failed: %d\n", + tcp4_client.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + + DebuggerPrint("%s\n", "KStreamClientIp4 connected!"); + + tcp4_client.getSendBuffer().insert(tcp4_client.getSendBuffer().end(), + send_buffer, sizeof(send_buffer) - 1); + if (!tcp4_client.send()) { + DebuggerPrint("KStreamClientIp4 send() failed: %d\n", + tcp4_client.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + + if (!tcp4_client.recv()) { + DebuggerPrint("KStreamClientIp4 recv() failed: %d\n", + tcp4_client.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + + DebuggerPrint("KStreamClientIp4 data received:\n%s\n", + tcp4_client.getRecvBuffer().to_string().c_str()); + DebuggerPrint("KStreamClientIp4 consuming %zu bytes\n", + tcp4_client.getRecvBuffer().size()); + tcp4_client.getRecvBuffer().consume(); + DebuggerPrint("KStreamClientIp4 receive buffer size: %zu\n", + tcp4_client.getRecvBuffer().size()); + + DebuggerPrint("%s\n", "KStreamClientIp4 finished."); + } + // // TCP server. // Listen on port 9095, wait for some message, @@ -114,6 +158,10 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, return STATUS_FAILED_DRIVER_ENTRY; } + DebuggerPrint( + "%s\n", + "TCP server is waiting for the user to start userspace_client.exe"); + socklen_t addrlen = sizeof(addr); int client_sockfd = accept(server_sockfd, (struct sockaddr *)&addr, &addrlen); @@ -124,7 +172,7 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, result = recv(client_sockfd, recv_buffer, sizeof(recv_buffer) - 1, 0); if (result > 0) { - DebuggerPrint("TCP server:\n%.*s\n", result, recv_buffer); + DebuggerPrint("TCP server received: \"%.*s\"\n", result, recv_buffer); } else { DebuggerPrint("TCP server recv failed\n"); } @@ -139,6 +187,69 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject, closesocket(server_sockfd); } + { + KStreamServerIp4 tcp4_server = KStreamServerIp4(); + + if (!tcp4_server.setup()) { + DebuggerPrint("KStreamServerIp4 setup() failed: %d\n", + tcp4_server.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + + if (!tcp4_server.bind(9095)) { + DebuggerPrint("KStreamServerIp4 bind() failed: %d\n", + tcp4_server.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + + if (!tcp4_server.listen()) { + DebuggerPrint("KStreamServerIp4 bind() failed: %d\n", + tcp4_server.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + + DebuggerPrint("%s\n", "KStreamServerIp4 listening for incomining " + "connections (run userspace_client.exe again).."); + + const auto &accept_fn = [](KAcceptedSocket &ka) { + const auto &remote = ka.getRemote(); + + if (remote.addr_used != 4) { + return false; + } + DebuggerPrint("KStreamServerIp4 client connected: %s\n", + remote.to_string().c_str()); + + if (!ka.recv()) { + DebuggerPrint("KStreamServerIp4 recv failed: %d\n", ka.getLastError()); + return false; + } + DebuggerPrint("KStreamServerIp4 received %zu bytes: \"%s\"\n", + ka.getRecvBuffer().size(), + ka.getRecvBuffer().to_string().c_str()); + ka.getRecvBuffer().consume(); + + ka.getSendBuffer().insert_string(ka.getSendBuffer().end(), + "KStreamServerIp4 says hello!"); + if (!ka.send()) { + DebuggerPrint("KStreamServerIp4 send failed: %d\n", ka.getLastError()); + return false; + } + ka.getSendBuffer().consume(); + + // Wait for the connection termination. + ka.recv(); + + return true; + }; + if (!tcp4_server.accept(accept_fn)) { + DebuggerPrint("KStreamServerIp4 accept() failed: %d\n", + tcp4_server.getLastError()); + return STATUS_FAILED_DRIVER_ENTRY; + } + DebuggerPrint("KStreamServerIp4 done\n"); + } + KsDestroy(); return STATUS_SUCCESS; diff --git a/examples/userspace_client_flatbuffers.cpp b/examples/userspace_client_flatbuffers.cpp new file mode 100644 index 0000000..9f43698 --- /dev/null +++ b/examples/userspace_client_flatbuffers.cpp @@ -0,0 +1,82 @@ +#include <stdio.h> +#include <stdlib.h> // Needed for _wtoi +#include <winsock2.h> +#include <ws2tcpip.h> + +#include <ksocket/ksocket.hpp> +#include <ksocket/helper.hpp> + +#include "apiwrapper_builder.h" +#include "apiwrapper_reader.h" +#include "apiwrapper_verifier.h" + +int main(int argc, char **argv) { + WSADATA wsaData = {}; + int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); + + UNREFERENCED_PARAMETER(argc); + UNREFERENCED_PARAMETER(argv); + + if (iResult != 0) { + wprintf(L"WSAStartup failed: %d\n", iResult); + return 1; + } + + SOCKET sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + + if (sock == INVALID_SOCKET) { + wprintf(L"socket function failed with error = %d\n", WSAGetLastError()); + } else { + wprintf(L"socket function succeeded\n"); + } + + sockaddr_in clientService; + clientService.sin_family = AF_INET; + clientService.sin_addr.s_addr = inet_addr("127.0.0.1"); + clientService.sin_port = htons(9096); + + do { + iResult = connect(sock, (SOCKADDR *)&clientService, sizeof(clientService)); + if (iResult == SOCKET_ERROR) { + wprintf(L"connect function failed with error: %ld\n", WSAGetLastError()); + Sleep(1000); + } + } while (iResult == SOCKET_ERROR); + + wprintf(L"Connected to server.\n"); + + flatcc_builder_t builder; + flatcc_builder_init(&builder); + for (size_t i = 0; i < 256; ++i) { + FunctionAddresses_start_as_root(&builder); + FunctionAddresses_names_add(&builder, flatbuffers_string_create_str(&builder, "A")); + FunctionAddresses_names_add(&builder, flatbuffers_string_create_str(&builder, "B")); + FunctionAddresses_names_add(&builder, flatbuffers_string_create_str(&builder, "C")); + FunctionAddresses_end_as_root(&builder); + + KSocketBuffer buffer; + void *buf; + size_t siz; + buf = flatcc_builder_finalize_aligned_buffer(&builder, &siz); + (void)buf; + uint8_t a[] = {0x41,0x41,0x41}; + buffer.insert_u16(buffer.begin(), 65535); + buffer.insert_bytebuffer(buffer.begin(), a, 3); + } + + wprintf(L"Closing Connection ..\n"); + + iResult = closesocket(sock); + if (iResult == SOCKET_ERROR) { + wprintf(L"closesocket function failed with error: %ld\n", + WSAGetLastError()); + WSACleanup(); + return 1; + } + + WSACleanup(); + + system("pause"); + + return 0; +} diff --git a/examples/userspace_client_protobuf.cpp b/examples/userspace_client_protobuf.cpp index a387579..1eadc0c 100644 --- a/examples/userspace_client_protobuf.cpp +++ b/examples/userspace_client_protobuf.cpp @@ -4,6 +4,7 @@ #include <ws2tcpip.h> #include <ksocket/helper.hpp> +#include <ksocket/utils.h> #include "examples/common.hpp" #include "examples/example.pb-c.h" diff --git a/ksocket/berkeley.c b/ksocket/berkeley.c index e72f92e..73ccb79 100644 --- a/ksocket/berkeley.c +++ b/ksocket/berkeley.c @@ -297,14 +297,6 @@ VOID NTAPI KspUtilFreeAddrInfoEx(_In_ PADDRINFOEXW AddrInfo) { // Public functions. ////////////////////////////////////////////////////////////////////////// -uint32_t htonl(uint32_t hostlong) { return __builtin_bswap32(hostlong); } - -uint16_t htons(uint16_t hostshort) { return __builtin_bswap16(hostshort); } - -uint32_t ntohl(uint32_t netlong) { return __builtin_bswap32(netlong); } - -uint16_t ntohs(uint16_t netshort) { return __builtin_bswap16(netshort); } - int getaddrinfo(const char *node, const char *service, const struct addrinfo *hints, struct addrinfo **res) { NTSTATUS Status; diff --git a/ksocket/berkeley.h b/ksocket/berkeley.h index b879d3c..1475a89 100644 --- a/ksocket/berkeley.h +++ b/ksocket/berkeley.h @@ -1,6 +1,11 @@ -#pragma once +#ifndef KSOCKET_BERKELEY_H +#define KSOCKET_BERKELEY_H 1 + +#ifdef BUILD_USERMODE +#error "This file should only be included if building for kernel mode! Include <ksocket/ksocket.hpp> wrapper instead." +#endif + #include <ntddk.h> -#include <stdint.h> #include <ksocket/wsk.h> #define socket socket_connection @@ -12,11 +17,6 @@ extern "C" { typedef int socklen_t; typedef intptr_t ssize_t; -uint32_t htonl(uint32_t hostlong); -uint16_t htons(uint16_t hostshort); -uint32_t ntohl(uint32_t netlong); -uint16_t ntohs(uint16_t netshort); - int getaddrinfo(const char *node, const char *service, const struct addrinfo *hints, struct addrinfo **res); void freeaddrinfo(struct addrinfo *res); @@ -39,3 +39,5 @@ int closesocket(int sockfd); #ifdef __cplusplus } #endif + +#endif diff --git a/ksocket/helper.hpp b/ksocket/helper.hpp index 153e549..9143b07 100644 --- a/ksocket/helper.hpp +++ b/ksocket/helper.hpp @@ -11,6 +11,8 @@ #include <EASTL/string.h> #include <EASTL/vector.h> +#include <ksocket/utils.h> + #ifndef SOCKET_ERROR #define SOCKET_ERROR -1 #endif diff --git a/ksocket/ksocket.cpp b/ksocket/ksocket.cpp new file mode 100644 index 0000000..ca0625b --- /dev/null +++ b/ksocket/ksocket.cpp @@ -0,0 +1,350 @@ +#include "ksocket.hpp" + +#include <EASTL/type_traits.h> +#include <eastl_compat.hpp> + +#ifdef BUILD_USERMODE +// clang-format off +#include <winsock2.h> +#include <windows.h> +#include <ws2tcpip.h> +// clang-format on +#else +#include <ksocket/berkeley.h> +#include <ksocket/ksocket.h> +#include <ksocket/wsk.h> +#endif +#include <ksocket/utils.h> + +struct KSocketImplCommon { + ADDRESS_FAMILY domain; + int type; + int proto; +}; + +#ifdef BUILD_USERMODE +struct KSocketImpl { + ~KSocketImpl() { closesocket(s); } + + SOCKET s; + KSocketImplCommon c; +}; +#else +struct KSocketImpl { + ~KSocketImpl() { + if (s >= 0) { + closesocket(s); + s = -1; + } + } + + int s = -1; + KSocketImplCommon c; +}; +#endif + +eastl::string KSocketAddress::to_string(bool with_port) const { + if (addr_used != 4) { + return ""; + } + + return ::to_string(addr.u8[0]) + "." + ::to_string(addr.u8[1]) + "." + + ::to_string(addr.u8[2]) + "." + ::to_string(addr.u8[3]) + ":" + + ::to_string(with_port); +} + +void KSocketBuffer::insert_u8(KBuffer::iterator it, uint8_t value) { + buffer.insert(it, value); +} + +void KSocketBuffer::insert_u16(KBuffer::iterator it, uint16_t value) { + uint16_t net_value; + uint8_t insert_value[2]; + + net_value = htons(value); + insert_value[0] = (net_value & 0x00FF) >> 0; + insert_value[1] = (net_value & 0xFF00) >> 8; + buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); +} + +void KSocketBuffer::insert_u32(KBuffer::iterator it, uint32_t value) { + uint32_t net_value; + uint8_t insert_value[4]; + + net_value = htonl(value); + insert_value[0] = (net_value & 0x000000FF) >> 0; + insert_value[1] = (net_value & 0x0000FF00) >> 8; + insert_value[2] = (net_value & 0x00FF0000) >> 16; + insert_value[3] = (net_value & 0xFF000000) >> 24; + buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); +} + +void KSocketBuffer::insert_u64(KBuffer::iterator it, uint64_t value) { + uint64_t net_value; + uint8_t insert_value[8]; + + net_value = htonll(value); + insert_value[0] = (net_value & 0x00000000000000FF) >> 0; + insert_value[1] = (net_value & 0x000000000000FF00) >> 8; + insert_value[2] = (net_value & 0x0000000000FF0000) >> 16; + insert_value[3] = (net_value & 0x00000000FF000000) >> 24; + insert_value[4] = (net_value & 0x000000FF00000000) >> 32; + insert_value[5] = (net_value & 0x0000FF0000000000) >> 40; + insert_value[6] = (net_value & 0x00FF000000000000) >> 48; + insert_value[7] = (net_value & 0xFF00000000000000) >> 56; + buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); +} + +void KSocketBuffer::insert_bytebuffer(KBuffer::iterator it, + const uint8_t bytebuffer[], size_t size) { + buffer.insert(it, bytebuffer, bytebuffer + size); +} + +void KSocketBuffer::consume(size_t amount_bytes) { + if (amount_bytes == 0 || amount_bytes > buffer.size()) + amount_bytes = buffer.size(); + + buffer.erase(buffer.begin(), buffer.begin() + amount_bytes); +} + +eastl::string KSocketBuffer::toHex(eastl::string delim) { + eastl::string str; + char const *const hex = "0123456789ABCDEF"; + char pout[3] = {}; + + for (const auto &input_byte : buffer) { + pout[0] = hex[(input_byte >> 4) & 0xF]; + pout[1] = hex[input_byte & 0xF]; + str += pout; + str += delim; + } + + if (str.length() >= delim.length() && delim.length() > 0) + str.erase(str.length() - delim.length(), delim.length()); + + return str; +} + +KSocket::~KSocket() { + if (m_socket != nullptr) + delete m_socket; +} + +bool KSocket::setup(KSocketType sock_type, int proto) { + int domain, type; + + m_lastError = KSE_SUCCESS; + if (m_socket != nullptr) + delete m_socket; + m_socket = new KSocketImpl(); + if (m_socket == nullptr) { + m_lastError = KSE_SETUP_IMPL_NULL; + return false; + } + + if (KSocket::socketTypeToTuple(sock_type, domain, type) != true) { + m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE; + return false; + } + if (sock_type == KSocketType::KST_STREAM_CLIENT_IP6 || + sock_type == KSocketType::KST_STREAM_SERVER_IP6 || + sock_type == KSocketType::KST_DATAGRAM_IP6) { + m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; + return false; // IPv6 is not supported for now + } + m_socketType = sock_type; + +#ifdef BUILD_USERMODE + m_socket->s = ::socket(domain, type, proto); +#else + // DbgPrint("KSocketType: %d, domain: %d, type: %d, proto: %d", sock_type, + // domain, type, proto); + switch (sock_type) { + case KSocketType::KST_INVALID: + m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE; + return false; + case KSocketType::KST_STREAM_CLIENT_IP4: + case KSocketType::KST_STREAM_CLIENT_IP6: + m_socket->s = ::socket_connection(domain, type, proto); + break; + case KSocketType::KST_STREAM_SERVER_IP4: + case KSocketType::KST_STREAM_SERVER_IP6: + m_socket->s = ::socket_listen(domain, type, proto); + break; + case KSocketType::KST_DATAGRAM_IP4: + case KSocketType::KST_DATAGRAM_IP6: + m_socket->s = ::socket_datagram(domain, type, proto); + break; + } +#endif + m_socket->c.domain = static_cast<ADDRESS_FAMILY>(domain); + m_socket->c.type = type; + m_socket->c.proto = proto; +#ifdef BUILD_USERMODE + return m_socket->s != INVALID_SOCKET; +#else + return m_socket->s >= 0; +#endif +} + +bool KSocket::connect(eastl::string host, eastl::string port) { + struct addrinfo hints = {}; + struct addrinfo *results; + + m_lastError = KSE_SUCCESS; + + if (m_socket == nullptr) + return false; + + if (m_socketType != KSocketType::KST_STREAM_CLIENT_IP4 && + m_socketType != KSocketType::KST_STREAM_CLIENT_IP6) + return false; + + hints.ai_flags |= AI_CANONNAME; + hints.ai_family = m_socket->c.domain; + hints.ai_socktype = m_socket->c.type; + m_lastError = ::getaddrinfo(host.c_str(), port.c_str(), &hints, &results); + if (m_lastError != KSE_SUCCESS) + return false; + + m_lastError = + ::connect(m_socket->s, results->ai_addr, (int)results->ai_addrlen); + freeaddrinfo(results); + return m_lastError == KSE_SUCCESS; +} + +bool KSocket::bind(uint16_t port) { + struct sockaddr_in addr; + + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(port); + + m_lastError = ::bind(m_socket->s, (struct sockaddr *)&addr, sizeof(addr)); + + return m_lastError == KSE_SUCCESS; +} + +bool KSocket::listen(int backlog) { + m_lastError = KSE_SUCCESS; + + if (m_socketType != KSocketType::KST_STREAM_SERVER_IP4 && + m_socketType != KSocketType::KST_STREAM_SERVER_IP6) + return false; + + m_lastError = ::listen(m_socket->s, backlog); + + return m_lastError == KSE_SUCCESS; +} + +bool KSocket::accept(KAcceptThreadCallback thread_callback) { + KAcceptedSocket ka; + struct sockaddr addr; + socklen_t addrlen = sizeof(addr); + + if (m_socket->c.domain != AF_INET) { + m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; + return false; + } + + addr.sa_family = m_socket->c.domain; + ka.m_socket = new KSocketImpl(); + if (ka.m_socket == nullptr) { + m_lastError = KSE_SETUP_IMPL_NULL; + return false; + } + ka.m_socketType = m_socketType; + ka.m_socket->c = m_socket->c; + ka.m_socket->s = ::accept(m_socket->s, &addr, &addrlen); + + if (m_socket->c.domain == AF_INET) { + struct sockaddr_in *addr_in = reinterpret_cast<struct sockaddr_in *>(&addr); + + ka.m_remote.addr_used = 4; + ka.m_remote.addr.u32[0] = addr_in->sin_addr.s_addr; + ka.m_remote.port = addr_in->sin_port; + } else { + m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; + return false; + } + +#ifdef BUILD_USERMODE + if (ka.m_socket->s == INVALID_SOCKET) +#else + if (ka.m_socket->s < 0) +#endif + { + m_lastError = KSE_ACCEPT_FAILED; + return false; + } + + return thread_callback(ka); +} + +bool KSocket::close() { + int rv = closesocket(m_socket->s); + + if (rv == 0) { + m_socket->s = -1; + return true; + } else { + return false; + } +} + +bool KSocket::send() { + m_lastError = KSE_SUCCESS; + + if (m_sendBuffer.size() == 0) + return false; + + m_lastError = + ::send(m_socket->s, reinterpret_cast<const char *>(m_sendBuffer.data()), + m_sendBuffer.size(), 0); + if (m_lastError > 0) { + m_sendBuffer.buffer.erase(m_sendBuffer.begin(), + m_sendBuffer.begin() + m_lastError); + return true; + } + + return false; +} + +bool KSocket::recv(size_t max_recv_size) { + const size_t current_size = m_recvBuffer.size(); + + m_recvBuffer.buffer.resize(current_size + max_recv_size); + m_lastError = ::recv( + m_socket->s, reinterpret_cast<char *>(m_recvBuffer.data() + current_size), + max_recv_size, 0); + if (m_lastError > 0) { + m_recvBuffer.buffer.resize(current_size + m_lastError); + return true; + } + + return false; +} + +bool KSocket::socketTypeToTuple(KSocketType sock_type, int &domain, int &type) { + switch (sock_type) { + case KSocketType::KST_INVALID: + break; + case KSocketType::KST_STREAM_CLIENT_IP4: + case KSocketType::KST_STREAM_SERVER_IP4: + domain = AF_INET; + type = SOCK_STREAM; + return true; + case KSocketType::KST_STREAM_CLIENT_IP6: + case KSocketType::KST_STREAM_SERVER_IP6: + domain = AF_INET6; + type = SOCK_STREAM; + return true; + case KSocketType::KST_DATAGRAM_IP4: + case KSocketType::KST_DATAGRAM_IP6: + domain = AF_INET; + type = SOCK_DGRAM; + return true; + } + + return false; +} diff --git a/ksocket/ksocket.h b/ksocket/ksocket.h index 0a6717e..e44035b 100644 --- a/ksocket/ksocket.h +++ b/ksocket/ksocket.h @@ -1,5 +1,12 @@ -#pragma once -#include "wsk.h" +#ifndef KSOCKET_H +#define KSOCKET_H 1 + +#ifdef BUILD_USERMODE +#error \ + "This file should only be included if building for kernel mode! Include <ksocket/ksocket.hpp> wrapper instead." +#endif + +#include <ksocket/wsk.h> #include <ntddk.h> #define STATUS_UNSUPPORTED_WINDOWS_VERSION \ @@ -97,3 +104,5 @@ KsRecvFrom(_In_ PKSOCKET Socket, _In_ PVOID Buffer, _Inout_ PULONG Length, #ifdef __cplusplus } #endif + +#endif diff --git a/ksocket/ksocket.hpp b/ksocket/ksocket.hpp new file mode 100644 index 0000000..abf03b6 --- /dev/null +++ b/ksocket/ksocket.hpp @@ -0,0 +1,233 @@ +#ifndef KSOCKET_HPP +#define KSOCKET_HPP 1 + +#include <EASTL/functional.h> +#include <EASTL/string.h> +#include <EASTL/vector.h> +#include <cstdint> + +using KBuffer = eastl::vector<uint8_t>; + +struct KSocketImpl; + +enum { + KSE_SUCCESS = 0, + KSE_SETUP_IMPL_NULL = 1, + KSE_SETUP_INVALID_SOCKET_TYPE = 2, + KSE_SETUP_UNSUPPORTED_SOCKET_TYPE = 3, + KSE_ACCEPT_FAILED = 4, +}; + +enum class KSocketType { + KST_INVALID = 0, + KST_STREAM_CLIENT_IP4, + KST_STREAM_SERVER_IP4, + KST_STREAM_CLIENT_IP6, + KST_STREAM_SERVER_IP6, + KST_DATAGRAM_IP4, + KST_DATAGRAM_IP6 +}; + +struct KSocketAddress { + eastl::string to_string(bool with_port = true) const; + uint8_t addr_used = 0; + union { + uint8_t u8[16]; + uint16_t u16[8]; + uint32_t u32[4]; + uint64_t u64[2]; + } addr; + uint16_t port = 0; +}; + +class KAcceptedSocket; +using KAcceptThreadCallback = eastl::function<bool(KAcceptedSocket &accepted)>; + +struct KSocketBuffer { + void insert_i8(KBuffer::iterator it, int8_t value) { + insert_u8(it, static_cast<uint8_t>(value)); + } + void insert_i16(KBuffer::iterator it, int16_t value) { + insert_u16(it, static_cast<uint16_t>(value)); + } + void insert_i32(KBuffer::iterator it, int32_t value) { + insert_u32(it, static_cast<uint32_t>(value)); + } + void insert_i64(KBuffer::iterator it, int64_t value) { + insert_u64(it, static_cast<uint64_t>(value)); + } + + void insert_u8(KBuffer::iterator it, uint8_t value); + void insert_u16(KBuffer::iterator it, uint16_t value); + void insert_u32(KBuffer::iterator it, uint32_t value); + void insert_u64(KBuffer::iterator it, uint64_t value); + + void insert_string(KBuffer::iterator it, const eastl::string &value) { + insert_bytebuffer(it, reinterpret_cast<const uint8_t *>(value.c_str()), + value.length()); + } + void insert_string(KBuffer::iterator it, const char buffer[], size_t size) { + insert_bytebuffer(it, reinterpret_cast<const uint8_t *>(buffer), size); + } + void insert_bytebuffer(KBuffer::iterator it, const void *bytebuffer, + size_t size) { + insert_bytebuffer(it, reinterpret_cast<const uint8_t *>(bytebuffer), size); + } + void insert_bytebuffer(KBuffer::iterator it, const uint8_t bytebuffer[], + size_t size); + + void insert(KBuffer::iterator it, int8_t value) { + insert_u8(it, static_cast<uint8_t>(value)); + } + void insert(KBuffer::iterator it, int16_t value) { + insert_u16(it, static_cast<uint16_t>(value)); + } + void insert(KBuffer::iterator it, int32_t value) { + insert_u32(it, static_cast<uint32_t>(value)); + } + void insert(KBuffer::iterator it, int64_t value) { + insert_u64(it, static_cast<uint64_t>(value)); + } + + void insert(KBuffer::iterator it, uint8_t value) { insert_u8(it, value); } + void insert(KBuffer::iterator it, uint16_t value) { insert_u16(it, value); } + void insert(KBuffer::iterator it, uint32_t value) { insert_u32(it, value); } + void insert(KBuffer::iterator it, uint64_t value) { insert_u64(it, value); } + + void insert(KBuffer::iterator it, const eastl::string value) { + insert_string(it, eastl::move(value)); + } + void insert(KBuffer::iterator it, const char buffer[], size_t size) { + insert_string(it, buffer, size); + } + void insert(KBuffer::iterator it, const char *value) { + insert_string(it, eastl::move(eastl::string(value))); + } + void insert(KBuffer::iterator it, const uint8_t bytebuffer[], size_t size) { + insert_bytebuffer(it, bytebuffer, size); + } + + void consume(size_t amount_bytes = 0); + + eastl::string to_string() { + return eastl::string(reinterpret_cast<const char *>(data()), size()); + } + uint8_t *data() { return buffer.data(); } + size_t size() { return buffer.size(); } + KBuffer::iterator begin() { return buffer.begin(); } + KBuffer::iterator end() { return buffer.end(); } + + eastl::string toHex(eastl::string delim = ":"); + + KBuffer buffer; +}; + +class KSocket { +protected: + KSocket() : m_sendBuffer(), m_recvBuffer() {} + KSocket(const KSocket &) = delete; + ~KSocket(); + bool setup(KSocketType sock_type, int proto = 0); + + bool connect(eastl::string host, eastl::string port); + bool bind(uint16_t port); + bool listen(int backlog); + bool accept(KAcceptThreadCallback thread_callback); + bool close(); + + bool send(); + bool recv(size_t max_recv_size); + +public: + int getLastError() const { return m_lastError; } + + KSocketBuffer &getSendBuffer() { return m_sendBuffer; } + KSocketBuffer &getRecvBuffer() { return m_recvBuffer; } + + static bool socketTypeToTuple(KSocketType sock_type, int &domain, int &type); + +private: + KSocketBuffer m_sendBuffer; + KSocketBuffer m_recvBuffer; + KSocketType m_socketType = KSocketType::KST_INVALID; + KSocketImpl *m_socket = nullptr; + int m_lastError = KSE_SUCCESS; +}; + +class KAcceptedSocket : public KSocket { + friend class KSocket; + +public: + KAcceptedSocket() : KSocket() {} + ~KAcceptedSocket() {} + bool setup(KSocketType sock_type, int proto = 0) = delete; + + bool connect(eastl::string host, eastl::string port) = delete; + bool bind(uint16_t port) = delete; + bool listen(int backlog) = delete; + bool accept(KAcceptThreadCallback thread_callback) = delete; + + bool send() { return KSocket::send(); } + bool recv(size_t max_recv_size = 65535) { + return KSocket::recv(max_recv_size); + } + + int getLastError() { return KSocket::getLastError(); } + + const KSocketAddress &getRemote() { return m_remote; } + +private: + KSocketAddress m_remote; +}; + +class KStreamClientIp4 : public KSocket { +public: + KStreamClientIp4() : KSocket() {} + ~KStreamClientIp4() {} + + bool setup() { + return KSocket::setup(KSocketType::KST_STREAM_CLIENT_IP4, + 6 /* IPPROTO_TCP */); + } + + bool connect(eastl::string host, eastl::string port) { + return KSocket::connect(host, port); + } + bool bind(uint16_t) = delete; + bool listen(int) = delete; + bool accept(KAcceptThreadCallback) = delete; + + bool send() { return KSocket::send(); } + bool recv(size_t max_recv_size = 65535) { + return KSocket::recv(max_recv_size); + } + + int getLastError() { return KSocket::getLastError(); } +}; + +class KStreamServerIp4 : public KSocket { +public: + KStreamServerIp4() : KSocket() {} + ~KStreamServerIp4() {} + + bool setup() { + return KSocket::setup(KSocketType::KST_STREAM_SERVER_IP4, + 6 /* IPPROTO_TCP */); + } + + bool connect(eastl::string host, eastl::string port) = delete; + bool bind(uint16_t port) { return KSocket::bind(port); } + bool listen(int backlog = 16) { return KSocket::listen(backlog); } + bool accept(KAcceptThreadCallback thread_callback) { + return KSocket::accept(thread_callback); + } + + bool send() { return KSocket::send(); } + bool recv(size_t max_recv_size = 65535) { + return KSocket::recv(max_recv_size); + } + + int getLastError() { return KSocket::getLastError(); } +}; + +#endif diff --git a/ksocket/utils.c b/ksocket/utils.c new file mode 100644 index 0000000..1fe26ea --- /dev/null +++ b/ksocket/utils.c @@ -0,0 +1,17 @@ +#include "utils.h" + +uint64_t htonll(uint64_t hostlonglong) { return __builtin_bswap64(hostlonglong); } + +#ifndef BUILD_USERMODE +uint32_t htonl(uint32_t hostlong) { return __builtin_bswap32(hostlong); } + +uint16_t htons(uint16_t hostshort) { return __builtin_bswap16(hostshort); } +#endif + +uint64_t ntohll(uint64_t netlonglong) { return __builtin_bswap64(netlonglong); } + +#ifndef BUILD_USERMODE +uint32_t ntohl(uint32_t netlong) { return __builtin_bswap32(netlong); } + +uint16_t ntohs(uint16_t netshort) { return __builtin_bswap16(netshort); } +#endif diff --git a/ksocket/utils.h b/ksocket/utils.h new file mode 100644 index 0000000..e3c0474 --- /dev/null +++ b/ksocket/utils.h @@ -0,0 +1,30 @@ +#ifndef KSOCKET_UTILS_H +#define KSOCKET_UTILS_H 1 + +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif + +uint64_t htonll(uint64_t hostlonglong); + +#ifndef BUILD_USERMODE +uint32_t htonl(uint32_t hostlong); + +uint16_t htons(uint16_t hostshort); +#endif + +uint64_t ntohll(uint64_t netlonglong); + +#ifndef BUILD_USERMODE +uint32_t ntohl(uint32_t netlong); + +uint16_t ntohs(uint16_t netshort); +#endif + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/ksocket/wsk.h b/ksocket/wsk.h index ccea103..25d7f06 100644 --- a/ksocket/wsk.h +++ b/ksocket/wsk.h @@ -1,6 +1,10 @@ #ifndef WSK_H #define WSK_H 1 +#ifdef BUILD_USERMODE +#error "This file should only be included if building for kernel mode! Include <ksocket/ksocket.hpp> wrapper instead." +#endif + #include <ntddk.h> #if !defined(__MINGW64__) |