#include #include #include #include #include #include #include #include #include "server.h" #include "socket.h" #include "utils.h" #include "log.h" typedef struct client_thread_args { pthread_t self; psocket client_psock; char host_buf[NI_MAXHOST], service_buf[NI_MAXSERV]; const server_ctx *server_ctx; } client_thread_args; typedef enum connection_state { CON_OK, CON_IN_TERMINATED, CON_OUT_TERMINATED, CON_IN_ERROR, CON_OUT_ERROR } connection_state; static int server_accept_client(server_ctx *ctx[], size_t siz, struct epoll_event *event); static void * client_mainloop_epoll(void *arg); static connection_state client_io_epoll(struct epoll_event *ev, int dest_fd); static int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz); void server_init_ctx(server_ctx **ctx, forward_ctx *fwd_ctx) { assert(ctx && fwd_ctx); if (!*ctx) *ctx = (server_ctx *) malloc(sizeof(**ctx)); assert(*ctx); memset(*ctx, 0, sizeof(**ctx)); (*ctx)->fwd_ctx = fwd_ctx; } int server_setup(server_ctx *ctx, const char *listen_addr, const char *listen_port) { int s; struct addrinfo *srv_addr = NULL; assert(ctx); assert(listen_addr || listen_port); D2("Try to listen on %s:%s", (listen_addr ? listen_addr : "*"), listen_port); s = socket_init_in(listen_addr, listen_port, &srv_addr); if (s) { E_GAIERR(s, "Could not initialise server socket"); return 1; } if (socket_bind_in(&ctx->sock, &srv_addr)) { E_STRERR("Could not bind server socket"); return 1; } if (socket_listen_in(&ctx->sock)) { E_STRERR("Could not listen on server socket"); return 1; } return 0; } int server_validate_ctx(const server_ctx *ctx) { assert(ctx && ctx->fwd_ctx); assert(ctx->sock.fd >= 0 && ctx->sock.addr_len > 0); return 0; } int server_setup_epoll(server_ctx *ctx[], size_t siz) { int s, fd = epoll_create1(0); /* flags == 0 -> obsolete size arg is dropped */ struct epoll_event ev; assert(ctx); assert(siz > 0 && siz < POTD_MAXFD); if (fd < 0) return -1; for (size_t i = 0; i < siz; ++i) { memset(&ev, 0, sizeof(ev)); ev.data.fd = ctx[i]->sock.fd; ev.events = EPOLLIN | EPOLLET; s = socket_addrtostr_in(&ctx[i]->sock, ctx[i]->host_buf, ctx[i]->service_buf); if (s) { E_GAIERR(s, "Convert socket address to string"); return -2; } N("Redirector service listening on %s:%s: %d", ctx[i]->host_buf, ctx[i]->service_buf, ev.data.fd); s = epoll_ctl(fd, EPOLL_CTL_ADD, ctx[i]->sock.fd, &ev); if (s) { close(fd); return -3; } } return fd; } pid_t server_daemonize(int epoll_fd, server_ctx *ctx[], size_t siz) { pid_t p; int s; size_t i; assert(ctx); assert(siz > 0); for (i = 0; i < siz; ++i) { assert(ctx[i]); s = socket_addrtostr_in(&ctx[i]->sock, ctx[i]->host_buf, ctx[i]->service_buf); if (s) { E_GAIERR(s, "Could not initialise server daemon socket"); return 1; } } p = fork(); switch (p) { case -1: W_STRERR("Server daemonsize"); return -1; case 0: N("%s", "Server daemon mainloop"); server_mainloop_epoll(epoll_fd, ctx, siz); break; } D2("Server daemon pid: %d", p); return p; } static int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz) { static struct epoll_event *events = NULL; sigset_t eset; if (!events) events = (struct epoll_event *) calloc(POTD_MAXEVENTS, sizeof(*events)); assert(events); assert(ctx); assert(siz > 0 && siz < POTD_MAXFD); set_procname("[potd] server"); assert( set_child_sighandler() == 0 ); sigemptyset(&eset); while (1) { int n, i; n = epoll_pwait(epoll_fd, events, POTD_MAXEVENTS, -1, &eset); if (n < 0) goto error; for (i = 0; i < n; ++i) { if ((events[i].events & EPOLLERR) || (events[i].events & EPOLLHUP) || (!(events[i].events & EPOLLIN))) { E("Epoll for descriptor %d failed", events[i].data.fd); E_STRERR("epoll_wait"); close(events[i].data.fd); continue; } else { if (server_accept_client(ctx, siz, &events[i])) { /* new client connection, accept succeeded */ continue; } W2("Server accept client failed: [fd: %d , npoll: %d]", events[i].data.fd, n); } } } return 0; error: return 1; } static int server_accept_client(server_ctx *ctx[], size_t siz, struct epoll_event *event) { size_t i; int s; client_thread_args *args; for (i = 0; i < siz; ++i) { if (ctx[i]->sock.fd == event->data.fd) { args = (client_thread_args *) calloc(1, sizeof(*args)); assert(args); if (socket_accept_in(&ctx[i]->sock, &args->client_psock)) { E_STRERR("Could not accept client connection"); goto error; } args->server_ctx = ctx[i]; s = socket_addrtostr_in(&args->client_psock, args->host_buf, args->service_buf); if (s) { E_GAIERR(s, "Convert socket address to string"); goto error; } N2("New connection from %s:%s to %s:%s: %d", args->host_buf, args->service_buf, ctx[i]->host_buf, ctx[i]->service_buf, args->client_psock.fd); if (pthread_create(&args->self, NULL, client_mainloop_epoll, args)) { E_STRERR("Thread creation"); goto error; } return 1; error: socket_close(&args->client_psock); free(args); return 0; } } return 0; } static void * client_mainloop_epoll(void *arg) { client_thread_args *args; int s, epoll_fd, dest_fd, active = 1; struct epoll_event event = {0,{0}}; struct epoll_event *events; sigset_t eset; connection_state cs; psocket fwd; assert(arg); args = (client_thread_args *) arg; pthread_detach(args->self); events = (struct epoll_event *) calloc(POTD_MAXEVENTS, sizeof(*events)); assert(events); epoll_fd = epoll_create1(0); if (epoll_fd < 0) { E_STRERR("Client epoll_create1"); goto finish; } if (fwd_connect(args->server_ctx->fwd_ctx, &fwd)) { E("Forward connection to %s:%s failed", args->server_ctx->fwd_ctx->host_buf, args->server_ctx->fwd_ctx->service_buf); E_STRERR("Forward connect"); goto finish; } N("Forwarding connection to %s:%s: %d", args->server_ctx->fwd_ctx->host_buf, args->server_ctx->fwd_ctx->service_buf, fwd.fd); event.data.fd = fwd.fd; event.events = EPOLLIN | EPOLLET; s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fwd.fd, &event); if (s) { E_STRERR("Forward epoll_ctl"); goto finish; } /* * We got the client socket from our main thread, so fd flags like * O_NONBLOCK are not inherited! */ s = socket_nonblock(&args->client_psock); if (s) { E_STRERR("socket_nonblock"); goto finish; } event.data.fd = args->client_psock.fd; event.events = EPOLLIN | EPOLLET; s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, args->client_psock.fd, &event); if (s) { E_STRERR("Client epoll_ctl"); goto finish; } sigemptyset(&eset); while (active) { int n, i; n = epoll_pwait(epoll_fd, events, POTD_MAXEVENTS, -1, &eset); if (n < 0) break; for (i = 0; i < n; ++i) { if ((events[i].events & EPOLLERR) || (events[i].events & EPOLLHUP) || (!(events[i].events & EPOLLIN))) { E("Epoll for descriptor %d failed", events[i].data.fd); E_STRERR("epoll_pwait"); active = 0; break; } else { if (events[i].data.fd == fwd.fd) { dest_fd = args->client_psock.fd; } else if (events[i].data.fd == args->client_psock.fd) { dest_fd = fwd.fd; } else continue; cs = client_io_epoll(&events[i], dest_fd); if (cs == CON_OK) continue; switch (cs) { case CON_OK: break; case CON_IN_ERROR: N("Lost connection to %s:%s: %d", args->host_buf, args->service_buf, args->client_psock.fd); active = 0; break; case CON_IN_TERMINATED: N("Connection terminated: %s:%s: %d", args->host_buf, args->service_buf, args->client_psock.fd); active = 0; break; case CON_OUT_ERROR: N("Lost forward connection to %s:%s: %d", args->server_ctx->fwd_ctx->host_buf, args->server_ctx->fwd_ctx->service_buf, fwd.fd); active = 0; break; case CON_OUT_TERMINATED: N("Forward connection terminated: %s:%s: %d", args->server_ctx->fwd_ctx->host_buf, args->server_ctx->fwd_ctx->service_buf, fwd.fd); active = 0; break; } if (!active) break; } W2("I/O forwarder failed: [fd: %d , npoll: %d]", events[i].data.fd, n); } } finish: close(epoll_fd); socket_close(&fwd); socket_close(&args->client_psock); free(events); free(args); return NULL; } static connection_state client_io_epoll(struct epoll_event *ev, int dest_fd) { int data_avail = 1; int has_input; int saved_errno; connection_state rc = CON_OK; ssize_t siz; char buf[BUFSIZ+sizeof(long)]; while (data_avail) { has_input = 0; saved_errno = 0; siz = -1; if (ev->events & EPOLLIN) { has_input = 1; errno = 0; siz = read(ev->data.fd, &buf[0], BUFSIZ); saved_errno = errno; } else break; if (saved_errno == EAGAIN) break; switch (siz) { case -1: E_STRERR("Client read"); rc = CON_IN_ERROR; break; case 0: rc = CON_IN_TERMINATED; break; default: buf[siz] = 0; D2("Read %lu bytes from fd %d", siz, ev->data.fd); break; } if (rc != CON_OK) break; if (has_input) { siz = write(dest_fd, &buf[0], siz); switch (siz) { case -1: rc = CON_OUT_ERROR; break; case 0: rc = CON_OUT_TERMINATED; break; default: D2("Written %lu bytes from fd %d to fd %d", siz, ev->data.fd, dest_fd); break; } } if (rc != CON_OK) break; } D2("Connection state: %d", rc); return rc; }