diff options
Diffstat (limited to 'ksocket/ksocket.cpp')
-rw-r--r-- | ksocket/ksocket.cpp | 350 |
1 files changed, 350 insertions, 0 deletions
diff --git a/ksocket/ksocket.cpp b/ksocket/ksocket.cpp new file mode 100644 index 0000000..ca0625b --- /dev/null +++ b/ksocket/ksocket.cpp @@ -0,0 +1,350 @@ +#include "ksocket.hpp" + +#include <EASTL/type_traits.h> +#include <eastl_compat.hpp> + +#ifdef BUILD_USERMODE +// clang-format off +#include <winsock2.h> +#include <windows.h> +#include <ws2tcpip.h> +// clang-format on +#else +#include <ksocket/berkeley.h> +#include <ksocket/ksocket.h> +#include <ksocket/wsk.h> +#endif +#include <ksocket/utils.h> + +struct KSocketImplCommon { + ADDRESS_FAMILY domain; + int type; + int proto; +}; + +#ifdef BUILD_USERMODE +struct KSocketImpl { + ~KSocketImpl() { closesocket(s); } + + SOCKET s; + KSocketImplCommon c; +}; +#else +struct KSocketImpl { + ~KSocketImpl() { + if (s >= 0) { + closesocket(s); + s = -1; + } + } + + int s = -1; + KSocketImplCommon c; +}; +#endif + +eastl::string KSocketAddress::to_string(bool with_port) const { + if (addr_used != 4) { + return ""; + } + + return ::to_string(addr.u8[0]) + "." + ::to_string(addr.u8[1]) + "." + + ::to_string(addr.u8[2]) + "." + ::to_string(addr.u8[3]) + ":" + + ::to_string(with_port); +} + +void KSocketBuffer::insert_u8(KBuffer::iterator it, uint8_t value) { + buffer.insert(it, value); +} + +void KSocketBuffer::insert_u16(KBuffer::iterator it, uint16_t value) { + uint16_t net_value; + uint8_t insert_value[2]; + + net_value = htons(value); + insert_value[0] = (net_value & 0x00FF) >> 0; + insert_value[1] = (net_value & 0xFF00) >> 8; + buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); +} + +void KSocketBuffer::insert_u32(KBuffer::iterator it, uint32_t value) { + uint32_t net_value; + uint8_t insert_value[4]; + + net_value = htonl(value); + insert_value[0] = (net_value & 0x000000FF) >> 0; + insert_value[1] = (net_value & 0x0000FF00) >> 8; + insert_value[2] = (net_value & 0x00FF0000) >> 16; + insert_value[3] = (net_value & 0xFF000000) >> 24; + buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); +} + +void KSocketBuffer::insert_u64(KBuffer::iterator it, uint64_t value) { + uint64_t net_value; + uint8_t insert_value[8]; + + net_value = htonll(value); + insert_value[0] = (net_value & 0x00000000000000FF) >> 0; + insert_value[1] = (net_value & 0x000000000000FF00) >> 8; + insert_value[2] = (net_value & 0x0000000000FF0000) >> 16; + insert_value[3] = (net_value & 0x00000000FF000000) >> 24; + insert_value[4] = (net_value & 0x000000FF00000000) >> 32; + insert_value[5] = (net_value & 0x0000FF0000000000) >> 40; + insert_value[6] = (net_value & 0x00FF000000000000) >> 48; + insert_value[7] = (net_value & 0xFF00000000000000) >> 56; + buffer.insert(it, insert_value, insert_value + eastl::size(insert_value)); +} + +void KSocketBuffer::insert_bytebuffer(KBuffer::iterator it, + const uint8_t bytebuffer[], size_t size) { + buffer.insert(it, bytebuffer, bytebuffer + size); +} + +void KSocketBuffer::consume(size_t amount_bytes) { + if (amount_bytes == 0 || amount_bytes > buffer.size()) + amount_bytes = buffer.size(); + + buffer.erase(buffer.begin(), buffer.begin() + amount_bytes); +} + +eastl::string KSocketBuffer::toHex(eastl::string delim) { + eastl::string str; + char const *const hex = "0123456789ABCDEF"; + char pout[3] = {}; + + for (const auto &input_byte : buffer) { + pout[0] = hex[(input_byte >> 4) & 0xF]; + pout[1] = hex[input_byte & 0xF]; + str += pout; + str += delim; + } + + if (str.length() >= delim.length() && delim.length() > 0) + str.erase(str.length() - delim.length(), delim.length()); + + return str; +} + +KSocket::~KSocket() { + if (m_socket != nullptr) + delete m_socket; +} + +bool KSocket::setup(KSocketType sock_type, int proto) { + int domain, type; + + m_lastError = KSE_SUCCESS; + if (m_socket != nullptr) + delete m_socket; + m_socket = new KSocketImpl(); + if (m_socket == nullptr) { + m_lastError = KSE_SETUP_IMPL_NULL; + return false; + } + + if (KSocket::socketTypeToTuple(sock_type, domain, type) != true) { + m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE; + return false; + } + if (sock_type == KSocketType::KST_STREAM_CLIENT_IP6 || + sock_type == KSocketType::KST_STREAM_SERVER_IP6 || + sock_type == KSocketType::KST_DATAGRAM_IP6) { + m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; + return false; // IPv6 is not supported for now + } + m_socketType = sock_type; + +#ifdef BUILD_USERMODE + m_socket->s = ::socket(domain, type, proto); +#else + // DbgPrint("KSocketType: %d, domain: %d, type: %d, proto: %d", sock_type, + // domain, type, proto); + switch (sock_type) { + case KSocketType::KST_INVALID: + m_lastError = KSE_SETUP_INVALID_SOCKET_TYPE; + return false; + case KSocketType::KST_STREAM_CLIENT_IP4: + case KSocketType::KST_STREAM_CLIENT_IP6: + m_socket->s = ::socket_connection(domain, type, proto); + break; + case KSocketType::KST_STREAM_SERVER_IP4: + case KSocketType::KST_STREAM_SERVER_IP6: + m_socket->s = ::socket_listen(domain, type, proto); + break; + case KSocketType::KST_DATAGRAM_IP4: + case KSocketType::KST_DATAGRAM_IP6: + m_socket->s = ::socket_datagram(domain, type, proto); + break; + } +#endif + m_socket->c.domain = static_cast<ADDRESS_FAMILY>(domain); + m_socket->c.type = type; + m_socket->c.proto = proto; +#ifdef BUILD_USERMODE + return m_socket->s != INVALID_SOCKET; +#else + return m_socket->s >= 0; +#endif +} + +bool KSocket::connect(eastl::string host, eastl::string port) { + struct addrinfo hints = {}; + struct addrinfo *results; + + m_lastError = KSE_SUCCESS; + + if (m_socket == nullptr) + return false; + + if (m_socketType != KSocketType::KST_STREAM_CLIENT_IP4 && + m_socketType != KSocketType::KST_STREAM_CLIENT_IP6) + return false; + + hints.ai_flags |= AI_CANONNAME; + hints.ai_family = m_socket->c.domain; + hints.ai_socktype = m_socket->c.type; + m_lastError = ::getaddrinfo(host.c_str(), port.c_str(), &hints, &results); + if (m_lastError != KSE_SUCCESS) + return false; + + m_lastError = + ::connect(m_socket->s, results->ai_addr, (int)results->ai_addrlen); + freeaddrinfo(results); + return m_lastError == KSE_SUCCESS; +} + +bool KSocket::bind(uint16_t port) { + struct sockaddr_in addr; + + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = INADDR_ANY; + addr.sin_port = htons(port); + + m_lastError = ::bind(m_socket->s, (struct sockaddr *)&addr, sizeof(addr)); + + return m_lastError == KSE_SUCCESS; +} + +bool KSocket::listen(int backlog) { + m_lastError = KSE_SUCCESS; + + if (m_socketType != KSocketType::KST_STREAM_SERVER_IP4 && + m_socketType != KSocketType::KST_STREAM_SERVER_IP6) + return false; + + m_lastError = ::listen(m_socket->s, backlog); + + return m_lastError == KSE_SUCCESS; +} + +bool KSocket::accept(KAcceptThreadCallback thread_callback) { + KAcceptedSocket ka; + struct sockaddr addr; + socklen_t addrlen = sizeof(addr); + + if (m_socket->c.domain != AF_INET) { + m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; + return false; + } + + addr.sa_family = m_socket->c.domain; + ka.m_socket = new KSocketImpl(); + if (ka.m_socket == nullptr) { + m_lastError = KSE_SETUP_IMPL_NULL; + return false; + } + ka.m_socketType = m_socketType; + ka.m_socket->c = m_socket->c; + ka.m_socket->s = ::accept(m_socket->s, &addr, &addrlen); + + if (m_socket->c.domain == AF_INET) { + struct sockaddr_in *addr_in = reinterpret_cast<struct sockaddr_in *>(&addr); + + ka.m_remote.addr_used = 4; + ka.m_remote.addr.u32[0] = addr_in->sin_addr.s_addr; + ka.m_remote.port = addr_in->sin_port; + } else { + m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; + return false; + } + +#ifdef BUILD_USERMODE + if (ka.m_socket->s == INVALID_SOCKET) +#else + if (ka.m_socket->s < 0) +#endif + { + m_lastError = KSE_ACCEPT_FAILED; + return false; + } + + return thread_callback(ka); +} + +bool KSocket::close() { + int rv = closesocket(m_socket->s); + + if (rv == 0) { + m_socket->s = -1; + return true; + } else { + return false; + } +} + +bool KSocket::send() { + m_lastError = KSE_SUCCESS; + + if (m_sendBuffer.size() == 0) + return false; + + m_lastError = + ::send(m_socket->s, reinterpret_cast<const char *>(m_sendBuffer.data()), + m_sendBuffer.size(), 0); + if (m_lastError > 0) { + m_sendBuffer.buffer.erase(m_sendBuffer.begin(), + m_sendBuffer.begin() + m_lastError); + return true; + } + + return false; +} + +bool KSocket::recv(size_t max_recv_size) { + const size_t current_size = m_recvBuffer.size(); + + m_recvBuffer.buffer.resize(current_size + max_recv_size); + m_lastError = ::recv( + m_socket->s, reinterpret_cast<char *>(m_recvBuffer.data() + current_size), + max_recv_size, 0); + if (m_lastError > 0) { + m_recvBuffer.buffer.resize(current_size + m_lastError); + return true; + } + + return false; +} + +bool KSocket::socketTypeToTuple(KSocketType sock_type, int &domain, int &type) { + switch (sock_type) { + case KSocketType::KST_INVALID: + break; + case KSocketType::KST_STREAM_CLIENT_IP4: + case KSocketType::KST_STREAM_SERVER_IP4: + domain = AF_INET; + type = SOCK_STREAM; + return true; + case KSocketType::KST_STREAM_CLIENT_IP6: + case KSocketType::KST_STREAM_SERVER_IP6: + domain = AF_INET6; + type = SOCK_STREAM; + return true; + case KSocketType::KST_DATAGRAM_IP4: + case KSocketType::KST_DATAGRAM_IP6: + domain = AF_INET; + type = SOCK_DGRAM; + return true; + } + + return false; +} |