aboutsummaryrefslogtreecommitdiff
path: root/ksocket/ksocket.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ksocket/ksocket.cpp')
-rw-r--r--ksocket/ksocket.cpp350
1 files changed, 350 insertions, 0 deletions
diff --git a/ksocket/ksocket.cpp b/ksocket/ksocket.cpp
new file mode 100644
index 0000000..ca0625b
--- /dev/null
+++ b/ksocket/ksocket.cpp
@@ -0,0 +1,350 @@
+#include "ksocket.hpp"
+
+#include <EASTL/type_traits.h>
+#include <eastl_compat.hpp>
+
+#ifdef BUILD_USERMODE
+// clang-format off
+#include <winsock2.h>
+#include <windows.h>
+#include <ws2tcpip.h>
+// clang-format on
+#else
+#include <ksocket/berkeley.h>
+#include <ksocket/ksocket.h>
+#include <ksocket/wsk.h>
+#endif
+#include <ksocket/utils.h>
+
+struct KSocketImplCommon {
+ ADDRESS_FAMILY domain;
+ int type;
+ int proto;
+};
+
+#ifdef BUILD_USERMODE
+struct KSocketImpl {
+ ~KSocketImpl() { closesocket(s); }
+
+ SOCKET s;
+ KSocketImplCommon c;
+};
+#else
+struct KSocketImpl {
+ ~KSocketImpl() {
+ if (s >= 0) {
+ closesocket(s);
+ s = -1;
+ }
+ }
+
+ int s = -1;
+ KSocketImplCommon c;
+};
+#endif
+
+eastl::string KSocketAddress::to_string(bool with_port) const {
+ if (addr_used != 4) {
+ return "";
+ }
+
+ return ::to_string(addr.u8[0]) + "." + ::to_string(addr.u8[1]) + "." +
+ ::to_string(addr.u8[2]) + "." + ::to_string(addr.u8[3]) + ":" +
+ ::to_string(with_port);
+}
+
+void KSocketBuffer::insert_u8(KBuffer::iterator it, uint8_t value) {
+ buffer.insert(it, value);
+}
+
+void KSocketBuffer::insert_u16(KBuffer::iterator it, uint16_t value) {
+ uint16_t net_value;
+ uint8_t insert_value[2];
+
+ net_value = htons(value);
+ insert_value[0] = (net_value & 0x00FF) >> 0;
+ insert_value[1] = (net_value & 0xFF00) >> 8;
+ buffer.insert(it, insert_value, insert_value + eastl::size(insert_value));
+}
+
+void KSocketBuffer::insert_u32(KBuffer::iterator it, uint32_t value) {
+ uint32_t net_value;
+ uint8_t insert_value[4];
+
+ net_value = htonl(value);
+ insert_value[0] = (net_value & 0x000000FF) >> 0;
+ insert_value[1] = (net_value & 0x0000FF00) >> 8;
+ insert_value[2] = (net_value & 0x00FF0000) >> 16;
+ insert_value[3] = (net_value & 0xFF000000) >> 24;
+ buffer.insert(it, insert_value, insert_value + eastl::size(insert_value));
+}
+
+void KSocketBuffer::insert_u64(KBuffer::iterator it, uint64_t value) {
+ uint64_t net_value;
+ uint8_t insert_value[8];
+
+ net_value = htonll(value);
+ insert_value[0] = (net_value & 0x00000000000000FF) >> 0;
+ insert_value[1] = (net_value & 0x000000000000FF00) >> 8;
+ insert_value[2] = (net_value & 0x0000000000FF0000) >> 16;
+ insert_value[3] = (net_value & 0x00000000FF000000) >> 24;
+ insert_value[4] = (net_value & 0x000000FF00000000) >> 32;
+ insert_value[5] = (net_value & 0x0000FF0000000000) >> 40;
+ insert_value[6] = (net_value & 0x00FF000000000000) >> 48;
+ insert_value[7] = (net_value & 0xFF00000000000000) >> 56;
+ buffer.insert(it, insert_value, insert_value + eastl::size(insert_value));
+}
+
+void KSocketBuffer::insert_bytebuffer(KBuffer::iterator it,
+ const uint8_t bytebuffer[], size_t size) {
+ buffer.insert(it, bytebuffer, bytebuffer + size);
+}
+
+void KSocketBuffer::consume(size_t amount_bytes) {
+ if (amount_bytes == 0 || amount_bytes > buffer.size())
+ amount_bytes = buffer.size();
+
+ buffer.erase(buffer.begin(), buffer.begin() + amount_bytes);
+}
+
+eastl::string KSocketBuffer::toHex(eastl::string delim) {
+ eastl::string str;
+ char const *const hex = "0123456789ABCDEF";
+ char pout[3] = {};
+
+ for (const auto &input_byte : buffer) {
+ pout[0] = hex[(input_byte >> 4) & 0xF];
+ pout[1] = hex[input_byte & 0xF];
+ str += pout;
+ str += delim;
+ }
+
+ if (str.length() >= delim.length() && delim.length() > 0)
+ str.erase(str.length() - delim.length(), delim.length());
+
+ return str;
+}
+
+KSocket::~KSocket() {
+ if (m_socket != nullptr)
+ delete m_socket;
+}
+
+bool KSocket::setup(KSocketType sock_type, int proto) {
+ int domain, type;
+
+ m_lastError = KSE_SUCCESS;
+ if (m_socket != nullptr)
+ delete m_socket;
+ m_socket = new KSocketImpl();
+ if (m_socket == nullptr) {
+ m_lastError = KSE_SETUP_IMPL_NULL;
+ return false;
+ }
+
+ if (KSocket::socketTypeToTuple(sock_type, domain, type) != true) {
+ m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE;
+ return false;
+ }
+ if (sock_type == KSocketType::KST_STREAM_CLIENT_IP6 ||
+ sock_type == KSocketType::KST_STREAM_SERVER_IP6 ||
+ sock_type == KSocketType::KST_DATAGRAM_IP6) {
+ m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE;
+ return false; // IPv6 is not supported for now
+ }
+ m_socketType = sock_type;
+
+#ifdef BUILD_USERMODE
+ m_socket->s = ::socket(domain, type, proto);
+#else
+ // DbgPrint("KSocketType: %d, domain: %d, type: %d, proto: %d", sock_type,
+ // domain, type, proto);
+ switch (sock_type) {
+ case KSocketType::KST_INVALID:
+ m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE;
+ return false;
+ case KSocketType::KST_STREAM_CLIENT_IP4:
+ case KSocketType::KST_STREAM_CLIENT_IP6:
+ m_socket->s = ::socket_connection(domain, type, proto);
+ break;
+ case KSocketType::KST_STREAM_SERVER_IP4:
+ case KSocketType::KST_STREAM_SERVER_IP6:
+ m_socket->s = ::socket_listen(domain, type, proto);
+ break;
+ case KSocketType::KST_DATAGRAM_IP4:
+ case KSocketType::KST_DATAGRAM_IP6:
+ m_socket->s = ::socket_datagram(domain, type, proto);
+ break;
+ }
+#endif
+ m_socket->c.domain = static_cast<ADDRESS_FAMILY>(domain);
+ m_socket->c.type = type;
+ m_socket->c.proto = proto;
+#ifdef BUILD_USERMODE
+ return m_socket->s != INVALID_SOCKET;
+#else
+ return m_socket->s >= 0;
+#endif
+}
+
+bool KSocket::connect(eastl::string host, eastl::string port) {
+ struct addrinfo hints = {};
+ struct addrinfo *results;
+
+ m_lastError = KSE_SUCCESS;
+
+ if (m_socket == nullptr)
+ return false;
+
+ if (m_socketType != KSocketType::KST_STREAM_CLIENT_IP4 &&
+ m_socketType != KSocketType::KST_STREAM_CLIENT_IP6)
+ return false;
+
+ hints.ai_flags |= AI_CANONNAME;
+ hints.ai_family = m_socket->c.domain;
+ hints.ai_socktype = m_socket->c.type;
+ m_lastError = ::getaddrinfo(host.c_str(), port.c_str(), &hints, &results);
+ if (m_lastError != KSE_SUCCESS)
+ return false;
+
+ m_lastError =
+ ::connect(m_socket->s, results->ai_addr, (int)results->ai_addrlen);
+ freeaddrinfo(results);
+ return m_lastError == KSE_SUCCESS;
+}
+
+bool KSocket::bind(uint16_t port) {
+ struct sockaddr_in addr;
+
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = INADDR_ANY;
+ addr.sin_port = htons(port);
+
+ m_lastError = ::bind(m_socket->s, (struct sockaddr *)&addr, sizeof(addr));
+
+ return m_lastError == KSE_SUCCESS;
+}
+
+bool KSocket::listen(int backlog) {
+ m_lastError = KSE_SUCCESS;
+
+ if (m_socketType != KSocketType::KST_STREAM_SERVER_IP4 &&
+ m_socketType != KSocketType::KST_STREAM_SERVER_IP6)
+ return false;
+
+ m_lastError = ::listen(m_socket->s, backlog);
+
+ return m_lastError == KSE_SUCCESS;
+}
+
+bool KSocket::accept(KAcceptThreadCallback thread_callback) {
+ KAcceptedSocket ka;
+ struct sockaddr addr;
+ socklen_t addrlen = sizeof(addr);
+
+ if (m_socket->c.domain != AF_INET) {
+ m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE;
+ return false;
+ }
+
+ addr.sa_family = m_socket->c.domain;
+ ka.m_socket = new KSocketImpl();
+ if (ka.m_socket == nullptr) {
+ m_lastError = KSE_SETUP_IMPL_NULL;
+ return false;
+ }
+ ka.m_socketType = m_socketType;
+ ka.m_socket->c = m_socket->c;
+ ka.m_socket->s = ::accept(m_socket->s, &addr, &addrlen);
+
+ if (m_socket->c.domain == AF_INET) {
+ struct sockaddr_in *addr_in = reinterpret_cast<struct sockaddr_in *>(&addr);
+
+ ka.m_remote.addr_used = 4;
+ ka.m_remote.addr.u32[0] = addr_in->sin_addr.s_addr;
+ ka.m_remote.port = addr_in->sin_port;
+ } else {
+ m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE;
+ return false;
+ }
+
+#ifdef BUILD_USERMODE
+ if (ka.m_socket->s == INVALID_SOCKET)
+#else
+ if (ka.m_socket->s < 0)
+#endif
+ {
+ m_lastError = KSE_ACCEPT_FAILED;
+ return false;
+ }
+
+ return thread_callback(ka);
+}
+
+bool KSocket::close() {
+ int rv = closesocket(m_socket->s);
+
+ if (rv == 0) {
+ m_socket->s = -1;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool KSocket::send() {
+ m_lastError = KSE_SUCCESS;
+
+ if (m_sendBuffer.size() == 0)
+ return false;
+
+ m_lastError =
+ ::send(m_socket->s, reinterpret_cast<const char *>(m_sendBuffer.data()),
+ m_sendBuffer.size(), 0);
+ if (m_lastError > 0) {
+ m_sendBuffer.buffer.erase(m_sendBuffer.begin(),
+ m_sendBuffer.begin() + m_lastError);
+ return true;
+ }
+
+ return false;
+}
+
+bool KSocket::recv(size_t max_recv_size) {
+ const size_t current_size = m_recvBuffer.size();
+
+ m_recvBuffer.buffer.resize(current_size + max_recv_size);
+ m_lastError = ::recv(
+ m_socket->s, reinterpret_cast<char *>(m_recvBuffer.data() + current_size),
+ max_recv_size, 0);
+ if (m_lastError > 0) {
+ m_recvBuffer.buffer.resize(current_size + m_lastError);
+ return true;
+ }
+
+ return false;
+}
+
+bool KSocket::socketTypeToTuple(KSocketType sock_type, int &domain, int &type) {
+ switch (sock_type) {
+ case KSocketType::KST_INVALID:
+ break;
+ case KSocketType::KST_STREAM_CLIENT_IP4:
+ case KSocketType::KST_STREAM_SERVER_IP4:
+ domain = AF_INET;
+ type = SOCK_STREAM;
+ return true;
+ case KSocketType::KST_STREAM_CLIENT_IP6:
+ case KSocketType::KST_STREAM_SERVER_IP6:
+ domain = AF_INET6;
+ type = SOCK_STREAM;
+ return true;
+ case KSocketType::KST_DATAGRAM_IP4:
+ case KSocketType::KST_DATAGRAM_IP6:
+ domain = AF_INET;
+ type = SOCK_DGRAM;
+ return true;
+ }
+
+ return false;
+}