#include "ksocket.hpp" #include <EASTL/type_traits.h> #include <eastl_compat.hpp> #ifdef BUILD_USERMODE // clang-format off #include <winsock2.h> #include <windows.h> #include <ws2tcpip.h> // clang-format on #else #include <ksocket/berkeley.h> #include <ksocket/ksocket.h> #include <ksocket/wsk.h> #endif #include <ksocket/utils.h> struct KSocketImplCommon { ADDRESS_FAMILY domain; int type; int proto; }; #ifdef BUILD_USERMODE struct KSocketImpl { ~KSocketImpl() { if (s != INVALID_SOCKET) { closesocket(s); s = INVALID_SOCKET; } } SOCKET s = INVALID_SOCKET; KSocketImplCommon c; }; #else struct KSocketImpl { ~KSocketImpl() { if (s >= 0) { closesocket(s); s = -1; } } int s = -1; KSocketImplCommon c; }; #endif eastl::string KSocketAddress::to_string(bool with_port) const { if (addr_used != 4) { return ""; } return ::to_string(addr.u8[0]) + "." + ::to_string(addr.u8[1]) + "." + ::to_string(addr.u8[2]) + "." + ::to_string(addr.u8[3]) + ":" + ::to_string(with_port); } void KSocketBuffer::insert_u8(KBuffer::iterator it, uint8_t value) { buffer.insert(it, value); } void KSocketBuffer::insert_u16(KBuffer::iterator it, uint16_t value) { uint16_t net_value; uint8_t insert_value[2]; net_value = htons(value); insert_value[0] = (net_value & 0x00FF) >> 0; insert_value[1] = (net_value & 0xFF00) >> 8; buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); } void KSocketBuffer::insert_u32(KBuffer::iterator it, uint32_t value) { uint32_t net_value; uint8_t insert_value[4]; net_value = htonl(value); insert_value[0] = (net_value & 0x000000FF) >> 0; insert_value[1] = (net_value & 0x0000FF00) >> 8; insert_value[2] = (net_value & 0x00FF0000) >> 16; insert_value[3] = (net_value & 0xFF000000) >> 24; buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); } void KSocketBuffer::insert_u64(KBuffer::iterator it, uint64_t value) { uint64_t net_value; uint8_t insert_value[8]; net_value = htonll(value); insert_value[0] = (net_value & 0x00000000000000FF) >> 0; insert_value[1] = (net_value & 0x000000000000FF00) >> 8; insert_value[2] = (net_value & 0x0000000000FF0000) >> 16; insert_value[3] = (net_value & 0x00000000FF000000) >> 24; insert_value[4] = (net_value & 0x000000FF00000000) >> 32; insert_value[5] = (net_value & 0x0000FF0000000000) >> 40; insert_value[6] = (net_value & 0x00FF000000000000) >> 48; insert_value[7] = (net_value & 0xFF00000000000000) >> 56; buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); } void KSocketBuffer::insert_bytebuffer(KBuffer::iterator it, const uint8_t bytebuffer[], size_t size) { buffer.insert(it, bytebuffer, bytebuffer + size); } void KSocketBuffer::consume(size_t amount_bytes) { if (amount_bytes == 0 || amount_bytes > buffer.size()) amount_bytes = buffer.size(); buffer.erase(buffer.begin(), buffer.begin() + amount_bytes); } eastl::string KSocketBuffer::toHex(eastl::string delim) { eastl::string str; char const *const hex = "0123456789ABCDEF"; char pout[3] = {}; for (const auto &input_byte : buffer) { pout[0] = hex[(input_byte >> 4) & 0xF]; pout[1] = hex[input_byte & 0xF]; str += pout; str += delim; } if (str.length() >= delim.length() && delim.length() > 0) str.erase(str.length() - delim.length(), delim.length()); return str; } KSocket::~KSocket() { if (m_socket != nullptr) delete m_socket; } bool KSocket::setup(KSocketType sock_type, int proto) { int domain, type; m_lastError = KSE_SUCCESS; if (m_socket != nullptr) delete m_socket; m_socket = new KSocketImpl(); if (m_socket == nullptr) { m_lastError = KSE_SETUP_IMPL_NULL; return false; } if (KSocket::socketTypeToTuple(sock_type, domain, type) != true) { m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE; return false; } if (sock_type == KSocketType::KST_STREAM_CLIENT_IP6 || sock_type == KSocketType::KST_STREAM_SERVER_IP6 || sock_type == KSocketType::KST_DATAGRAM_IP6) { m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; return false; // IPv6 is not supported for now } m_socketType = sock_type; #ifdef BUILD_USERMODE m_socket->s = ::socket(domain, type, proto); #else // DbgPrint("KSocketType: %d, domain: %d, type: %d, proto: %d", sock_type, // domain, type, proto); switch (sock_type) { case KSocketType::KST_INVALID: m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE; return false; case KSocketType::KST_STREAM_CLIENT_IP4: case KSocketType::KST_STREAM_CLIENT_IP6: m_socket->s = ::socket_connection(domain, type, proto); break; case KSocketType::KST_STREAM_SERVER_IP4: case KSocketType::KST_STREAM_SERVER_IP6: m_socket->s = ::socket_listen(domain, type, proto); break; case KSocketType::KST_DATAGRAM_IP4: case KSocketType::KST_DATAGRAM_IP6: m_socket->s = ::socket_datagram(domain, type, proto); break; } #endif m_socket->c.domain = static_cast<ADDRESS_FAMILY>(domain); m_socket->c.type = type; m_socket->c.proto = proto; #ifdef BUILD_USERMODE return m_socket->s != INVALID_SOCKET; #else return m_socket->s >= 0; #endif } bool KSocket::connect(eastl::string host, eastl::string port) { struct addrinfo hints = {}; struct addrinfo *results; m_lastError = KSE_SUCCESS; if (!sanityCheck()) return false; if (m_socketType != KSocketType::KST_STREAM_CLIENT_IP4 && m_socketType != KSocketType::KST_STREAM_CLIENT_IP6) return false; hints.ai_flags |= AI_CANONNAME; hints.ai_family = m_socket->c.domain; hints.ai_socktype = m_socket->c.type; m_lastError = ::getaddrinfo(host.c_str(), port.c_str(), &hints, &results); if (m_lastError != KSE_SUCCESS) return false; m_lastError = ::connect(m_socket->s, results->ai_addr, (int)results->ai_addrlen); freeaddrinfo(results); return m_lastError == KSE_SUCCESS; } bool KSocket::bind(uint16_t port) { struct sockaddr_in addr; if (!sanityCheck()) return false; addr.sin_family = AF_INET; addr.sin_addr.s_addr = INADDR_ANY; addr.sin_port = htons(port); m_lastError = ::bind(m_socket->s, (struct sockaddr *)&addr, sizeof(addr)); return m_lastError == KSE_SUCCESS; } bool KSocket::listen(int backlog) { m_lastError = KSE_SUCCESS; if (!sanityCheck()) return false; if (m_socketType != KSocketType::KST_STREAM_SERVER_IP4 && m_socketType != KSocketType::KST_STREAM_SERVER_IP6) return false; m_lastError = ::listen(m_socket->s, backlog); return m_lastError == KSE_SUCCESS; } bool KSocket::accept(KAcceptThreadCallback thread_callback) { KAcceptedSocket ka; struct sockaddr addr; socklen_t addrlen = sizeof(addr); if (!sanityCheck()) return false; if (m_socket->c.domain != AF_INET) { m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; return false; } addr.sa_family = m_socket->c.domain; ka.m_socket = new KSocketImpl(); if (ka.m_socket == nullptr) { m_lastError = KSE_SETUP_IMPL_NULL; return false; } ka.m_socketType = m_socketType; ka.m_socket->c = m_socket->c; ka.m_socket->s = ::accept(m_socket->s, &addr, &addrlen); if (m_socket->c.domain == AF_INET) { struct sockaddr_in *addr_in = reinterpret_cast<struct sockaddr_in *>(&addr); ka.m_remote.addr_used = 4; ka.m_remote.addr.u32[0] = addr_in->sin_addr.s_addr; ka.m_remote.port = addr_in->sin_port; } else { m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; return false; } #ifdef BUILD_USERMODE if (ka.m_socket->s == INVALID_SOCKET) #else if (ka.m_socket->s < 0) #endif { m_lastError = KSE_ACCEPT_FAILED; return false; } return thread_callback(ka); } bool KSocket::close() { int rv; if (!sanityCheck()) return false; rv = closesocket(m_socket->s); if (rv == 0) { m_socket->s = -1; return true; } else { return false; } } bool KSocket::send() { m_lastError = KSE_SUCCESS; if (!sanityCheck()) return false; if (m_sendBuffer.size() == 0) return false; m_lastError = ::send(m_socket->s, reinterpret_cast<const char *>(m_sendBuffer.data()), m_sendBuffer.size(), 0); if (m_lastError > 0) { m_sendBuffer.buffer.erase(m_sendBuffer.begin(), m_sendBuffer.begin() + m_lastError); return true; } return false; } bool KSocket::recv(size_t max_recv_size) { const size_t current_size = m_recvBuffer.size(); if (!sanityCheck()) return false; m_recvBuffer.buffer.resize(current_size + max_recv_size); m_lastError = ::recv( m_socket->s, reinterpret_cast<char *>(m_recvBuffer.data() + current_size), max_recv_size, 0); if (m_lastError > 0) { m_recvBuffer.buffer.resize(current_size + m_lastError); return true; } return false; } bool KSocket::sanityCheck() { if (m_socket != nullptr && #ifdef BUILD_USERMODE m_socket->s != INVALID_SOCKET #else m_socket->s >= 0 #endif ) { return true; } m_lastError = KSE_INVALID_SOCKET; return false; } bool KSocket::socketTypeToTuple(KSocketType sock_type, int &domain, int &type) { switch (sock_type) { case KSocketType::KST_INVALID: break; case KSocketType::KST_STREAM_CLIENT_IP4: case KSocketType::KST_STREAM_SERVER_IP4: domain = AF_INET; type = SOCK_STREAM; return true; case KSocketType::KST_STREAM_CLIENT_IP6: case KSocketType::KST_STREAM_SERVER_IP6: domain = AF_INET6; type = SOCK_STREAM; return true; case KSocketType::KST_DATAGRAM_IP4: case KSocketType::KST_DATAGRAM_IP6: domain = AF_INET; type = SOCK_DGRAM; return true; } return false; }