diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/forward.c | 10 | ||||
-rw-r--r-- | src/forward.h | 2 | ||||
-rw-r--r-- | src/server.c | 153 | ||||
-rw-r--r-- | src/server_ssh.c | 6 | ||||
-rw-r--r-- | src/socket.c | 75 | ||||
-rw-r--r-- | src/socket.h | 8 |
6 files changed, 191 insertions, 63 deletions
diff --git a/src/forward.c b/src/forward.c index 54b64fe..f12a6f8 100644 --- a/src/forward.c +++ b/src/forward.c @@ -36,7 +36,7 @@ int fwd_setup(forward_ctx *ctx, const char *host, const char *port) return 1; if (ctx->fwd_cbs.on_listen(ctx, host, port)) return 1; - if (socket_connect_in(&ctx->sock, fwd_addr)) { + if (socket_connect_in(&ctx->sock, &fwd_addr)) { E_STRERR("Connection to forward socket"); return 1; } @@ -63,3 +63,11 @@ int fwd_validate_ctx(const forward_ctx *ctx) return 0; } + +int fwd_connect(forward_ctx *ctx, psocket *fwd_client) +{ + assert(ctx); + socket_clone(&ctx->sock, fwd_client); + + return socket_reconnect_in(fwd_client); +} diff --git a/src/forward.h b/src/forward.h index 4d9c88f..386dc94 100644 --- a/src/forward.h +++ b/src/forward.h @@ -30,4 +30,6 @@ int fwd_setup(forward_ctx *ctx, const char *host, const char *port); int fwd_validate_ctx(const forward_ctx *ctx); +int fwd_connect(forward_ctx *ctx, psocket *fwd_client); + #endif diff --git a/src/server.c b/src/server.c index e1b93cc..31f7e45 100644 --- a/src/server.c +++ b/src/server.c @@ -18,11 +18,17 @@ typedef struct client_thread_args { const server_ctx *server_ctx; } client_thread_args; +typedef enum connection_state { + CON_OK, CON_IN_TERMINATED, CON_OUT_TERMINATED, + CON_IN_ERROR, CON_OUT_ERROR +} connection_state; + static int server_accept_client(server_ctx *ctx[], size_t siz, struct epoll_event *event); static void * client_mainloop_epoll(void *arg); -static int client_io_epoll(struct epoll_event *ev); +static connection_state +client_io_epoll(struct epoll_event *ev, int dest_fd); int server_init_ctx(server_ctx **ctx, forward_ctx *fwd_ctx) @@ -54,7 +60,7 @@ int server_setup(server_ctx *ctx, E_GAIERR(s, "Could not initialise server socket"); return 1; } - if (socket_bind_in(&ctx->sock, srv_addr)) { + if (socket_bind_in(&ctx->sock, &srv_addr)) { E_STRERR("Could not bind server socket"); return 1; } @@ -120,13 +126,14 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz) assert(ctx); assert(siz > 0 && siz < POTD_MAXFD); + signal(SIGPIPE, SIG_IGN); sigemptyset(&eset); while (1) { int n, i; n = epoll_pwait(epoll_fd, events, POTD_MAXEVENTS, -1, &eset); if (n < 0) - return 1; + goto error; for (i = 0; i < n; ++i) { if ((events[i].events & EPOLLERR) || @@ -149,6 +156,9 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz) free(events); return 0; +error: + free(events); + return 1; } static int server_accept_client(server_ctx *ctx[], @@ -188,7 +198,7 @@ static int server_accept_client(server_ctx *ctx[], return 1; error: - close(args->client_psock.fd); + socket_close(&args->client_psock); free(args); return 0; } @@ -201,10 +211,12 @@ static void * client_mainloop_epoll(void *arg) { client_thread_args *args; - int s, epoll_fd, active = 1; + int s, epoll_fd, dest_fd, active = 1; struct epoll_event event = {0,{0}}; struct epoll_event *events; sigset_t eset; + connection_state cs; + psocket fwd; assert(arg); args = (client_thread_args *) arg; @@ -218,6 +230,33 @@ client_mainloop_epoll(void *arg) goto finish; } + if (fwd_connect(args->server_ctx->fwd_ctx, &fwd)) { + E("Forward connection to %s:%s failed", + args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf); + E_STRERR("Forward connect"); + goto finish; + } + N("Forwarding connection to %s:%s: %d", args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf, fwd.fd); + + event.data.fd = fwd.fd; + event.events = EPOLLIN | EPOLLET; + s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fwd.fd, &event); + if (s) { + E_STRERR("Forward epoll_ctl"); + goto finish; + } + + /* + * We got the client socket from our main thread, so fd flags like + * O_NONBLOCK are not inherited! + */ + s = socket_nonblock(&args->client_psock); + if (s) { + E_STRERR("socket_nonblock"); + goto finish; + } event.data.fd = args->client_psock.fd; event.events = EPOLLIN | EPOLLET; s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, args->client_psock.fd, &event); @@ -237,21 +276,55 @@ client_mainloop_epoll(void *arg) for (i = 0; i < n; ++i) { if ((events[i].events & EPOLLERR) || (events[i].events & EPOLLHUP) || - (!(events[i].events & EPOLLIN) && - !(events[i].events & EPOLLOUT))) + (!(events[i].events & EPOLLIN))) { E("Epoll for descriptor %d failed", events[i].data.fd); E_STRERR("epoll_pwait"); - close(events[i].data.fd); - continue; + active = 0; + break; } else { - if (client_io_epoll(&events[i])) { - N("Lost connection to %s:%s: %d", - args->host_buf, args->service_buf, - args->client_psock.fd); - active = 0; - break; + if (events[i].data.fd == fwd.fd) { + dest_fd = args->client_psock.fd; + } else if (events[i].data.fd == args->client_psock.fd) { + dest_fd = fwd.fd; + } else continue; + + cs = client_io_epoll(&events[i], dest_fd); + if (cs == CON_OK) + continue; + + switch (cs) { + case CON_OK: + break; + case CON_IN_ERROR: + N("Lost connection to %s:%s: %d", + args->host_buf, args->service_buf, + args->client_psock.fd); + active = 0; + break; + case CON_IN_TERMINATED: + N("Connection terminated: %s:%s: %d", + args->host_buf, args->service_buf, + args->client_psock.fd); + active = 0; + break; + case CON_OUT_ERROR: + N("Lost forward connection to %s:%s: %d", + args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf, + fwd.fd); + active = 0; + break; + case CON_OUT_TERMINATED: + N("Forward connection terminated: %s:%s: %d", + args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf, + fwd.fd); + active = 0; + break; } + if (!active) + break; } W2("I/O forwarder failed: [fd: %d , npoll: %d]", events[i].data.fd, n); } @@ -259,17 +332,20 @@ client_mainloop_epoll(void *arg) finish: close(epoll_fd); - close(args->client_psock.fd); + socket_close(&fwd); + socket_close(&args->client_psock); free(events); free(args); return NULL; } -static int client_io_epoll(struct epoll_event *ev) +static connection_state +client_io_epoll(struct epoll_event *ev, int dest_fd) { int data_avail = 1; int has_input; - int saved_errno, io_fail = 0; + int saved_errno; + connection_state rc = CON_OK; ssize_t siz; char buf[BUFSIZ+sizeof(long)]; @@ -280,44 +356,51 @@ static int client_io_epoll(struct epoll_event *ev) if (ev->events & EPOLLIN) { has_input = 1; + errno = 0; siz = read(ev->data.fd, &buf[0], BUFSIZ); saved_errno = errno; - } else if (ev->events & EPOLLOUT) { - W("%s", "Suffering from buffer bloat"); - continue; - } + } else break; + if (saved_errno == EAGAIN) + break; switch (siz) { case -1: E_STRERR("Client read"); - if (saved_errno != EAGAIN) - io_fail = 1; + rc = CON_IN_ERROR; break; case 0: - printf("DISCONNECT !!!\n"); - io_fail = 1; + rc = CON_IN_TERMINATED; break; default: buf[siz] = 0; + D2("Read %llu bytes from fd %d", siz, ev->data.fd); break; } - if (io_fail) + if (rc != CON_OK) break; if (has_input) { - printf("INPUT: ___%s___\n", buf); - if (strncmp(buf, "QUIT", 4) == 0) - io_fail = 1; - if (strncmp(buf, "TEST", 4) == 0) { - printf("------------\n"); - write(ev->data.fd, "BLABLABLA\n", 10); + siz = write(dest_fd, &buf[0], siz); + + switch (siz) { + case -1: + rc = CON_OUT_ERROR; + break; + case 0: + rc = CON_OUT_TERMINATED; + break; + default: + D2("Written %llu bytes from fd %d to fd %d", + siz, ev->data.fd, dest_fd); + break; } } - if (io_fail) + if (rc != CON_OK) break; } - return io_fail != 0; + D2("Connection state: %d", rc); + return rc; } diff --git a/src/server_ssh.c b/src/server_ssh.c index 8e6a224..a8afeec 100644 --- a/src/server_ssh.c +++ b/src/server_ssh.c @@ -124,17 +124,17 @@ static int gen_default_keys(void) int s = 0; if (gen_export_sshkey(SSH_KEYTYPE_RSA, 1024, "./ssh_host_rsa_key")) { - W("%s", "libssh RSA key generation failed, using fallback ssh-keygen"); + W("libssh %s key generation failed, using fallback ssh-keygen", "RSA"); remove("./ssh_host_rsa_key"); s |= system("ssh-keygen -t rsa -b 1024 -f ./ssh_host_rsa_key -N '' >/dev/null 2>/dev/null"); } if (gen_export_sshkey(SSH_KEYTYPE_DSS, 1024, "./ssh_host_dsa_key")) { - W("%s", "libssh DSA key generation failed, using fallback ssh-keygen"); + W("libssh %s key generation failed, using fallback ssh-keygen", "DSA"); remove("./ssh_host_dsa_key"); s |= system("ssh-keygen -t dsa -b 1024 -f ./ssh_host_dsa_key -N '' >/dev/null 2>/dev/null"); } if (gen_export_sshkey(SSH_KEYTYPE_ECDSA, 1024, "./ssh_host_ecdsa_key")) { - W("%s", "libssh ECDSA key generation failed, using fallback ssh-keygen"); + W("libssh %s key generation failed, using fallback ssh-keygen", "ECDSA"); remove("./ssh_host_ecdsa_key"); s |= system("ssh-keygen -t ecdsa -b 256 -f ./ssh_host_ecdsa_key -N '' >/dev/null 2>/dev/null"); } diff --git a/src/socket.c b/src/socket.c index 7b16254..7fe5cbc 100644 --- a/src/socket.c +++ b/src/socket.c @@ -10,7 +10,7 @@ #include "socket.h" -static int socket_nonblock(const psocket *psock) +int socket_nonblock(const psocket *psock) { int flags; @@ -26,6 +26,7 @@ static int socket_nonblock(const psocket *psock) int socket_init_in(const char *addr, const char *port, struct addrinfo **results) { + int s; struct addrinfo hints; assert(addr || port); /* getaddrinfo wants either node or service */ @@ -35,17 +36,24 @@ int socket_init_in(const char *addr, hints.ai_socktype = SOCK_STREAM; /* TCP */ hints.ai_flags = AI_PASSIVE; /* all interfaces */ - return getaddrinfo(addr, port, &hints, results); + s = getaddrinfo(addr, port, &hints, results); + if (s) { + freeaddrinfo(*results); + *results = NULL; + } + + return s; } -int socket_bind_in(psocket *psock, struct addrinfo *results) +int socket_bind_in(psocket *psock, struct addrinfo **results) { - int fd = -1, rv, reuse_enable = 1; + int s = 1, fd = -1, rv, reuse_enable = 1; struct addrinfo *rp = NULL; - assert(psock && results); + assert(psock && results && *results); + psock->fd = -1; - for (rp = results; rp != NULL; rp = rp->ai_next) { + for (rp = *results; rp != NULL; rp = rp->ai_next) { fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (fd < 0) continue; @@ -56,7 +64,7 @@ int socket_bind_in(psocket *psock, struct addrinfo *results) } if (!rp) - return -1; + goto finalise; setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_enable, sizeof(int)); psock->fd = fd; @@ -65,8 +73,13 @@ int socket_bind_in(psocket *psock, struct addrinfo *results) psock->family = rp->ai_family; psock->socktype = rp->ai_socktype; psock->protocol = rp->ai_protocol; - freeaddrinfo(results); - return socket_nonblock(psock); + s = socket_nonblock(psock); + +finalise: + freeaddrinfo(*results); + *results = NULL; + + return s; } int socket_listen_in(psocket *psock) @@ -81,27 +94,31 @@ int socket_accept_in(const psocket *psock, psocket *client_psock) int fd; assert(psock && client_psock); + client_psock->fd = -1; client_psock->addr_len = psock->addr_len; fd = accept(psock->fd, &client_psock->addr, &client_psock->addr_len); if (fd < 0) - return -1; - if (socket_nonblock(psock)) - return -2; + return 1; + if (socket_nonblock(psock)) { + close(fd); + return 1; + } client_psock->fd = fd; return 0; } -int socket_connect_in(psocket *psock, struct addrinfo *results) +int socket_connect_in(psocket *psock, struct addrinfo **results) { - int fd = -1, rv, reuse_enable = 1; + int s = 1, fd = -1, rv, reuse_enable = 1; struct addrinfo *rp = NULL; - assert(psock && results); + assert(psock && results && *results); + psock->fd = -1; - for (rp = results; rp != NULL; rp = rp->ai_next) { + for (rp = *results; rp != NULL; rp = rp->ai_next) { fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (fd < 0) continue; @@ -112,18 +129,22 @@ int socket_connect_in(psocket *psock, struct addrinfo *results) } if (!rp) - return -1; + goto finalise; setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_enable, sizeof(int)); psock->fd = fd; psock->addr_len = rp->ai_addrlen; - psock->addr = *rp->ai_addr; + psock->addr = *(rp->ai_addr); psock->family = rp->ai_family; psock->socktype = rp->ai_socktype; psock->protocol = rp->ai_protocol; - freeaddrinfo(results); + s = socket_nonblock(psock); - return socket_nonblock(psock); +finalise: + freeaddrinfo(*results); + *results = NULL; + + return s; } int socket_addrtostr_in(const psocket *psock, @@ -154,7 +175,7 @@ int socket_reconnect_in(psocket *psock) if (psock->fd < 0) return -2; rv = connect(psock->fd, &psock->addr, psock->addr_len); - if (!rv) { + if (rv) { socket_close(psock); return -3; } @@ -164,7 +185,7 @@ int socket_reconnect_in(psocket *psock) return -4; } - return 0; + return socket_nonblock(psock); } int socket_close(psocket *psock) @@ -172,8 +193,18 @@ int socket_close(psocket *psock) int rv; assert(psock); + if (psock->fd < 0) + return 0; rv = close(psock->fd); psock->fd = -1; return rv; } + +void socket_clone(const psocket *src, psocket *dst) +{ + assert(src && dst); + + memcpy(dst, src, sizeof(*dst)); + dst->fd = -1; +} diff --git a/src/socket.h b/src/socket.h index 861a8b3..d2bb160 100644 --- a/src/socket.h +++ b/src/socket.h @@ -15,16 +15,18 @@ typedef struct psocket { } psocket; +int socket_nonblock(const psocket *psock); + int socket_init_in(const char *addr, const char *port, struct addrinfo **results); -int socket_bind_in(psocket *psock, struct addrinfo *results); +int socket_bind_in(psocket *psock, struct addrinfo **results); int socket_listen_in(psocket *psock); int socket_accept_in(const psocket *psock, psocket *client_psock); -int socket_connect_in(psocket *psock, struct addrinfo *results); +int socket_connect_in(psocket *psock, struct addrinfo **results); int socket_addrtostr_in(const psocket *psock, char hbuf[NI_MAXHOST], char sbuf[NI_MAXSERV]); @@ -33,4 +35,6 @@ int socket_reconnect_in(psocket *psock); int socket_close(psocket *psock); +void socket_clone(const psocket *src, psocket *dst); + #endif |