aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorToni Uhlig <matzeton@googlemail.com>2023-09-15 11:21:31 +0200
committerToni Uhlig <matzeton@googlemail.com>2023-09-15 11:21:31 +0200
commit0cbfbe129934976359460fdbe69fb97632d81d24 (patch)
tree3d4685fc1f02244c80d4f5de2fa6c80fae94a425
parent37d1e657e5e79bc240ea036cfb8da377b1640490 (diff)
Added C++ (`ksocket/ksocket.hpp`) Socket wrapper classes.
* another flatbuffers example (WiP!) * Makefile improvements Signed-off-by: Toni Uhlig <matzeton@googlemail.com>
-rw-r--r--Makefile54
-rw-r--r--examples/apiwrapper.fbs6
-rw-r--r--examples/apiwrapper_builder.h55
-rw-r--r--examples/apiwrapper_reader.h54
-rw-r--r--examples/apiwrapper_verifier.h43
-rw-r--r--examples/driver-flatbuffers-tcp.bat27
-rw-r--r--examples/driver-flatbuffers-tcp.cpp108
-rw-r--r--examples/driver-protobuf-c-tcp.cpp1
-rw-r--r--examples/driver-protobuf-c.cpp1
-rw-r--r--examples/driver.cpp129
-rw-r--r--examples/userspace_client_flatbuffers.cpp82
-rw-r--r--examples/userspace_client_protobuf.cpp1
-rw-r--r--ksocket/berkeley.c8
-rw-r--r--ksocket/berkeley.h16
-rw-r--r--ksocket/helper.hpp2
-rw-r--r--ksocket/ksocket.cpp350
-rw-r--r--ksocket/ksocket.h13
-rw-r--r--ksocket/ksocket.hpp233
-rw-r--r--ksocket/utils.c17
-rw-r--r--ksocket/utils.h30
-rw-r--r--ksocket/wsk.h4
21 files changed, 1196 insertions, 38 deletions
diff --git a/Makefile b/Makefile
index e02b4f5..381cf04 100644
--- a/Makefile
+++ b/Makefile
@@ -12,43 +12,61 @@ FLATBUFFERS_CFLAGS = -Iflatbuffers-build/include -Wno-misleading-indentation
FLATBUFFERS_FLATC = flatbuffers-flatcc-build/bin/flatcc
DRIVER0_NAME = driver
-DRIVER0_OBJECTS = examples/$(DRIVER0_NAME).o ksocket/ksocket.o ksocket/berkeley.o
+DRIVER0_OBJECTS = examples/$(DRIVER0_NAME).opp ksocket/ksocket.o ksocket/berkeley.o ksocket/ksocket.opp ksocket/utils.o
DRIVER0_TARGET = $(DRIVER0_NAME).sys
DRIVER1_NAME = driver-protobuf-c
-DRIVER1_OBJECTS = examples/$(DRIVER1_NAME).o examples/example.pb-c.o $(PROTOBUF_OBJECT)
+DRIVER1_OBJECTS = examples/$(DRIVER1_NAME).opp ksocket/utils.o examples/example.pb-c.o $(PROTOBUF_OBJECT)
DRIVER1_TARGET = $(DRIVER1_NAME).sys
DRIVER2_NAME = driver-protobuf-c-tcp
-DRIVER2_OBJECTS = examples/$(DRIVER2_NAME).o ksocket/ksocket.o ksocket/berkeley.o examples/example.pb-c.o $(PROTOBUF_OBJECT)
+DRIVER2_OBJECTS = examples/$(DRIVER2_NAME).opp ksocket/ksocket.o ksocket/berkeley.o ksocket/utils.o examples/example.pb-c.o $(PROTOBUF_OBJECT)
DRIVER2_TARGET = $(DRIVER2_NAME).sys
DRIVER3_NAME = driver-flatbuffers
-DRIVER3_OBJECTS = examples/$(DRIVER3_NAME).o ksocket/ksocket.o ksocket/berkeley.o $(FLATBUFFERS_LIB)
+DRIVER3_OBJECTS = examples/$(DRIVER3_NAME).opp ksocket/ksocket.o ksocket/berkeley.o $(FLATBUFFERS_LIB)
DRIVER3_TARGET = $(DRIVER3_NAME).sys
+DRIVER4_NAME = driver-flatbuffers-tcp
+DRIVER4_OBJECTS = examples/$(DRIVER4_NAME).opp ksocket/ksocket.opp ksocket/ksocket.o ksocket/berkeley.o ksocket/utils.o $(FLATBUFFERS_LIB)
+DRIVER4_TARGET = $(DRIVER4_NAME).sys
+
USERSPACE0_NAME = userspace_client
-USERSPACE0_OBJECTS = examples/$(USERSPACE0_NAME).o
+USERSPACE0_OBJECTS = examples/$(USERSPACE0_NAME).opp
USERSPACE0_TARGET = $(USERSPACE0_NAME).exe
USERSPACE1_NAME = userspace_client_protobuf
-USERSPACE1_OBJECTS = examples/$(USERSPACE1_NAME).o examples/example.pb-c.o $(PROTOBUF_OBJECT)
+USERSPACE1_OBJECTS = examples/$(USERSPACE1_NAME).opp examples/example.pb-c.o $(PROTOBUF_OBJECT)
USERSPACE1_TARGET = $(USERSPACE1_NAME).exe
+USERSPACE2_NAME = userspace_client_flatbuffers
+USERSPACE2_OBJECTS = examples/$(USERSPACE2_NAME).opp ksocket/ksocket_user.opp ksocket/utils_user.o $(FLATBUFFERS_LIB)
+USERSPACE2_TARGET = $(USERSPACE2_NAME).exe
+
# mingw-w64-dpp related
+CFLAGS_examples/$(USERSPACE1_NAME).opp = -DBUILD_USERMODE=1
+CFLAGS_examples/$(USERSPACE2_NAME).opp = -DBUILD_USERMODE=1
+CFLAGS_ksocket/utils_user.o = -DBUILD_USERMODE=1
+CFLAGS_ksocket/ksocket_user.opp = -DBUILD_USERMODE=1
CFLAGS_$(PROTOBUF_OBJECT) = $(PROTOBUF_CFLAGS_PRIVATE)
CUSTOM_CFLAGS = -I. -Iexamples -Werror $(FLATBUFFERS_CFLAGS) -Wl,--exclude-all-symbols -DNDEBUG
DRIVER_LIBS += -lnetio
USER_LIBS += -lws2_32
-all: deps $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET)
+all: deps $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(DRIVER4_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) $(USERSPACE2_TARGET)
-%.o: %.cpp
+%.opp: %.cpp
$(call BUILD_CPP_OBJECT,$<,$@)
%.o: %.c
$(call BUILD_C_OBJECT,$<,$@)
+ksocket/utils_user.o: ksocket/utils.c
+ $(call BUILD_C_OBJECT,$<,$@)
+
+ksocket/ksocket_user.opp: ksocket/ksocket.cpp
+ $(call BUILD_CPP_OBJECT,$<,$@)
+
$(PROTOBUF_OBJECT): protobuf-c/protobuf-c.c
mkdir -p '$(dir $(PROTOBUF_OBJECT))'
$(call BUILD_C_OBJECT,$<,$@)
@@ -65,12 +83,18 @@ $(DRIVER2_TARGET): $(DRIVER2_OBJECTS)
$(DRIVER3_TARGET): $(FLATBUFFERS_LIB) $(DRIVER3_OBJECTS)
$(call LINK_CPP_KERNEL_TARGET,$(DRIVER3_OBJECTS),$@)
+$(DRIVER4_TARGET): $(FLATBUFFERS_LIB) $(DRIVER4_OBJECTS)
+ $(call LINK_CPP_KERNEL_TARGET,$(DRIVER4_OBJECTS),$@)
+
$(USERSPACE0_TARGET): $(USERSPACE0_OBJECTS)
$(call LINK_CPP_USER_TARGET,$(USERSPACE0_OBJECTS),$@)
$(USERSPACE1_TARGET): $(USERSPACE1_OBJECTS)
$(call LINK_CPP_USER_TARGET,$(USERSPACE1_OBJECTS),$@)
+$(USERSPACE2_TARGET): $(USERSPACE2_OBJECTS)
+ $(call LINK_CPP_USER_TARGET,$(USERSPACE2_OBJECTS),$@)
+
deps: $(FLATBUFFERS_LIB) $(FLATBUFFERS_FLATC)
$(FLATBUFFERS_LIB):
@@ -98,6 +122,7 @@ $(FLATBUFFERS_FLATC):
generate: $(FLATBUFFERS_FLATC)
@echo 'Generating flatbuffer files..'
$(FLATBUFFERS_FLATC) -a -o examples examples/monster.fbs
+ $(FLATBUFFERS_FLATC) -a -o examples examples/apiwrapper.fbs
@echo '=========================================='
@echo '= You need protobuf-c to make this work! ='
@echo '=========================================='
@@ -105,26 +130,31 @@ generate: $(FLATBUFFERS_FLATC)
@echo
protoc-c --c_out=. examples/example.proto
-install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET)
+install: $(DRIVER0_TARGET) $(DRIVER1_TARGET) $(DRIVER2_TARGET) $(DRIVER3_TARGET) $(DRIVER4_TARGET) $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) $(USERSPACE2_TARGET)
$(call INSTALL_EXEC_SIGN,$(DRIVER0_TARGET))
$(call INSTALL_EXEC_SIGN,$(DRIVER1_TARGET))
$(call INSTALL_EXEC_SIGN,$(DRIVER2_TARGET))
$(call INSTALL_EXEC_SIGN,$(DRIVER3_TARGET))
+ $(call INSTALL_EXEC_SIGN,$(DRIVER4_TARGET))
$(call INSTALL_EXEC,$(USERSPACE0_TARGET))
$(call INSTALL_EXEC,$(USERSPACE1_TARGET))
+ $(call INSTALL_EXEC,$(USERSPACE2_TARGET))
$(INSTALL) 'examples/$(DRIVER0_NAME).bat' '$(DESTDIR)/'
$(INSTALL) 'examples/$(DRIVER1_NAME).bat' '$(DESTDIR)/'
$(INSTALL) 'examples/$(DRIVER2_NAME).bat' '$(DESTDIR)/'
$(INSTALL) 'examples/$(DRIVER3_NAME).bat' '$(DESTDIR)/'
+ $(INSTALL) 'examples/$(DRIVER4_NAME).bat' '$(DESTDIR)/'
clean:
+ rm -f examples/*.o examples/*.opp
rm -f $(DRIVER0_OBJECTS) $(DRIVER1_OBJECTS) $(DRIVER2_OBJECTS) $(DRIVER3_OBJECTS)
rm -f $(DRIVER0_TARGET) $(DRIVER0_TARGET).map \
$(DRIVER1_TARGET) $(DRIVER1_TARGET).map \
$(DRIVER2_TARGET) $(DRIVER2_TARGET).map \
- $(DRIVER3_TARGET) $(DRIVER3_TARGET).map
- rm -f $(USERSPACE0_OBJECTS) $(USERSPACE1_OBJECTS)
- rm -f $(USERSPACE0_TARGET) $(USERSPACE1_TARGET)
+ $(DRIVER3_TARGET) $(DRIVER3_TARGET).map \
+ $(DRIVER4_TARGET) $(DRIVER4_TARGET).map
+ rm -f $(USERSPACE0_OBJECTS) $(USERSPACE1_OBJECTS) $(USERSPACE2_OBJECTS)
+ rm -f $(USERSPACE0_TARGET) $(USERSPACE1_TARGET) $(USERSPACE2_TARGET)
distclean: clean
rm -rf flatbuffers-build flatbuffers-flatcc-build
diff --git a/examples/apiwrapper.fbs b/examples/apiwrapper.fbs
new file mode 100644
index 0000000..6de1330
--- /dev/null
+++ b/examples/apiwrapper.fbs
@@ -0,0 +1,6 @@
+table FunctionAddresses {
+ names:[string];
+ addrs:[uint64];
+}
+
+root_type FunctionAddresses;
diff --git a/examples/apiwrapper_builder.h b/examples/apiwrapper_builder.h
new file mode 100644
index 0000000..f0e0c1b
--- /dev/null
+++ b/examples/apiwrapper_builder.h
@@ -0,0 +1,55 @@
+#ifndef APIWRAPPER_BUILDER_H
+#define APIWRAPPER_BUILDER_H
+
+/* Generated by flatcc 0.6.2 FlatBuffers schema compiler for C by dvide.com */
+
+#ifndef APIWRAPPER_READER_H
+#include "apiwrapper_reader.h"
+#endif
+#ifndef FLATBUFFERS_COMMON_BUILDER_H
+#include "flatbuffers_common_builder.h"
+#endif
+#include "flatcc/flatcc_prologue.h"
+#ifndef flatbuffers_identifier
+#define flatbuffers_identifier 0
+#endif
+#ifndef flatbuffers_extension
+#define flatbuffers_extension "bin"
+#endif
+
+static const flatbuffers_voffset_t __FunctionAddresses_required[] = { 0 };
+typedef flatbuffers_ref_t FunctionAddresses_ref_t;
+static FunctionAddresses_ref_t FunctionAddresses_clone(flatbuffers_builder_t *B, FunctionAddresses_table_t t);
+__flatbuffers_build_table(flatbuffers_, FunctionAddresses, 2)
+
+#define __FunctionAddresses_formal_args , flatbuffers_string_vec_ref_t v0, flatbuffers_uint64_vec_ref_t v1
+#define __FunctionAddresses_call_args , v0, v1
+static inline FunctionAddresses_ref_t FunctionAddresses_create(flatbuffers_builder_t *B __FunctionAddresses_formal_args);
+__flatbuffers_build_table_prolog(flatbuffers_, FunctionAddresses, FunctionAddresses_file_identifier, FunctionAddresses_type_identifier)
+
+__flatbuffers_build_string_vector_field(0, flatbuffers_, FunctionAddresses_names, FunctionAddresses)
+__flatbuffers_build_vector_field(1, flatbuffers_, FunctionAddresses_addrs, flatbuffers_uint64, uint64_t, FunctionAddresses)
+
+static inline FunctionAddresses_ref_t FunctionAddresses_create(flatbuffers_builder_t *B __FunctionAddresses_formal_args)
+{
+ if (FunctionAddresses_start(B)
+ || FunctionAddresses_names_add(B, v0)
+ || FunctionAddresses_addrs_add(B, v1)) {
+ return 0;
+ }
+ return FunctionAddresses_end(B);
+}
+
+static FunctionAddresses_ref_t FunctionAddresses_clone(flatbuffers_builder_t *B, FunctionAddresses_table_t t)
+{
+ __flatbuffers_memoize_begin(B, t);
+ if (FunctionAddresses_start(B)
+ || FunctionAddresses_names_pick(B, t)
+ || FunctionAddresses_addrs_pick(B, t)) {
+ return 0;
+ }
+ __flatbuffers_memoize_end(B, t, FunctionAddresses_end(B));
+}
+
+#include "flatcc/flatcc_epilogue.h"
+#endif /* APIWRAPPER_BUILDER_H */
diff --git a/examples/apiwrapper_reader.h b/examples/apiwrapper_reader.h
new file mode 100644
index 0000000..f7e8253
--- /dev/null
+++ b/examples/apiwrapper_reader.h
@@ -0,0 +1,54 @@
+#ifndef APIWRAPPER_READER_H
+#define APIWRAPPER_READER_H
+
+/* Generated by flatcc 0.6.2 FlatBuffers schema compiler for C by dvide.com */
+
+#ifndef FLATBUFFERS_COMMON_READER_H
+#include "flatbuffers_common_reader.h"
+#endif
+#include "flatcc/flatcc_flatbuffers.h"
+#ifndef __alignas_is_defined
+#include <stdalign.h>
+#endif
+#include "flatcc/flatcc_prologue.h"
+#ifndef flatbuffers_identifier
+#define flatbuffers_identifier 0
+#endif
+#ifndef flatbuffers_extension
+#define flatbuffers_extension "bin"
+#endif
+
+
+typedef const struct FunctionAddresses_table *FunctionAddresses_table_t;
+typedef struct FunctionAddresses_table *FunctionAddresses_mutable_table_t;
+typedef const flatbuffers_uoffset_t *FunctionAddresses_vec_t;
+typedef flatbuffers_uoffset_t *FunctionAddresses_mutable_vec_t;
+#ifndef FunctionAddresses_file_identifier
+#define FunctionAddresses_file_identifier 0
+#endif
+/* deprecated, use FunctionAddresses_file_identifier */
+#ifndef FunctionAddresses_identifier
+#define FunctionAddresses_identifier 0
+#endif
+#define FunctionAddresses_type_hash ((flatbuffers_thash_t)0x73e9a2df)
+#define FunctionAddresses_type_identifier "\xdf\xa2\xe9\x73"
+#ifndef FunctionAddresses_file_extension
+#define FunctionAddresses_file_extension "bin"
+#endif
+
+
+
+struct FunctionAddresses_table { uint8_t unused__; };
+
+static inline size_t FunctionAddresses_vec_len(FunctionAddresses_vec_t vec)
+__flatbuffers_vec_len(vec)
+static inline FunctionAddresses_table_t FunctionAddresses_vec_at(FunctionAddresses_vec_t vec, size_t i)
+__flatbuffers_offset_vec_at(FunctionAddresses_table_t, vec, i, 0)
+__flatbuffers_table_as_root(FunctionAddresses)
+
+__flatbuffers_define_vector_field(0, FunctionAddresses, names, flatbuffers_string_vec_t, 0)
+__flatbuffers_define_vector_field(1, FunctionAddresses, addrs, flatbuffers_uint64_vec_t, 0)
+
+
+#include "flatcc/flatcc_epilogue.h"
+#endif /* APIWRAPPER_READER_H */
diff --git a/examples/apiwrapper_verifier.h b/examples/apiwrapper_verifier.h
new file mode 100644
index 0000000..fc88b41
--- /dev/null
+++ b/examples/apiwrapper_verifier.h
@@ -0,0 +1,43 @@
+#ifndef APIWRAPPER_VERIFIER_H
+#define APIWRAPPER_VERIFIER_H
+
+/* Generated by flatcc 0.6.2 FlatBuffers schema compiler for C by dvide.com */
+
+#ifndef APIWRAPPER_READER_H
+#include "apiwrapper_reader.h"
+#endif
+#include "flatcc/flatcc_verifier.h"
+#include "flatcc/flatcc_prologue.h"
+
+static int FunctionAddresses_verify_table(flatcc_table_verifier_descriptor_t *td);
+
+static int FunctionAddresses_verify_table(flatcc_table_verifier_descriptor_t *td)
+{
+ int ret;
+ if ((ret = flatcc_verify_string_vector_field(td, 0, 0) /* names */)) return ret;
+ if ((ret = flatcc_verify_vector_field(td, 1, 0, 8, 8, INT64_C(536870911)) /* addrs */)) return ret;
+ return flatcc_verify_ok;
+}
+
+static inline int FunctionAddresses_verify_as_root(const void *buf, size_t bufsiz)
+{
+ return flatcc_verify_table_as_root(buf, bufsiz, FunctionAddresses_identifier, &FunctionAddresses_verify_table);
+}
+
+static inline int FunctionAddresses_verify_as_typed_root(const void *buf, size_t bufsiz)
+{
+ return flatcc_verify_table_as_root(buf, bufsiz, FunctionAddresses_type_identifier, &FunctionAddresses_verify_table);
+}
+
+static inline int FunctionAddresses_verify_as_root_with_identifier(const void *buf, size_t bufsiz, const char *fid)
+{
+ return flatcc_verify_table_as_root(buf, bufsiz, fid, &FunctionAddresses_verify_table);
+}
+
+static inline int FunctionAddresses_verify_as_root_with_type_hash(const void *buf, size_t bufsiz, flatbuffers_thash_t thash)
+{
+ return flatcc_verify_table_as_typed_root(buf, bufsiz, thash, &FunctionAddresses_verify_table);
+}
+
+#include "flatcc/flatcc_epilogue.h"
+#endif /* APIWRAPPER_VERIFIER_H */
diff --git a/examples/driver-flatbuffers-tcp.bat b/examples/driver-flatbuffers-tcp.bat
new file mode 100644
index 0000000..00d5d8c
--- /dev/null
+++ b/examples/driver-flatbuffers-tcp.bat
@@ -0,0 +1,27 @@
+@echo off
+set SERVICE_NAME=flatbuffers-tcp
+set DRIVER="%~dp0\driver-flatbuffers-tcp.sys"
+
+net session >nul 2>&1
+if NOT %ERRORLEVEL% EQU 0 (
+ echo ERROR: This script requires Administrator privileges!
+ pause
+ exit /b 1
+)
+
+echo ---------------------------------------
+echo -- Service Name: %SERVICE_NAME%
+echo -- Driver......: %DRIVER%
+echo ---------------------------------------
+
+sc create %SERVICE_NAME% binPath= %DRIVER% type= kernel
+echo ---------------------------------------
+sc start %SERVICE_NAME%
+echo ---------------------------------------
+sc query %SERVICE_NAME%
+echo [PRESS A KEY TO STOP THE DRIVER]
+pause
+sc stop %SERVICE_NAME%
+sc delete %SERVICE_NAME%
+echo Done.
+timeout /t 3
diff --git a/examples/driver-flatbuffers-tcp.cpp b/examples/driver-flatbuffers-tcp.cpp
new file mode 100644
index 0000000..00b200c
--- /dev/null
+++ b/examples/driver-flatbuffers-tcp.cpp
@@ -0,0 +1,108 @@
+#include <ksocket/berkeley.h>
+#include <ksocket/helper.hpp>
+#include <ksocket/ksocket.hpp>
+#include <ksocket/ksocket.h>
+#include <ksocket/wsk.h>
+
+#include "apiwrapper_builder.h"
+#include "apiwrapper_reader.h"
+#include "apiwrapper_verifier.h"
+
+extern "C" {
+DRIVER_INITIALIZE DriverEntry;
+DRIVER_UNLOAD DriverUnload;
+
+#define DebuggerPrint(...) \
+ DbgPrintEx(DPFLTR_IHVDRIVER_ID, DPFLTR_ERROR_LEVEL, __VA_ARGS__);
+
+NTSTATUS
+NTAPI
+DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
+ _In_ PUNICODE_STRING RegistryPath) {
+ UNREFERENCED_PARAMETER(DriverObject);
+ UNREFERENCED_PARAMETER(RegistryPath);
+
+ NTSTATUS Status;
+
+ KSocketBuffer buf;
+ buf.insert(buf.end(), static_cast<uint16_t>(0x1122));
+ buf.insert(buf.end(), static_cast<uint32_t>(0xFFFFFFFF));
+ buf.insert(buf.end(), "AAAAAAAA");
+ DebuggerPrint("HEX: %s\n", buf.toHex().c_str());
+
+ DebuggerPrint("Hi.\n");
+ Status = KsInitialize();
+
+ if (!NT_SUCCESS(Status)) {
+ return Status;
+ }
+
+ int server_sockfd = socket_listen(AF_INET, SOCK_STREAM, 0);
+
+ struct sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ addr.sin_port = htons(9096);
+
+ int result = bind(server_sockfd, (struct sockaddr *)&addr, sizeof(addr));
+ if (result != 0) {
+ DebuggerPrint("TCP server bind failed\n");
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ result = listen(server_sockfd, 1);
+ if (result != 0) {
+ DebuggerPrint("TCP server listen failed\n");
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ socklen_t addrlen = sizeof(addr);
+ int client_sockfd = accept(server_sockfd, (struct sockaddr *)&addr, &addrlen);
+ if (client_sockfd < 0) {
+ DebuggerPrint("TCP accept failed\n");
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ int iResult;
+ SocketBuffer<1024> sb_send, sb_recv;
+
+ do {
+ RECV_PDU_BEGIN(client_sockfd, sb_recv, iResult, pdu_type, pdu_len) {
+ DebuggerPrint("PDU type/len: %u/%u\n", pdu_type, pdu_len);
+ if (pdu_type == 0) {
+ int ret = FunctionAddresses_verify_as_root(sb_recv.GetStart(), pdu_len);
+
+ if (ret == 0) {
+ FunctionAddresses_table_t fnaddr = FunctionAddresses_as_root(sb_recv.GetStart());
+
+ if (!fnaddr) {
+ DebuggerPrint("%s\n", "FunctionAddresses not available!");
+ } else {
+ flatbuffers_string_vec_t names = FunctionAddresses_names(fnaddr);
+ size_t name_size = flatbuffers_string_vec_len(names);
+
+ DebuggerPrint("Length of names vector: %zu\n", name_size);
+ }
+ } else {
+ DebuggerPrint("Flatbuffer verification failed with %d: %s\n", ret, flatcc_verify_error_string(ret));
+ }
+ } else {
+ DebuggerPrint("%s\n", "PDU type not supported!");
+ }
+ }
+ RECV_PDU_END(sb_recv, pdu_len);
+ } while (iResult != SOCKET_ERROR && iResult > 0);
+
+ DebuggerPrint("Client gone.\n") closesocket(client_sockfd);
+ closesocket(server_sockfd);
+ KsDestroy();
+
+ return STATUS_SUCCESS;
+}
+
+VOID DriverUnload(_In_ struct _DRIVER_OBJECT *DriverObject) {
+ UNREFERENCED_PARAMETER(DriverObject);
+
+ DebuggerPrint("Bye.\n");
+}
+}
diff --git a/examples/driver-protobuf-c-tcp.cpp b/examples/driver-protobuf-c-tcp.cpp
index fcf4465..a7071ed 100644
--- a/examples/driver-protobuf-c-tcp.cpp
+++ b/examples/driver-protobuf-c-tcp.cpp
@@ -1,6 +1,7 @@
#include <ksocket/berkeley.h>
#include <ksocket/helper.hpp>
#include <ksocket/ksocket.h>
+#include <ksocket/utils.h>
#include <ksocket/wsk.h>
#include "examples/common.hpp"
diff --git a/examples/driver-protobuf-c.cpp b/examples/driver-protobuf-c.cpp
index c3c8445..bbd0b52 100644
--- a/examples/driver-protobuf-c.cpp
+++ b/examples/driver-protobuf-c.cpp
@@ -1,6 +1,7 @@
#include <ksocket/berkeley.h>
#include <ksocket/helper.hpp>
#include <ksocket/ksocket.h>
+#include <ksocket/utils.h>
#include <ksocket/wsk.h>
#include "examples/common.hpp"
diff --git a/examples/driver.cpp b/examples/driver.cpp
index 86e42e9..228acc7 100644
--- a/examples/driver.cpp
+++ b/examples/driver.cpp
@@ -1,9 +1,11 @@
-
-extern "C" {
#include <ksocket/berkeley.h>
#include <ksocket/ksocket.h>
+#include <ksocket/utils.h>
#include <ksocket/wsk.h>
+#include <ksocket/ksocket.hpp>
+
+extern "C" {
DRIVER_INITIALIZE DriverEntry;
DRIVER_UNLOAD DriverUnload;
@@ -34,15 +36,15 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
// Perform HTTP request to http://httpbin.org/uuid
//
+ const char send_buffer[] = "GET /uuid HTTP/1.1\r\n"
+ "Host: httpbin.org\r\n"
+ "Connection: close\r\n"
+ "\r\n";
+
{
int result;
UNREFERENCED_PARAMETER(result);
- char send_buffer[] = "GET /uuid HTTP/1.1\r\n"
- "Host: httpbin.org\r\n"
- "Connection: close\r\n"
- "\r\n";
-
char recv_buffer[1024] = {};
struct addrinfo hints = {};
@@ -66,7 +68,7 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
return STATUS_FAILED_DRIVER_ENTRY;
}
- result = send(sockfd, send_buffer, sizeof(send_buffer), 0);
+ result = send(sockfd, send_buffer, sizeof(send_buffer) - 1, 0);
if (result <= 0) {
DebuggerPrint("TCP client send failed\n");
return STATUS_FAILED_DRIVER_ENTRY;
@@ -83,6 +85,48 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
closesocket(sockfd);
}
+ {
+ KStreamClientIp4 tcp4_client = KStreamClientIp4();
+
+ if (!tcp4_client.setup()) {
+ DebuggerPrint("KStreamClientIp4 setup() failed: %d\n",
+ tcp4_client.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ if (!tcp4_client.connect("httpbin.org", "80")) {
+ DebuggerPrint("KStreamClientIp4 connect() failed: %d\n",
+ tcp4_client.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ DebuggerPrint("%s\n", "KStreamClientIp4 connected!");
+
+ tcp4_client.getSendBuffer().insert(tcp4_client.getSendBuffer().end(),
+ send_buffer, sizeof(send_buffer) - 1);
+ if (!tcp4_client.send()) {
+ DebuggerPrint("KStreamClientIp4 send() failed: %d\n",
+ tcp4_client.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ if (!tcp4_client.recv()) {
+ DebuggerPrint("KStreamClientIp4 recv() failed: %d\n",
+ tcp4_client.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ DebuggerPrint("KStreamClientIp4 data received:\n%s\n",
+ tcp4_client.getRecvBuffer().to_string().c_str());
+ DebuggerPrint("KStreamClientIp4 consuming %zu bytes\n",
+ tcp4_client.getRecvBuffer().size());
+ tcp4_client.getRecvBuffer().consume();
+ DebuggerPrint("KStreamClientIp4 receive buffer size: %zu\n",
+ tcp4_client.getRecvBuffer().size());
+
+ DebuggerPrint("%s\n", "KStreamClientIp4 finished.");
+ }
+
//
// TCP server.
// Listen on port 9095, wait for some message,
@@ -114,6 +158,10 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
return STATUS_FAILED_DRIVER_ENTRY;
}
+ DebuggerPrint(
+ "%s\n",
+ "TCP server is waiting for the user to start userspace_client.exe");
+
socklen_t addrlen = sizeof(addr);
int client_sockfd =
accept(server_sockfd, (struct sockaddr *)&addr, &addrlen);
@@ -124,7 +172,7 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
result = recv(client_sockfd, recv_buffer, sizeof(recv_buffer) - 1, 0);
if (result > 0) {
- DebuggerPrint("TCP server:\n%.*s\n", result, recv_buffer);
+ DebuggerPrint("TCP server received: \"%.*s\"\n", result, recv_buffer);
} else {
DebuggerPrint("TCP server recv failed\n");
}
@@ -139,6 +187,69 @@ DriverEntry(_In_ PDRIVER_OBJECT DriverObject,
closesocket(server_sockfd);
}
+ {
+ KStreamServerIp4 tcp4_server = KStreamServerIp4();
+
+ if (!tcp4_server.setup()) {
+ DebuggerPrint("KStreamServerIp4 setup() failed: %d\n",
+ tcp4_server.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ if (!tcp4_server.bind(9095)) {
+ DebuggerPrint("KStreamServerIp4 bind() failed: %d\n",
+ tcp4_server.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ if (!tcp4_server.listen()) {
+ DebuggerPrint("KStreamServerIp4 bind() failed: %d\n",
+ tcp4_server.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+
+ DebuggerPrint("%s\n", "KStreamServerIp4 listening for incomining "
+ "connections (run userspace_client.exe again)..");
+
+ const auto &accept_fn = [](KAcceptedSocket &ka) {
+ const auto &remote = ka.getRemote();
+
+ if (remote.addr_used != 4) {
+ return false;
+ }
+ DebuggerPrint("KStreamServerIp4 client connected: %s\n",
+ remote.to_string().c_str());
+
+ if (!ka.recv()) {
+ DebuggerPrint("KStreamServerIp4 recv failed: %d\n", ka.getLastError());
+ return false;
+ }
+ DebuggerPrint("KStreamServerIp4 received %zu bytes: \"%s\"\n",
+ ka.getRecvBuffer().size(),
+ ka.getRecvBuffer().to_string().c_str());
+ ka.getRecvBuffer().consume();
+
+ ka.getSendBuffer().insert_string(ka.getSendBuffer().end(),
+ "KStreamServerIp4 says hello!");
+ if (!ka.send()) {
+ DebuggerPrint("KStreamServerIp4 send failed: %d\n", ka.getLastError());
+ return false;
+ }
+ ka.getSendBuffer().consume();
+
+ // Wait for the connection termination.
+ ka.recv();
+
+ return true;
+ };
+ if (!tcp4_server.accept(accept_fn)) {
+ DebuggerPrint("KStreamServerIp4 accept() failed: %d\n",
+ tcp4_server.getLastError());
+ return STATUS_FAILED_DRIVER_ENTRY;
+ }
+ DebuggerPrint("KStreamServerIp4 done\n");
+ }
+
KsDestroy();
return STATUS_SUCCESS;
diff --git a/examples/userspace_client_flatbuffers.cpp b/examples/userspace_client_flatbuffers.cpp
new file mode 100644
index 0000000..9f43698
--- /dev/null
+++ b/examples/userspace_client_flatbuffers.cpp
@@ -0,0 +1,82 @@
+#include <stdio.h>
+#include <stdlib.h> // Needed for _wtoi
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#include <ksocket/ksocket.hpp>
+#include <ksocket/helper.hpp>
+
+#include "apiwrapper_builder.h"
+#include "apiwrapper_reader.h"
+#include "apiwrapper_verifier.h"
+
+int main(int argc, char **argv) {
+ WSADATA wsaData = {};
+ int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData);
+
+ UNREFERENCED_PARAMETER(argc);
+ UNREFERENCED_PARAMETER(argv);
+
+ if (iResult != 0) {
+ wprintf(L"WSAStartup failed: %d\n", iResult);
+ return 1;
+ }
+
+ SOCKET sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
+
+ if (sock == INVALID_SOCKET) {
+ wprintf(L"socket function failed with error = %d\n", WSAGetLastError());
+ } else {
+ wprintf(L"socket function succeeded\n");
+ }
+
+ sockaddr_in clientService;
+ clientService.sin_family = AF_INET;
+ clientService.sin_addr.s_addr = inet_addr("127.0.0.1");
+ clientService.sin_port = htons(9096);
+
+ do {
+ iResult = connect(sock, (SOCKADDR *)&clientService, sizeof(clientService));
+ if (iResult == SOCKET_ERROR) {
+ wprintf(L"connect function failed with error: %ld\n", WSAGetLastError());
+ Sleep(1000);
+ }
+ } while (iResult == SOCKET_ERROR);
+
+ wprintf(L"Connected to server.\n");
+
+ flatcc_builder_t builder;
+ flatcc_builder_init(&builder);
+ for (size_t i = 0; i < 256; ++i) {
+ FunctionAddresses_start_as_root(&builder);
+ FunctionAddresses_names_add(&builder, flatbuffers_string_create_str(&builder, "A"));
+ FunctionAddresses_names_add(&builder, flatbuffers_string_create_str(&builder, "B"));
+ FunctionAddresses_names_add(&builder, flatbuffers_string_create_str(&builder, "C"));
+ FunctionAddresses_end_as_root(&builder);
+
+ KSocketBuffer buffer;
+ void *buf;
+ size_t siz;
+ buf = flatcc_builder_finalize_aligned_buffer(&builder, &siz);
+ (void)buf;
+ uint8_t a[] = {0x41,0x41,0x41};
+ buffer.insert_u16(buffer.begin(), 65535);
+ buffer.insert_bytebuffer(buffer.begin(), a, 3);
+ }
+
+ wprintf(L"Closing Connection ..\n");
+
+ iResult = closesocket(sock);
+ if (iResult == SOCKET_ERROR) {
+ wprintf(L"closesocket function failed with error: %ld\n",
+ WSAGetLastError());
+ WSACleanup();
+ return 1;
+ }
+
+ WSACleanup();
+
+ system("pause");
+
+ return 0;
+}
diff --git a/examples/userspace_client_protobuf.cpp b/examples/userspace_client_protobuf.cpp
index a387579..1eadc0c 100644
--- a/examples/userspace_client_protobuf.cpp
+++ b/examples/userspace_client_protobuf.cpp
@@ -4,6 +4,7 @@
#include <ws2tcpip.h>
#include <ksocket/helper.hpp>
+#include <ksocket/utils.h>
#include "examples/common.hpp"
#include "examples/example.pb-c.h"
diff --git a/ksocket/berkeley.c b/ksocket/berkeley.c
index e72f92e..73ccb79 100644
--- a/ksocket/berkeley.c
+++ b/ksocket/berkeley.c
@@ -297,14 +297,6 @@ VOID NTAPI KspUtilFreeAddrInfoEx(_In_ PADDRINFOEXW AddrInfo) {
// Public functions.
//////////////////////////////////////////////////////////////////////////
-uint32_t htonl(uint32_t hostlong) { return __builtin_bswap32(hostlong); }
-
-uint16_t htons(uint16_t hostshort) { return __builtin_bswap16(hostshort); }
-
-uint32_t ntohl(uint32_t netlong) { return __builtin_bswap32(netlong); }
-
-uint16_t ntohs(uint16_t netshort) { return __builtin_bswap16(netshort); }
-
int getaddrinfo(const char *node, const char *service,
const struct addrinfo *hints, struct addrinfo **res) {
NTSTATUS Status;
diff --git a/ksocket/berkeley.h b/ksocket/berkeley.h
index b879d3c..1475a89 100644
--- a/ksocket/berkeley.h
+++ b/ksocket/berkeley.h
@@ -1,6 +1,11 @@
-#pragma once
+#ifndef KSOCKET_BERKELEY_H
+#define KSOCKET_BERKELEY_H 1
+
+#ifdef BUILD_USERMODE
+#error "This file should only be included if building for kernel mode! Include <ksocket/ksocket.hpp> wrapper instead."
+#endif
+
#include <ntddk.h>
-#include <stdint.h>
#include <ksocket/wsk.h>
#define socket socket_connection
@@ -12,11 +17,6 @@ extern "C" {
typedef int socklen_t;
typedef intptr_t ssize_t;
-uint32_t htonl(uint32_t hostlong);
-uint16_t htons(uint16_t hostshort);
-uint32_t ntohl(uint32_t netlong);
-uint16_t ntohs(uint16_t netshort);
-
int getaddrinfo(const char *node, const char *service,
const struct addrinfo *hints, struct addrinfo **res);
void freeaddrinfo(struct addrinfo *res);
@@ -39,3 +39,5 @@ int closesocket(int sockfd);
#ifdef __cplusplus
}
#endif
+
+#endif
diff --git a/ksocket/helper.hpp b/ksocket/helper.hpp
index 153e549..9143b07 100644
--- a/ksocket/helper.hpp
+++ b/ksocket/helper.hpp
@@ -11,6 +11,8 @@
#include <EASTL/string.h>
#include <EASTL/vector.h>
+#include <ksocket/utils.h>
+
#ifndef SOCKET_ERROR
#define SOCKET_ERROR -1
#endif
diff --git a/ksocket/ksocket.cpp b/ksocket/ksocket.cpp
new file mode 100644
index 0000000..ca0625b
--- /dev/null
+++ b/ksocket/ksocket.cpp
@@ -0,0 +1,350 @@
+#include "ksocket.hpp"
+
+#include <EASTL/type_traits.h>
+#include <eastl_compat.hpp>
+
+#ifdef BUILD_USERMODE
+// clang-format off
+#include <winsock2.h>
+#include <windows.h>
+#include <ws2tcpip.h>
+// clang-format on
+#else
+#include <ksocket/berkeley.h>
+#include <ksocket/ksocket.h>
+#include <ksocket/wsk.h>
+#endif
+#include <ksocket/utils.h>
+
+struct KSocketImplCommon {
+ ADDRESS_FAMILY domain;
+ int type;
+ int proto;
+};
+
+#ifdef BUILD_USERMODE
+struct KSocketImpl {
+ ~KSocketImpl() { closesocket(s); }
+
+ SOCKET s;
+ KSocketImplCommon c;
+};
+#else
+struct KSocketImpl {
+ ~KSocketImpl() {
+ if (s >= 0) {
+ closesocket(s);
+ s = -1;
+ }
+ }
+
+ int s = -1;
+ KSocketImplCommon c;
+};
+#endif
+
+eastl::string KSocketAddress::to_string(bool with_port) const {
+ if (addr_used != 4) {
+ return "";
+ }
+
+ return ::to_string(addr.u8[0]) + "." + ::to_string(addr.u8[1]) + "." +
+ ::to_string(addr.u8[2]) + "." + ::to_string(addr.u8[3]) + ":" +
+ ::to_string(with_port);
+}
+
+void KSocketBuffer::insert_u8(KBuffer::iterator it, uint8_t value) {
+ buffer.insert(it, value);
+}
+
+void KSocketBuffer::insert_u16(KBuffer::iterator it, uint16_t value) {
+ uint16_t net_value;
+ uint8_t insert_value[2];
+
+ net_value = htons(value);
+ insert_value[0] = (net_value & 0x00FF) >> 0;
+ insert_value[1] = (net_value & 0xFF00) >> 8;
+ buffer.insert(it, insert_value, insert_value + eastl::size(insert_value));
+}
+
+void KSocketBuffer::insert_u32(KBuffer::iterator it, uint32_t value) {
+ uint32_t net_value;
+ uint8_t insert_value[4];
+
+ net_value = htonl(value);
+ insert_value[0] = (net_value & 0x000000FF) >> 0;
+ insert_value[1] = (net_value & 0x0000FF00) >> 8;
+ insert_value[2] = (net_value & 0x00FF0000) >> 16;
+ insert_value[3] = (net_value & 0xFF000000) >> 24;
+ buffer.insert(it, insert_value, insert_value + eastl::size(insert_value));
+}
+
+void KSocketBuffer::insert_u64(KBuffer::iterator it, uint64_t value) {
+ uint64_t net_value;
+ uint8_t insert_value[8];
+
+ net_value = htonll(value);
+ insert_value[0] = (net_value & 0x00000000000000FF) >> 0;
+ insert_value[1] = (net_value & 0x000000000000FF00) >> 8;
+ insert_value[2] = (net_value & 0x0000000000FF0000) >> 16;
+ insert_value[3] = (net_value & 0x00000000FF000000) >> 24;
+ insert_value[4] = (net_value & 0x000000FF00000000) >> 32;
+ insert_value[5] = (net_value & 0x0000FF0000000000) >> 40;
+ insert_value[6] = (net_value & 0x00FF000000000000) >> 48;
+ insert_value[7] = (net_value & 0xFF00000000000000) >> 56;
+ buffer.insert(it, insert_value, insert_value + eastl::size(insert_value));
+}
+
+void KSocketBuffer::insert_bytebuffer(KBuffer::iterator it,
+ const uint8_t bytebuffer[], size_t size) {
+ buffer.insert(it, bytebuffer, bytebuffer + size);
+}
+
+void KSocketBuffer::consume(size_t amount_bytes) {
+ if (amount_bytes == 0 || amount_bytes > buffer.size())
+ amount_bytes = buffer.size();
+
+ buffer.erase(buffer.begin(), buffer.begin() + amount_bytes);
+}
+
+eastl::string KSocketBuffer::toHex(eastl::string delim) {
+ eastl::string str;
+ char const *const hex = "0123456789ABCDEF";
+ char pout[3] = {};
+
+ for (const auto &input_byte : buffer) {
+ pout[0] = hex[(input_byte >> 4) & 0xF];
+ pout[1] = hex[input_byte & 0xF];
+ str += pout;
+ str += delim;
+ }
+
+ if (str.length() >= delim.length() && delim.length() > 0)
+ str.erase(str.length() - delim.length(), delim.length());
+
+ return str;
+}
+
+KSocket::~KSocket() {
+ if (m_socket != nullptr)
+ delete m_socket;
+}
+
+bool KSocket::setup(KSocketType sock_type, int proto) {
+ int domain, type;
+
+ m_lastError = KSE_SUCCESS;
+ if (m_socket != nullptr)
+ delete m_socket;
+ m_socket = new KSocketImpl();
+ if (m_socket == nullptr) {
+ m_lastError = KSE_SETUP_IMPL_NULL;
+ return false;
+ }
+
+ if (KSocket::socketTypeToTuple(sock_type, domain, type) != true) {
+ m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE;
+ return false;
+ }
+ if (sock_type == KSocketType::KST_STREAM_CLIENT_IP6 ||
+ sock_type == KSocketType::KST_STREAM_SERVER_IP6 ||
+ sock_type == KSocketType::KST_DATAGRAM_IP6) {
+ m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE;
+ return false; // IPv6 is not supported for now
+ }
+ m_socketType = sock_type;
+
+#ifdef BUILD_USERMODE
+ m_socket->s = ::socket(domain, type, proto);
+#else
+ // DbgPrint("KSocketType: %d, domain: %d, type: %d, proto: %d", sock_type,
+ // domain, type, proto);
+ switch (sock_type) {
+ case KSocketType::KST_INVALID:
+ m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE;
+ return false;
+ case KSocketType::KST_STREAM_CLIENT_IP4:
+ case KSocketType::KST_STREAM_CLIENT_IP6:
+ m_socket->s = ::socket_connection(domain, type, proto);
+ break;
+ case KSocketType::KST_STREAM_SERVER_IP4:
+ case KSocketType::KST_STREAM_SERVER_IP6:
+ m_socket->s = ::socket_listen(domain, type, proto);
+ break;
+ case KSocketType::KST_DATAGRAM_IP4:
+ case KSocketType::KST_DATAGRAM_IP6:
+ m_socket->s = ::socket_datagram(domain, type, proto);
+ break;
+ }
+#endif
+ m_socket->c.domain = static_cast<ADDRESS_FAMILY>(domain);
+ m_socket->c.type = type;
+ m_socket->c.proto = proto;
+#ifdef BUILD_USERMODE
+ return m_socket->s != INVALID_SOCKET;
+#else
+ return m_socket->s >= 0;
+#endif
+}
+
+bool KSocket::connect(eastl::string host, eastl::string port) {
+ struct addrinfo hints = {};
+ struct addrinfo *results;
+
+ m_lastError = KSE_SUCCESS;
+
+ if (m_socket == nullptr)
+ return false;
+
+ if (m_socketType != KSocketType::KST_STREAM_CLIENT_IP4 &&
+ m_socketType != KSocketType::KST_STREAM_CLIENT_IP6)
+ return false;
+
+ hints.ai_flags |= AI_CANONNAME;
+ hints.ai_family = m_socket->c.domain;
+ hints.ai_socktype = m_socket->c.type;
+ m_lastError = ::getaddrinfo(host.c_str(), port.c_str(), &hints, &results);
+ if (m_lastError != KSE_SUCCESS)
+ return false;
+
+ m_lastError =
+ ::connect(m_socket->s, results->ai_addr, (int)results->ai_addrlen);
+ freeaddrinfo(results);
+ return m_lastError == KSE_SUCCESS;
+}
+
+bool KSocket::bind(uint16_t port) {
+ struct sockaddr_in addr;
+
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ addr.sin_port = htons(port);
+
+ m_lastError = ::bind(m_socket->s, (struct sockaddr *)&addr, sizeof(addr));
+
+ return m_lastError == KSE_SUCCESS;
+}
+
+bool KSocket::listen(int backlog) {
+ m_lastError = KSE_SUCCESS;
+
+ if (m_socketType != KSocketType::KST_STREAM_SERVER_IP4 &&
+ m_socketType != KSocketType::KST_STREAM_SERVER_IP6)
+ return false;
+
+ m_lastError = ::listen(m_socket->s, backlog);
+
+ return m_lastError == KSE_SUCCESS;
+}
+
+bool KSocket::accept(KAcceptThreadCallback thread_callback) {
+ KAcceptedSocket ka;
+ struct sockaddr addr;
+ socklen_t addrlen = sizeof(addr);
+
+ if (m_socket->c.domain != AF_INET) {
+ m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE;
+ return false;
+ }
+
+ addr.sa_family = m_socket->c.domain;
+ ka.m_socket = new KSocketImpl();
+ if (ka.m_socket == nullptr) {
+ m_lastError = KSE_SETUP_IMPL_NULL;
+ return false;
+ }
+ ka.m_socketType = m_socketType;
+ ka.m_socket->c = m_socket->c;
+ ka.m_socket->s = ::accept(m_socket->s, &addr, &addrlen);
+
+ if (m_socket->c.domain == AF_INET) {
+ struct sockaddr_in *addr_in = reinterpret_cast<struct sockaddr_in *>(&addr);
+
+ ka.m_remote.addr_used = 4;
+ ka.m_remote.addr.u32[0] = addr_in->sin_addr.s_addr;
+ ka.m_remote.port = addr_in->sin_port;
+ } else {
+ m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE;
+ return false;
+ }
+
+#ifdef BUILD_USERMODE
+ if (ka.m_socket->s == INVALID_SOCKET)
+#else
+ if (ka.m_socket->s < 0)
+#endif
+ {
+ m_lastError = KSE_ACCEPT_FAILED;
+ return false;
+ }
+
+ return thread_callback(ka);
+}
+
+bool KSocket::close() {
+ int rv = closesocket(m_socket->s);
+
+ if (rv == 0) {
+ m_socket->s = -1;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool KSocket::send() {
+ m_lastError = KSE_SUCCESS;
+
+ if (m_sendBuffer.size() == 0)
+ return false;
+
+ m_lastError =
+ ::send(m_socket->s, reinterpret_cast<const char *>(m_sendBuffer.data()),
+ m_sendBuffer.size(), 0);
+ if (m_lastError > 0) {
+ m_sendBuffer.buffer.erase(m_sendBuffer.begin(),
+ m_sendBuffer.begin() + m_lastError);
+ return true;
+ }
+
+ return false;
+}
+
+bool KSocket::recv(size_t max_recv_size) {
+ const size_t current_size = m_recvBuffer.size();
+
+ m_recvBuffer.buffer.resize(current_size + max_recv_size);
+ m_lastError = ::recv(
+ m_socket->s, reinterpret_cast<char *>(m_recvBuffer.data() + current_size),
+ max_recv_size, 0);
+ if (m_lastError > 0) {
+ m_recvBuffer.buffer.resize(current_size + m_lastError);
+ return true;
+ }
+
+ return false;
+}
+
+bool KSocket::socketTypeToTuple(KSocketType sock_type, int &domain, int &type) {
+ switch (sock_type) {
+ case KSocketType::KST_INVALID:
+ break;
+ case KSocketType::KST_STREAM_CLIENT_IP4:
+ case KSocketType::KST_STREAM_SERVER_IP4:
+ domain = AF_INET;
+ type = SOCK_STREAM;
+ return true;
+ case KSocketType::KST_STREAM_CLIENT_IP6:
+ case KSocketType::KST_STREAM_SERVER_IP6:
+ domain = AF_INET6;
+ type = SOCK_STREAM;
+ return true;
+ case KSocketType::KST_DATAGRAM_IP4:
+ case KSocketType::KST_DATAGRAM_IP6:
+ domain = AF_INET;
+ type = SOCK_DGRAM;
+ return true;
+ }
+
+ return false;
+}
diff --git a/ksocket/ksocket.h b/ksocket/ksocket.h
index 0a6717e..e44035b 100644
--- a/ksocket/ksocket.h
+++ b/ksocket/ksocket.h
@@ -1,5 +1,12 @@
-#pragma once
-#include "wsk.h"
+#ifndef KSOCKET_H
+#define KSOCKET_H 1
+
+#ifdef BUILD_USERMODE
+#error \
+ "This file should only be included if building for kernel mode! Include <ksocket/ksocket.hpp> wrapper instead."
+#endif
+
+#include <ksocket/wsk.h>
#include <ntddk.h>
#define STATUS_UNSUPPORTED_WINDOWS_VERSION \
@@ -97,3 +104,5 @@ KsRecvFrom(_In_ PKSOCKET Socket, _In_ PVOID Buffer, _Inout_ PULONG Length,
#ifdef __cplusplus
}
#endif
+
+#endif
diff --git a/ksocket/ksocket.hpp b/ksocket/ksocket.hpp
new file mode 100644
index 0000000..abf03b6
--- /dev/null
+++ b/ksocket/ksocket.hpp
@@ -0,0 +1,233 @@
+#ifndef KSOCKET_HPP
+#define KSOCKET_HPP 1
+
+#include <EASTL/functional.h>
+#include <EASTL/string.h>
+#include <EASTL/vector.h>
+#include <cstdint>
+
+using KBuffer = eastl::vector<uint8_t>;
+
+struct KSocketImpl;
+
+enum {
+ KSE_SUCCESS = 0,
+ KSE_SETUP_IMPL_NULL = 1,
+ KSE_SETUP_INVALID_SOCKET_TYPE = 2,
+ KSE_SETUP_UNSUPPORTED_SOCKET_TYPE = 3,
+ KSE_ACCEPT_FAILED = 4,
+};
+
+enum class KSocketType {
+ KST_INVALID = 0,
+ KST_STREAM_CLIENT_IP4,
+ KST_STREAM_SERVER_IP4,
+ KST_STREAM_CLIENT_IP6,
+ KST_STREAM_SERVER_IP6,
+ KST_DATAGRAM_IP4,
+ KST_DATAGRAM_IP6
+};
+
+struct KSocketAddress {
+ eastl::string to_string(bool with_port = true) const;
+ uint8_t addr_used = 0;
+ union {
+ uint8_t u8[16];
+ uint16_t u16[8];
+ uint32_t u32[4];
+ uint64_t u64[2];
+ } addr;
+ uint16_t port = 0;
+};
+
+class KAcceptedSocket;
+using KAcceptThreadCallback = eastl::function<bool(KAcceptedSocket &accepted)>;
+
+struct KSocketBuffer {
+ void insert_i8(KBuffer::iterator it, int8_t value) {
+ insert_u8(it, static_cast<uint8_t>(value));
+ }
+ void insert_i16(KBuffer::iterator it, int16_t value) {
+ insert_u16(it, static_cast<uint16_t>(value));
+ }
+ void insert_i32(KBuffer::iterator it, int32_t value) {
+ insert_u32(it, static_cast<uint32_t>(value));
+ }
+ void insert_i64(KBuffer::iterator it, int64_t value) {
+ insert_u64(it, static_cast<uint64_t>(value));
+ }
+
+ void insert_u8(KBuffer::iterator it, uint8_t value);
+ void insert_u16(KBuffer::iterator it, uint16_t value);
+ void insert_u32(KBuffer::iterator it, uint32_t value);
+ void insert_u64(KBuffer::iterator it, uint64_t value);
+
+ void insert_string(KBuffer::iterator it, const eastl::string &value) {
+ insert_bytebuffer(it, reinterpret_cast<const uint8_t *>(value.c_str()),
+ value.length());
+ }
+ void insert_string(KBuffer::iterator it, const char buffer[], size_t size) {
+ insert_bytebuffer(it, reinterpret_cast<const uint8_t *>(buffer), size);
+ }
+ void insert_bytebuffer(KBuffer::iterator it, const void *bytebuffer,
+ size_t size) {
+ insert_bytebuffer(it, reinterpret_cast<const uint8_t *>(bytebuffer), size);
+ }
+ void insert_bytebuffer(KBuffer::iterator it, const uint8_t bytebuffer[],
+ size_t size);
+
+ void insert(KBuffer::iterator it, int8_t value) {
+ insert_u8(it, static_cast<uint8_t>(value));
+ }
+ void insert(KBuffer::iterator it, int16_t value) {
+ insert_u16(it, static_cast<uint16_t>(value));
+ }
+ void insert(KBuffer::iterator it, int32_t value) {
+ insert_u32(it, static_cast<uint32_t>(value));
+ }
+ void insert(KBuffer::iterator it, int64_t value) {
+ insert_u64(it, static_cast<uint64_t>(value));
+ }
+
+ void insert(KBuffer::iterator it, uint8_t value) { insert_u8(it, value); }
+ void insert(KBuffer::iterator it, uint16_t value) { insert_u16(it, value); }
+ void insert(KBuffer::iterator it, uint32_t value) { insert_u32(it, value); }
+ void insert(KBuffer::iterator it, uint64_t value) { insert_u64(it, value); }
+
+ void insert(KBuffer::iterator it, const eastl::string value) {
+ insert_string(it, eastl::move(value));
+ }
+ void insert(KBuffer::iterator it, const char buffer[], size_t size) {
+ insert_string(it, buffer, size);
+ }
+ void insert(KBuffer::iterator it, const char *value) {
+ insert_string(it, eastl::move(eastl::string(value)));
+ }
+ void insert(KBuffer::iterator it, const uint8_t bytebuffer[], size_t size) {
+ insert_bytebuffer(it, bytebuffer, size);
+ }
+
+ void consume(size_t amount_bytes = 0);
+
+ eastl::string to_string() {
+ return eastl::string(reinterpret_cast<const char *>(data()), size());
+ }
+ uint8_t *data() { return buffer.data(); }
+ size_t size() { return buffer.size(); }
+ KBuffer::iterator begin() { return buffer.begin(); }
+ KBuffer::iterator end() { return buffer.end(); }
+
+ eastl::string toHex(eastl::string delim = ":");
+
+ KBuffer buffer;
+};
+
+class KSocket {
+protected:
+ KSocket() : m_sendBuffer(), m_recvBuffer() {}
+ KSocket(const KSocket &) = delete;
+ ~KSocket();
+ bool setup(KSocketType sock_type, int proto = 0);
+
+ bool connect(eastl::string host, eastl::string port);
+ bool bind(uint16_t port);
+ bool listen(int backlog);
+ bool accept(KAcceptThreadCallback thread_callback);
+ bool close();
+
+ bool send();
+ bool recv(size_t max_recv_size);
+
+public:
+ int getLastError() const { return m_lastError; }
+
+ KSocketBuffer &getSendBuffer() { return m_sendBuffer; }
+ KSocketBuffer &getRecvBuffer() { return m_recvBuffer; }
+
+ static bool socketTypeToTuple(KSocketType sock_type, int &domain, int &type);
+
+private:
+ KSocketBuffer m_sendBuffer;
+ KSocketBuffer m_recvBuffer;
+ KSocketType m_socketType = KSocketType::KST_INVALID;
+ KSocketImpl *m_socket = nullptr;
+ int m_lastError = KSE_SUCCESS;
+};
+
+class KAcceptedSocket : public KSocket {
+ friend class KSocket;
+
+public:
+ KAcceptedSocket() : KSocket() {}
+ ~KAcceptedSocket() {}
+ bool setup(KSocketType sock_type, int proto = 0) = delete;
+
+ bool connect(eastl::string host, eastl::string port) = delete;
+ bool bind(uint16_t port) = delete;
+ bool listen(int backlog) = delete;
+ bool accept(KAcceptThreadCallback thread_callback) = delete;
+
+ bool send() { return KSocket::send(); }
+ bool recv(size_t max_recv_size = 65535) {
+ return KSocket::recv(max_recv_size);
+ }
+
+ int getLastError() { return KSocket::getLastError(); }
+
+ const KSocketAddress &getRemote() { return m_remote; }
+
+private:
+ KSocketAddress m_remote;
+};
+
+class KStreamClientIp4 : public KSocket {
+public:
+ KStreamClientIp4() : KSocket() {}
+ ~KStreamClientIp4() {}
+
+ bool setup() {
+ return KSocket::setup(KSocketType::KST_STREAM_CLIENT_IP4,
+ 6 /* IPPROTO_TCP */);
+ }
+
+ bool connect(eastl::string host, eastl::string port) {
+ return KSocket::connect(host, port);
+ }
+ bool bind(uint16_t) = delete;
+ bool listen(int) = delete;
+ bool accept(KAcceptThreadCallback) = delete;
+
+ bool send() { return KSocket::send(); }
+ bool recv(size_t max_recv_size = 65535) {
+ return KSocket::recv(max_recv_size);
+ }
+
+ int getLastError() { return KSocket::getLastError(); }
+};
+
+class KStreamServerIp4 : public KSocket {
+public:
+ KStreamServerIp4() : KSocket() {}
+ ~KStreamServerIp4() {}
+
+ bool setup() {
+ return KSocket::setup(KSocketType::KST_STREAM_SERVER_IP4,
+ 6 /* IPPROTO_TCP */);
+ }
+
+ bool connect(eastl::string host, eastl::string port) = delete;
+ bool bind(uint16_t port) { return KSocket::bind(port); }
+ bool listen(int backlog = 16) { return KSocket::listen(backlog); }
+ bool accept(KAcceptThreadCallback thread_callback) {
+ return KSocket::accept(thread_callback);
+ }
+
+ bool send() { return KSocket::send(); }
+ bool recv(size_t max_recv_size = 65535) {
+ return KSocket::recv(max_recv_size);
+ }
+
+ int getLastError() { return KSocket::getLastError(); }
+};
+
+#endif
diff --git a/ksocket/utils.c b/ksocket/utils.c
new file mode 100644
index 0000000..1fe26ea
--- /dev/null
+++ b/ksocket/utils.c
@@ -0,0 +1,17 @@
+#include "utils.h"
+
+uint64_t htonll(uint64_t hostlonglong) { return __builtin_bswap64(hostlonglong); }
+
+#ifndef BUILD_USERMODE
+uint32_t htonl(uint32_t hostlong) { return __builtin_bswap32(hostlong); }
+
+uint16_t htons(uint16_t hostshort) { return __builtin_bswap16(hostshort); }
+#endif
+
+uint64_t ntohll(uint64_t netlonglong) { return __builtin_bswap64(netlonglong); }
+
+#ifndef BUILD_USERMODE
+uint32_t ntohl(uint32_t netlong) { return __builtin_bswap32(netlong); }
+
+uint16_t ntohs(uint16_t netshort) { return __builtin_bswap16(netshort); }
+#endif
diff --git a/ksocket/utils.h b/ksocket/utils.h
new file mode 100644
index 0000000..e3c0474
--- /dev/null
+++ b/ksocket/utils.h
@@ -0,0 +1,30 @@
+#ifndef KSOCKET_UTILS_H
+#define KSOCKET_UTILS_H 1
+
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+uint64_t htonll(uint64_t hostlonglong);
+
+#ifndef BUILD_USERMODE
+uint32_t htonl(uint32_t hostlong);
+
+uint16_t htons(uint16_t hostshort);
+#endif
+
+uint64_t ntohll(uint64_t netlonglong);
+
+#ifndef BUILD_USERMODE
+uint32_t ntohl(uint32_t netlong);
+
+uint16_t ntohs(uint16_t netshort);
+#endif
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/ksocket/wsk.h b/ksocket/wsk.h
index ccea103..25d7f06 100644
--- a/ksocket/wsk.h
+++ b/ksocket/wsk.h
@@ -1,6 +1,10 @@
#ifndef WSK_H
#define WSK_H 1
+#ifdef BUILD_USERMODE
+#error "This file should only be included if building for kernel mode! Include <ksocket/ksocket.hpp> wrapper instead."
+#endif
+
#include <ntddk.h>
#if !defined(__MINGW64__)