diff options
Diffstat (limited to 'ksocket/ksocket.cpp')
-rw-r--r-- | ksocket/ksocket.cpp | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/ksocket/ksocket.cpp b/ksocket/ksocket.cpp index ca0625b..7d8be3c 100644 --- a/ksocket/ksocket.cpp +++ b/ksocket/ksocket.cpp @@ -24,9 +24,14 @@ struct KSocketImplCommon { #ifdef BUILD_USERMODE struct KSocketImpl { - ~KSocketImpl() { closesocket(s); } + ~KSocketImpl() { + if (s != INVALID_SOCKET) { + closesocket(s); + s = INVALID_SOCKET; + } + } - SOCKET s; + SOCKET s = INVALID_SOCKET; KSocketImplCommon c; }; #else @@ -193,7 +198,7 @@ bool KSocket::connect(eastl::string host, eastl::string port) { m_lastError = KSE_SUCCESS; - if (m_socket == nullptr) + if (!sanityCheck()) return false; if (m_socketType != KSocketType::KST_STREAM_CLIENT_IP4 && @@ -216,6 +221,9 @@ bool KSocket::connect(eastl::string host, eastl::string port) { bool KSocket::bind(uint16_t port) { struct sockaddr_in addr; + if (!sanityCheck()) + return false; + addr.sin_family = AF_INET; addr.sin_addr.s_addr = INADDR_ANY; addr.sin_port = htons(port); @@ -228,6 +236,9 @@ bool KSocket::bind(uint16_t port) { bool KSocket::listen(int backlog) { m_lastError = KSE_SUCCESS; + if (!sanityCheck()) + return false; + if (m_socketType != KSocketType::KST_STREAM_SERVER_IP4 && m_socketType != KSocketType::KST_STREAM_SERVER_IP6) return false; @@ -242,6 +253,9 @@ bool KSocket::accept(KAcceptThreadCallback thread_callback) { struct sockaddr addr; socklen_t addrlen = sizeof(addr); + if (!sanityCheck()) + return false; + if (m_socket->c.domain != AF_INET) { m_lastError = KSE_SETUP_UNSUPPORTED_SOCKET_TYPE; return false; @@ -282,7 +296,12 @@ bool KSocket::accept(KAcceptThreadCallback thread_callback) { } bool KSocket::close() { - int rv = closesocket(m_socket->s); + int rv; + + if (!sanityCheck()) + return false; + + rv = closesocket(m_socket->s); if (rv == 0) { m_socket->s = -1; @@ -295,6 +314,9 @@ bool KSocket::close() { bool KSocket::send() { m_lastError = KSE_SUCCESS; + if (!sanityCheck()) + return false; + if (m_sendBuffer.size() == 0) return false; @@ -313,6 +335,9 @@ bool KSocket::send() { bool KSocket::recv(size_t max_recv_size) { const size_t current_size = m_recvBuffer.size(); + if (!sanityCheck()) + return false; + m_recvBuffer.buffer.resize(current_size + max_recv_size); m_lastError = ::recv( m_socket->s, reinterpret_cast<char *>(m_recvBuffer.data() + current_size), @@ -325,6 +350,21 @@ bool KSocket::recv(size_t max_recv_size) { return false; } +bool KSocket::sanityCheck() { + if (m_socket != nullptr && +#ifdef BUILD_USERMODE + m_socket->s != INVALID_SOCKET +#else + m_socket->s >= 0 +#endif + ) { + return true; + } + + m_lastError = KSE_INVALID_SOCKET; + return false; +} + bool KSocket::socketTypeToTuple(KSocketType sock_type, int &domain, int &type) { switch (sock_type) { case KSocketType::KST_INVALID: |