aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/forward.c10
-rw-r--r--src/forward.h2
-rw-r--r--src/server.c153
-rw-r--r--src/server_ssh.c6
-rw-r--r--src/socket.c75
-rw-r--r--src/socket.h8
6 files changed, 191 insertions, 63 deletions
diff --git a/src/forward.c b/src/forward.c
index 54b64fe..f12a6f8 100644
--- a/src/forward.c
+++ b/src/forward.c
@@ -36,7 +36,7 @@ int fwd_setup(forward_ctx *ctx, const char *host, const char *port)
return 1;
if (ctx->fwd_cbs.on_listen(ctx, host, port))
return 1;
- if (socket_connect_in(&ctx->sock, fwd_addr)) {
+ if (socket_connect_in(&ctx->sock, &fwd_addr)) {
E_STRERR("Connection to forward socket");
return 1;
}
@@ -63,3 +63,11 @@ int fwd_validate_ctx(const forward_ctx *ctx)
return 0;
}
+
+int fwd_connect(forward_ctx *ctx, psocket *fwd_client)
+{
+ assert(ctx);
+ socket_clone(&ctx->sock, fwd_client);
+
+ return socket_reconnect_in(fwd_client);
+}
diff --git a/src/forward.h b/src/forward.h
index 4d9c88f..386dc94 100644
--- a/src/forward.h
+++ b/src/forward.h
@@ -30,4 +30,6 @@ int fwd_setup(forward_ctx *ctx, const char *host, const char *port);
int fwd_validate_ctx(const forward_ctx *ctx);
+int fwd_connect(forward_ctx *ctx, psocket *fwd_client);
+
#endif
diff --git a/src/server.c b/src/server.c
index e1b93cc..31f7e45 100644
--- a/src/server.c
+++ b/src/server.c
@@ -18,11 +18,17 @@ typedef struct client_thread_args {
const server_ctx *server_ctx;
} client_thread_args;
+typedef enum connection_state {
+ CON_OK, CON_IN_TERMINATED, CON_OUT_TERMINATED,
+ CON_IN_ERROR, CON_OUT_ERROR
+} connection_state;
+
static int server_accept_client(server_ctx *ctx[],
size_t siz, struct epoll_event *event);
static void *
client_mainloop_epoll(void *arg);
-static int client_io_epoll(struct epoll_event *ev);
+static connection_state
+client_io_epoll(struct epoll_event *ev, int dest_fd);
int server_init_ctx(server_ctx **ctx, forward_ctx *fwd_ctx)
@@ -54,7 +60,7 @@ int server_setup(server_ctx *ctx,
E_GAIERR(s, "Could not initialise server socket");
return 1;
}
- if (socket_bind_in(&ctx->sock, srv_addr)) {
+ if (socket_bind_in(&ctx->sock, &srv_addr)) {
E_STRERR("Could not bind server socket");
return 1;
}
@@ -120,13 +126,14 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz)
assert(ctx);
assert(siz > 0 && siz < POTD_MAXFD);
+ signal(SIGPIPE, SIG_IGN);
sigemptyset(&eset);
while (1) {
int n, i;
n = epoll_pwait(epoll_fd, events, POTD_MAXEVENTS, -1, &eset);
if (n < 0)
- return 1;
+ goto error;
for (i = 0; i < n; ++i) {
if ((events[i].events & EPOLLERR) ||
@@ -149,6 +156,9 @@ int server_mainloop_epoll(int epoll_fd, server_ctx *ctx[], size_t siz)
free(events);
return 0;
+error:
+ free(events);
+ return 1;
}
static int server_accept_client(server_ctx *ctx[],
@@ -188,7 +198,7 @@ static int server_accept_client(server_ctx *ctx[],
return 1;
error:
- close(args->client_psock.fd);
+ socket_close(&args->client_psock);
free(args);
return 0;
}
@@ -201,10 +211,12 @@ static void *
client_mainloop_epoll(void *arg)
{
client_thread_args *args;
- int s, epoll_fd, active = 1;
+ int s, epoll_fd, dest_fd, active = 1;
struct epoll_event event = {0,{0}};
struct epoll_event *events;
sigset_t eset;
+ connection_state cs;
+ psocket fwd;
assert(arg);
args = (client_thread_args *) arg;
@@ -218,6 +230,33 @@ client_mainloop_epoll(void *arg)
goto finish;
}
+ if (fwd_connect(args->server_ctx->fwd_ctx, &fwd)) {
+ E("Forward connection to %s:%s failed",
+ args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf);
+ E_STRERR("Forward connect");
+ goto finish;
+ }
+ N("Forwarding connection to %s:%s: %d", args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf, fwd.fd);
+
+ event.data.fd = fwd.fd;
+ event.events = EPOLLIN | EPOLLET;
+ s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fwd.fd, &event);
+ if (s) {
+ E_STRERR("Forward epoll_ctl");
+ goto finish;
+ }
+
+ /*
+ * We got the client socket from our main thread, so fd flags like
+ * O_NONBLOCK are not inherited!
+ */
+ s = socket_nonblock(&args->client_psock);
+ if (s) {
+ E_STRERR("socket_nonblock");
+ goto finish;
+ }
event.data.fd = args->client_psock.fd;
event.events = EPOLLIN | EPOLLET;
s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, args->client_psock.fd, &event);
@@ -237,21 +276,55 @@ client_mainloop_epoll(void *arg)
for (i = 0; i < n; ++i) {
if ((events[i].events & EPOLLERR) ||
(events[i].events & EPOLLHUP) ||
- (!(events[i].events & EPOLLIN) &&
- !(events[i].events & EPOLLOUT)))
+ (!(events[i].events & EPOLLIN)))
{
E("Epoll for descriptor %d failed", events[i].data.fd);
E_STRERR("epoll_pwait");
- close(events[i].data.fd);
- continue;
+ active = 0;
+ break;
} else {
- if (client_io_epoll(&events[i])) {
- N("Lost connection to %s:%s: %d",
- args->host_buf, args->service_buf,
- args->client_psock.fd);
- active = 0;
- break;
+ if (events[i].data.fd == fwd.fd) {
+ dest_fd = args->client_psock.fd;
+ } else if (events[i].data.fd == args->client_psock.fd) {
+ dest_fd = fwd.fd;
+ } else continue;
+
+ cs = client_io_epoll(&events[i], dest_fd);
+ if (cs == CON_OK)
+ continue;
+
+ switch (cs) {
+ case CON_OK:
+ break;
+ case CON_IN_ERROR:
+ N("Lost connection to %s:%s: %d",
+ args->host_buf, args->service_buf,
+ args->client_psock.fd);
+ active = 0;
+ break;
+ case CON_IN_TERMINATED:
+ N("Connection terminated: %s:%s: %d",
+ args->host_buf, args->service_buf,
+ args->client_psock.fd);
+ active = 0;
+ break;
+ case CON_OUT_ERROR:
+ N("Lost forward connection to %s:%s: %d",
+ args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf,
+ fwd.fd);
+ active = 0;
+ break;
+ case CON_OUT_TERMINATED:
+ N("Forward connection terminated: %s:%s: %d",
+ args->server_ctx->fwd_ctx->host_buf,
+ args->server_ctx->fwd_ctx->service_buf,
+ fwd.fd);
+ active = 0;
+ break;
}
+ if (!active)
+ break;
}
W2("I/O forwarder failed: [fd: %d , npoll: %d]", events[i].data.fd, n);
}
@@ -259,17 +332,20 @@ client_mainloop_epoll(void *arg)
finish:
close(epoll_fd);
- close(args->client_psock.fd);
+ socket_close(&fwd);
+ socket_close(&args->client_psock);
free(events);
free(args);
return NULL;
}
-static int client_io_epoll(struct epoll_event *ev)
+static connection_state
+client_io_epoll(struct epoll_event *ev, int dest_fd)
{
int data_avail = 1;
int has_input;
- int saved_errno, io_fail = 0;
+ int saved_errno;
+ connection_state rc = CON_OK;
ssize_t siz;
char buf[BUFSIZ+sizeof(long)];
@@ -280,44 +356,51 @@ static int client_io_epoll(struct epoll_event *ev)
if (ev->events & EPOLLIN) {
has_input = 1;
+ errno = 0;
siz = read(ev->data.fd, &buf[0], BUFSIZ);
saved_errno = errno;
- } else if (ev->events & EPOLLOUT) {
- W("%s", "Suffering from buffer bloat");
- continue;
- }
+ } else break;
+ if (saved_errno == EAGAIN)
+ break;
switch (siz) {
case -1:
E_STRERR("Client read");
- if (saved_errno != EAGAIN)
- io_fail = 1;
+ rc = CON_IN_ERROR;
break;
case 0:
- printf("DISCONNECT !!!\n");
- io_fail = 1;
+ rc = CON_IN_TERMINATED;
break;
default:
buf[siz] = 0;
+ D2("Read %llu bytes from fd %d", siz, ev->data.fd);
break;
}
- if (io_fail)
+ if (rc != CON_OK)
break;
if (has_input) {
- printf("INPUT: ___%s___\n", buf);
- if (strncmp(buf, "QUIT", 4) == 0)
- io_fail = 1;
- if (strncmp(buf, "TEST", 4) == 0) {
- printf("------------\n");
- write(ev->data.fd, "BLABLABLA\n", 10);
+ siz = write(dest_fd, &buf[0], siz);
+
+ switch (siz) {
+ case -1:
+ rc = CON_OUT_ERROR;
+ break;
+ case 0:
+ rc = CON_OUT_TERMINATED;
+ break;
+ default:
+ D2("Written %llu bytes from fd %d to fd %d",
+ siz, ev->data.fd, dest_fd);
+ break;
}
}
- if (io_fail)
+ if (rc != CON_OK)
break;
}
- return io_fail != 0;
+ D2("Connection state: %d", rc);
+ return rc;
}
diff --git a/src/server_ssh.c b/src/server_ssh.c
index 8e6a224..a8afeec 100644
--- a/src/server_ssh.c
+++ b/src/server_ssh.c
@@ -124,17 +124,17 @@ static int gen_default_keys(void)
int s = 0;
if (gen_export_sshkey(SSH_KEYTYPE_RSA, 1024, "./ssh_host_rsa_key")) {
- W("%s", "libssh RSA key generation failed, using fallback ssh-keygen");
+ W("libssh %s key generation failed, using fallback ssh-keygen", "RSA");
remove("./ssh_host_rsa_key");
s |= system("ssh-keygen -t rsa -b 1024 -f ./ssh_host_rsa_key -N '' >/dev/null 2>/dev/null");
}
if (gen_export_sshkey(SSH_KEYTYPE_DSS, 1024, "./ssh_host_dsa_key")) {
- W("%s", "libssh DSA key generation failed, using fallback ssh-keygen");
+ W("libssh %s key generation failed, using fallback ssh-keygen", "DSA");
remove("./ssh_host_dsa_key");
s |= system("ssh-keygen -t dsa -b 1024 -f ./ssh_host_dsa_key -N '' >/dev/null 2>/dev/null");
}
if (gen_export_sshkey(SSH_KEYTYPE_ECDSA, 1024, "./ssh_host_ecdsa_key")) {
- W("%s", "libssh ECDSA key generation failed, using fallback ssh-keygen");
+ W("libssh %s key generation failed, using fallback ssh-keygen", "ECDSA");
remove("./ssh_host_ecdsa_key");
s |= system("ssh-keygen -t ecdsa -b 256 -f ./ssh_host_ecdsa_key -N '' >/dev/null 2>/dev/null");
}
diff --git a/src/socket.c b/src/socket.c
index 7b16254..7fe5cbc 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -10,7 +10,7 @@
#include "socket.h"
-static int socket_nonblock(const psocket *psock)
+int socket_nonblock(const psocket *psock)
{
int flags;
@@ -26,6 +26,7 @@ static int socket_nonblock(const psocket *psock)
int socket_init_in(const char *addr,
const char *port, struct addrinfo **results)
{
+ int s;
struct addrinfo hints;
assert(addr || port); /* getaddrinfo wants either node or service */
@@ -35,17 +36,24 @@ int socket_init_in(const char *addr,
hints.ai_socktype = SOCK_STREAM; /* TCP */
hints.ai_flags = AI_PASSIVE; /* all interfaces */
- return getaddrinfo(addr, port, &hints, results);
+ s = getaddrinfo(addr, port, &hints, results);
+ if (s) {
+ freeaddrinfo(*results);
+ *results = NULL;
+ }
+
+ return s;
}
-int socket_bind_in(psocket *psock, struct addrinfo *results)
+int socket_bind_in(psocket *psock, struct addrinfo **results)
{
- int fd = -1, rv, reuse_enable = 1;
+ int s = 1, fd = -1, rv, reuse_enable = 1;
struct addrinfo *rp = NULL;
- assert(psock && results);
+ assert(psock && results && *results);
+ psock->fd = -1;
- for (rp = results; rp != NULL; rp = rp->ai_next) {
+ for (rp = *results; rp != NULL; rp = rp->ai_next) {
fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (fd < 0)
continue;
@@ -56,7 +64,7 @@ int socket_bind_in(psocket *psock, struct addrinfo *results)
}
if (!rp)
- return -1;
+ goto finalise;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_enable, sizeof(int));
psock->fd = fd;
@@ -65,8 +73,13 @@ int socket_bind_in(psocket *psock, struct addrinfo *results)
psock->family = rp->ai_family;
psock->socktype = rp->ai_socktype;
psock->protocol = rp->ai_protocol;
- freeaddrinfo(results);
- return socket_nonblock(psock);
+ s = socket_nonblock(psock);
+
+finalise:
+ freeaddrinfo(*results);
+ *results = NULL;
+
+ return s;
}
int socket_listen_in(psocket *psock)
@@ -81,27 +94,31 @@ int socket_accept_in(const psocket *psock, psocket *client_psock)
int fd;
assert(psock && client_psock);
+ client_psock->fd = -1;
client_psock->addr_len = psock->addr_len;
fd = accept(psock->fd, &client_psock->addr,
&client_psock->addr_len);
if (fd < 0)
- return -1;
- if (socket_nonblock(psock))
- return -2;
+ return 1;
+ if (socket_nonblock(psock)) {
+ close(fd);
+ return 1;
+ }
client_psock->fd = fd;
return 0;
}
-int socket_connect_in(psocket *psock, struct addrinfo *results)
+int socket_connect_in(psocket *psock, struct addrinfo **results)
{
- int fd = -1, rv, reuse_enable = 1;
+ int s = 1, fd = -1, rv, reuse_enable = 1;
struct addrinfo *rp = NULL;
- assert(psock && results);
+ assert(psock && results && *results);
+ psock->fd = -1;
- for (rp = results; rp != NULL; rp = rp->ai_next) {
+ for (rp = *results; rp != NULL; rp = rp->ai_next) {
fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
if (fd < 0)
continue;
@@ -112,18 +129,22 @@ int socket_connect_in(psocket *psock, struct addrinfo *results)
}
if (!rp)
- return -1;
+ goto finalise;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_enable, sizeof(int));
psock->fd = fd;
psock->addr_len = rp->ai_addrlen;
- psock->addr = *rp->ai_addr;
+ psock->addr = *(rp->ai_addr);
psock->family = rp->ai_family;
psock->socktype = rp->ai_socktype;
psock->protocol = rp->ai_protocol;
- freeaddrinfo(results);
+ s = socket_nonblock(psock);
- return socket_nonblock(psock);
+finalise:
+ freeaddrinfo(*results);
+ *results = NULL;
+
+ return s;
}
int socket_addrtostr_in(const psocket *psock,
@@ -154,7 +175,7 @@ int socket_reconnect_in(psocket *psock)
if (psock->fd < 0)
return -2;
rv = connect(psock->fd, &psock->addr, psock->addr_len);
- if (!rv) {
+ if (rv) {
socket_close(psock);
return -3;
}
@@ -164,7 +185,7 @@ int socket_reconnect_in(psocket *psock)
return -4;
}
- return 0;
+ return socket_nonblock(psock);
}
int socket_close(psocket *psock)
@@ -172,8 +193,18 @@ int socket_close(psocket *psock)
int rv;
assert(psock);
+ if (psock->fd < 0)
+ return 0;
rv = close(psock->fd);
psock->fd = -1;
return rv;
}
+
+void socket_clone(const psocket *src, psocket *dst)
+{
+ assert(src && dst);
+
+ memcpy(dst, src, sizeof(*dst));
+ dst->fd = -1;
+}
diff --git a/src/socket.h b/src/socket.h
index 861a8b3..d2bb160 100644
--- a/src/socket.h
+++ b/src/socket.h
@@ -15,16 +15,18 @@ typedef struct psocket {
} psocket;
+int socket_nonblock(const psocket *psock);
+
int socket_init_in(const char *addr,
const char *port, struct addrinfo **results);
-int socket_bind_in(psocket *psock, struct addrinfo *results);
+int socket_bind_in(psocket *psock, struct addrinfo **results);
int socket_listen_in(psocket *psock);
int socket_accept_in(const psocket *psock, psocket *client_psock);
-int socket_connect_in(psocket *psock, struct addrinfo *results);
+int socket_connect_in(psocket *psock, struct addrinfo **results);
int socket_addrtostr_in(const psocket *psock,
char hbuf[NI_MAXHOST], char sbuf[NI_MAXSERV]);
@@ -33,4 +35,6 @@ int socket_reconnect_in(psocket *psock);
int socket_close(psocket *psock);
+void socket_clone(const psocket *src, psocket *dst);
+
#endif