diff options
Diffstat (limited to 'src/server.c')
-rw-r--r-- | src/server.c | 153 |
1 files changed, 118 insertions, 35 deletions
diff --git a/src/server.c b/src/server.c index e1b93cc..31f7e45 100644 --- a/src/server.c +++ b/src/server.c @@ -18,11 +18,17 @@ typedef struct client_thread_args { const server_ctx *server_ctx; } client_thread_args; +typedef enum connection_state { + CON_OK, CON_IN_TERMINATED, CON_OUT_TERMINATED, + CON_IN_ERROR, CON_OUT_ERROR +} connection_state; + static int server_accept_client(server_ctx *ctx[], size_t siz, struct epoll_event *event); static void * client_mainloop_epoll(void *arg); -static int client_io_epoll(struct epoll_event *ev); +static connection_state +client_io_epoll(struct epoll_event *ev, int dest_fd); int server_init_ctx(server_ctx **ctx, forward_ctx *fwd_ctx) @@ -54,7 +60,7 @@ int server_setup(server_ctx *ctx, E_GAIERR(s, "Could not initialise server socket"); return 1; } - if (socket_bind_in(&ctx->sock, srv_addr)) { + if (socket_bind_in(&ctx->sock, &srv_addr)) { E_STRERR("Could not bind server socket"); return 1; } @@ -120,13 +126,14 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz) assert(ctx); assert(siz > 0 && siz < POTD_MAXFD); + signal(SIGPIPE, SIG_IGN); sigemptyset(&eset); while (1) { int n, i; n = epoll_pwait(epoll_fd, events, POTD_MAXEVENTS, -1, &eset); if (n < 0) - return 1; + goto error; for (i = 0; i < n; ++i) { if ((events[i].events & EPOLLERR) || @@ -149,6 +156,9 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz) free(events); return 0; +error: + free(events); + return 1; } static int server_accept_client(server_ctx *ctx[], @@ -188,7 +198,7 @@ static int server_accept_client(server_ctx *ctx[], return 1; error: - close(args->client_psock.fd); + socket_close(&args->client_psock); free(args); return 0; } @@ -201,10 +211,12 @@ static void * client_mainloop_epoll(void *arg) { client_thread_args *args; - int s, epoll_fd, active = 1; + int s, epoll_fd, dest_fd, active = 1; struct epoll_event event = {0,{0}}; struct epoll_event *events; sigset_t eset; + connection_state cs; + psocket fwd; assert(arg); args = (client_thread_args *) arg; @@ -218,6 +230,33 @@ client_mainloop_epoll(void *arg) goto finish; } + if (fwd_connect(args->server_ctx->fwd_ctx, &fwd)) { + E("Forward connection to %s:%s failed", + args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf); + E_STRERR("Forward connect"); + goto finish; + } + N("Forwarding connection to %s:%s: %d", args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf, fwd.fd); + + event.data.fd = fwd.fd; + event.events = EPOLLIN | EPOLLET; + s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fwd.fd, &event); + if (s) { + E_STRERR("Forward epoll_ctl"); + goto finish; + } + + /* + * We got the client socket from our main thread, so fd flags like + * O_NONBLOCK are not inherited! + */ + s = socket_nonblock(&args->client_psock); + if (s) { + E_STRERR("socket_nonblock"); + goto finish; + } event.data.fd = args->client_psock.fd; event.events = EPOLLIN | EPOLLET; s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, args->client_psock.fd, &event); @@ -237,21 +276,55 @@ client_mainloop_epoll(void *arg) for (i = 0; i < n; ++i) { if ((events[i].events & EPOLLERR) || (events[i].events & EPOLLHUP) || - (!(events[i].events & EPOLLIN) && - !(events[i].events & EPOLLOUT))) + (!(events[i].events & EPOLLIN))) { E("Epoll for descriptor %d failed", events[i].data.fd); E_STRERR("epoll_pwait"); - close(events[i].data.fd); - continue; + active = 0; + break; } else { - if (client_io_epoll(&events[i])) { - N("Lost connection to %s:%s: %d", - args->host_buf, args->service_buf, - args->client_psock.fd); - active = 0; - break; + if (events[i].data.fd == fwd.fd) { + dest_fd = args->client_psock.fd; + } else if (events[i].data.fd == args->client_psock.fd) { + dest_fd = fwd.fd; + } else continue; + + cs = client_io_epoll(&events[i], dest_fd); + if (cs == CON_OK) + continue; + + switch (cs) { + case CON_OK: + break; + case CON_IN_ERROR: + N("Lost connection to %s:%s: %d", + args->host_buf, args->service_buf, + args->client_psock.fd); + active = 0; + break; + case CON_IN_TERMINATED: + N("Connection terminated: %s:%s: %d", + args->host_buf, args->service_buf, + args->client_psock.fd); + active = 0; + break; + case CON_OUT_ERROR: + N("Lost forward connection to %s:%s: %d", + args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf, + fwd.fd); + active = 0; + break; + case CON_OUT_TERMINATED: + N("Forward connection terminated: %s:%s: %d", + args->server_ctx->fwd_ctx->host_buf, + args->server_ctx->fwd_ctx->service_buf, + fwd.fd); + active = 0; + break; } + if (!active) + break; } W2("I/O forwarder failed: [fd: %d , npoll: %d]", events[i].data.fd, n); } @@ -259,17 +332,20 @@ client_mainloop_epoll(void *arg) finish: close(epoll_fd); - close(args->client_psock.fd); + socket_close(&fwd); + socket_close(&args->client_psock); free(events); free(args); return NULL; } -static int client_io_epoll(struct epoll_event *ev) +static connection_state +client_io_epoll(struct epoll_event *ev, int dest_fd) { int data_avail = 1; int has_input; - int saved_errno, io_fail = 0; + int saved_errno; + connection_state rc = CON_OK; ssize_t siz; char buf[BUFSIZ+sizeof(long)]; @@ -280,44 +356,51 @@ static int client_io_epoll(struct epoll_event *ev) if (ev->events & EPOLLIN) { has_input = 1; + errno = 0; siz = read(ev->data.fd, &buf[0], BUFSIZ); saved_errno = errno; - } else if (ev->events & EPOLLOUT) { - W("%s", "Suffering from buffer bloat"); - continue; - } + } else break; + if (saved_errno == EAGAIN) + break; switch (siz) { case -1: E_STRERR("Client read"); - if (saved_errno != EAGAIN) - io_fail = 1; + rc = CON_IN_ERROR; break; case 0: - printf("DISCONNECT !!!\n"); - io_fail = 1; + rc = CON_IN_TERMINATED; break; default: buf[siz] = 0; + D2("Read %llu bytes from fd %d", siz, ev->data.fd); break; } - if (io_fail) + if (rc != CON_OK) break; if (has_input) { - printf("INPUT: ___%s___\n", buf); - if (strncmp(buf, "QUIT", 4) == 0) - io_fail = 1; - if (strncmp(buf, "TEST", 4) == 0) { - printf("------------\n"); - write(ev->data.fd, "BLABLABLA\n", 10); + siz = write(dest_fd, &buf[0], siz); + + switch (siz) { + case -1: + rc = CON_OUT_ERROR; + break; + case 0: + rc = CON_OUT_TERMINATED; + break; + default: + D2("Written %llu bytes from fd %d to fd %d", + siz, ev->data.fd, dest_fd); + break; } } - if (io_fail) + if (rc != CON_OK) break; } - return io_fail != 0; + D2("Connection state: %d", rc); + return rc; } |