aboutsummaryrefslogtreecommitdiff
path: root/src/server.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/server.c')
-rw-r--r--src/server.c153
1 files changed, 118 insertions, 35 deletions
diff --git a/src/server.c b/src/server.c
index e1b93cc..31f7e45 100644
--- a/src/server.c
+++ b/src/server.c
@@ -18,11 +18,17 @@ typedef struct client_thread_args {
const server_ctx *server_ctx;
} client_thread_args;
+typedef enum connection_state {
+ CON_OK, CON_IN_TERMINATED, CON_OUT_TERMINATED,
+ CON_IN_ERROR, CON_OUT_ERROR
+} connection_state;
+
static int server_accept_client(server_ctx *ctx[],
size_t siz, struct epoll_event *event);
static void *
client_mainloop_epoll(void *arg);
-static int client_io_epoll(struct epoll_event *ev);
+static connection_state
+client_io_epoll(struct epoll_event *ev, int dest_fd);
int server_init_ctx(server_ctx **ctx, forward_ctx *fwd_ctx)
@@ -54,7 +60,7 @@ int server_setup(server_ctx *ctx,
E_GAIERR(s, "Could not initialise server socket");
return 1;
}
- if (socket_bind_in(&ctx->sock, srv_addr)) {
+ if (socket_bind_in(&ctx->sock, &srv_addr)) {
E_STRERR("Could not bind server socket");
return 1;
}
@@ -120,13 +126,14 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz)
assert(ctx);
assert(siz > 0 && siz < POTD_MAXFD);
+ signal(SIGPIPE, SIG_IGN);
sigemptyset(&eset);
while (1) {
int n, i;
n = epoll_pwait(epoll_fd, events, POTD_MAXEVENTS, -1, &eset);
if (n < 0)
- return 1;
+ goto error;
for (i = 0; i < n; ++i) {
if ((events[i].events & EPOLLERR) ||
@@ -149,6 +156,9 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz)
free(events);
return 0;
+error:
+ free(events);
+ return 1;
}
static int server_accept_client(server_ctx *ctx[],
@@ -188,7 +198,7 @@ static int server_accept_client(server_ctx *ctx[],
return 1;
error:
- close(args->client_psock.fd);
+ socket_close(&args->client_psock);
free(args);
return 0;
}
@@ -201,10 +211,12 @@ static void *
client_mainloop_epoll(void *arg)
{
client_thread_args *args;
- int s, epoll_fd, active = 1;
+ int s, epoll_fd, dest_fd, active = 1;
struct epoll_event event = {0,{0}};
struct epoll_event *events;
sigset_t eset;
+ connection_state cs;
+ psocket fwd;
assert(arg);
args = (client_thread_args *) arg;
@@ -218,6 +230,33 @@ client_mainloop_epoll(void *arg)
goto finish;
}
+ if (fwd_connect(args->server_ctx->fwd_ctx, &fwd)) {
+ E("Forward connection to %s:%s failed",
+ args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf);
+ E_STRERR("Forward connect");
+ goto finish;
+ }
+ N("Forwarding connection to %s:%s: %d", args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf, fwd.fd);
+
+ event.data.fd = fwd.fd;
+ event.events = EPOLLIN | EPOLLET;
+ s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fwd.fd, &event);
+ if (s) {
+ E_STRERR("Forward epoll_ctl");
+ goto finish;
+ }
+
+ /*
+ * We got the client socket from our main thread, so fd flags like
+ * O_NONBLOCK are not inherited!
+ */
+ s = socket_nonblock(&args->client_psock);
+ if (s) {
+ E_STRERR("socket_nonblock");
+ goto finish;
+ }
event.data.fd = args->client_psock.fd;
event.events = EPOLLIN | EPOLLET;
s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, args->client_psock.fd, &event);
@@ -237,21 +276,55 @@ client_mainloop_epoll(void *arg)
for (i = 0; i < n; ++i) {
if ((events[i].events & EPOLLERR) ||
(events[i].events & EPOLLHUP) ||
- (!(events[i].events & EPOLLIN) &&
- !(events[i].events & EPOLLOUT)))
+ (!(events[i].events & EPOLLIN)))
{
E("Epoll for descriptor %d failed", events[i].data.fd);
E_STRERR("epoll_pwait");
- close(events[i].data.fd);
- continue;
+ active = 0;
+ break;
} else {
- if (client_io_epoll(&events[i])) {
- N("Lost connection to %s:%s: %d",
- args->host_buf, args->service_buf,
- args->client_psock.fd);
- active = 0;
- break;
+ if (events[i].data.fd == fwd.fd) {
+ dest_fd = args->client_psock.fd;
+ } else if (events[i].data.fd == args->client_psock.fd) {
+ dest_fd = fwd.fd;
+ } else continue;
+
+ cs = client_io_epoll(&events[i], dest_fd);
+ if (cs == CON_OK)
+ continue;
+
+ switch (cs) {
+ case CON_OK:
+ break;
+ case CON_IN_ERROR:
+ N("Lost connection to %s:%s: %d",
+ args->host_buf, args->service_buf,
+ args->client_psock.fd);
+ active = 0;
+ break;
+ case CON_IN_TERMINATED:
+ N("Connection terminated: %s:%s: %d",
+ args->host_buf, args->service_buf,
+ args->client_psock.fd);
+ active = 0;
+ break;
+ case CON_OUT_ERROR:
+ N("Lost forward connection to %s:%s: %d",
+ args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf,
+ fwd.fd);
+ active = 0;
+ break;
+ case CON_OUT_TERMINATED:
+ N("Forward connection terminated: %s:%s: %d",
+ args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf,
+ fwd.fd);
+ active = 0;
+ break;
}
+ if (!active)
+ break;
}
W2("I/O forwarder failed: [fd: %d , npoll: %d]", events[i].data.fd, n);
}
@@ -259,17 +332,20 @@ client_mainloop_epoll(void *arg)
finish:
close(epoll_fd);
- close(args->client_psock.fd);
+ socket_close(&fwd);
+ socket_close(&args->client_psock);
free(events);
free(args);
return NULL;
}
-static int client_io_epoll(struct epoll_event *ev)
+static connection_state
+client_io_epoll(struct epoll_event *ev, int dest_fd)
{
int data_avail = 1;
int has_input;
- int saved_errno, io_fail = 0;
+ int saved_errno;
+ connection_state rc = CON_OK;
ssize_t siz;
char buf[BUFSIZ+sizeof(long)];
@@ -280,44 +356,51 @@ static int client_io_epoll(struct epoll_event *ev)
if (ev->events & EPOLLIN) {
has_input = 1;
+ errno = 0;
siz = read(ev->data.fd, &buf[0], BUFSIZ);
saved_errno = errno;
- } else if (ev->events & EPOLLOUT) {
- W("%s", "Suffering from buffer bloat");
- continue;
- }
+ } else break;
+ if (saved_errno == EAGAIN)
+ break;
switch (siz) {
case -1:
E_STRERR("Client read");
- if (saved_errno != EAGAIN)
- io_fail = 1;
+ rc = CON_IN_ERROR;
break;
case 0:
- printf("DISCONNECT !!!\n");
- io_fail = 1;
+ rc = CON_IN_TERMINATED;
break;
default:
buf[siz] = 0;
+ D2("Read %llu bytes from fd %d", siz, ev->data.fd);
break;
}
- if (io_fail)
+ if (rc != CON_OK)
break;
if (has_input) {
- printf("INPUT: ___%s___\n", buf);
- if (strncmp(buf, "QUIT", 4) == 0)
- io_fail = 1;
- if (strncmp(buf, "TEST", 4) == 0) {
- printf("------------\n");
- write(ev->data.fd, "BLABLABLA\n", 10);
+ siz = write(dest_fd, &buf[0], siz);
+
+ switch (siz) {
+ case -1:
+ rc = CON_OUT_ERROR;
+ break;
+ case 0:
+ rc = CON_OUT_TERMINATED;
+ break;
+ default:
+ D2("Written %llu bytes from fd %d to fd %d",
+ siz, ev->data.fd, dest_fd);
+ break;
}
}
- if (io_fail)
+ if (rc != CON_OK)
break;
}
- return io_fail != 0;
+ D2("Connection state: %d", rc);
+ return rc;
}