diff options
Diffstat (limited to 'src/protocol_ssh.c')
-rw-r--r-- | src/protocol_ssh.c | 505 |
1 files changed, 505 insertions, 0 deletions
diff --git a/src/protocol_ssh.c b/src/protocol_ssh.c new file mode 100644 index 0000000..ea2aa5b --- /dev/null +++ b/src/protocol_ssh.c @@ -0,0 +1,505 @@ +#include <stdio.h> +#include <stdlib.h> +#include <assert.h> +#include <sys/types.h> +#include <sys/wait.h> +#include <signal.h> +#include <pthread.h> +#include <poll.h> +#include <pty.h> +#include <utmp.h> +#include <libssh/callbacks.h> +#include <libssh/server.h> + +#include "protocol_ssh.h" +#include "protocol.h" +#include "log.h" + +#if LIBSSH_VERSION_MAJOR != 0 || LIBSSH_VERSION_MINOR < 7 || \ + LIBSSH_VERSION_MICRO < 3 +#pragma message "Unsupported libssh version < 0.7.3" +#endif + +typedef struct ssh_data { + pthread_t self; + ssh_bind sshbind; +} ssh_data; + +struct protocol_cbs potd_ssh_callbacks = { + .on_listen = ssh_on_listen, + .on_shutdown = ssh_on_shutdown +}; + +static int set_default_keys(ssh_bind sshbind, int rsa_already_set, + int dsa_already_set, int ecdsa_already_set); +static int gen_default_keys(void); +static int gen_export_sshkey(enum ssh_keytypes_e type, int length, const char *path); +static void ssh_log_cb(int priority, const char *function, const char *buffer, void *userdata); +static void * +ssh_thread_mainloop(void *arg); +static int authenticate(ssh_session session); +static int auth_password(const char *user, const char *password); +static int client_mainloop(ssh_channel chan); +static int copy_fd_to_chan(socket_t fd, int revents, void *userdata); +static int copy_chan_to_fd(ssh_session session, ssh_channel channel, void *data, + uint32_t len, int is_stderr, void *userdata); +static void chan_close(ssh_session session, ssh_channel channel, void *userdata); + +struct ssh_channel_callbacks_struct ssh_channel_cb = { + .channel_data_function = copy_chan_to_fd, + .channel_eof_function = chan_close, + .channel_close_function = chan_close, + .userdata = NULL +}; + + +int ssh_init_cb(protocol_ctx *ctx) +{ + N("libssh version: %s", ssh_version(0)); + if (ssh_version(SSH_VERSION_INT(LIBSSH_VERSION_MAJOR, + LIBSSH_VERSION_MINOR, + LIBSSH_VERSION_MICRO)) == NULL) + { + W("This software was compiled/linked for libssh %d.%d.%d," + " which you aren't currently using.", + LIBSSH_VERSION_MAJOR, LIBSSH_VERSION_MINOR, LIBSSH_VERSION_MICRO); + } + if (ssh_version(SSH_VERSION_INT(0,7,3)) == NULL) + { + W("%s", "Unsupported libssh version < 0.7.3"); + } + if (ssh_version(SSH_VERSION_INT(0,7,4)) != NULL || + ssh_version(SSH_VERSION_INT(0,7,90)) != NULL) + { + W("%s", + "libssh versions > 0.7.3 may suffer " + "from problems with the pki key generation/export"); + } + ctx->cbs = potd_ssh_callbacks; + + if (ssh_init()) + return 1; + + ssh_data *d = (ssh_data *) calloc(1, sizeof(*d)); + assert(d); + d->sshbind = ssh_bind_new(); + ctx->src.data = d; + + ssh_set_log_callback(ssh_log_cb); + ssh_set_log_level(SSH_LOG_FUNCTIONS); + + if (!d->sshbind) + return 1; + if (gen_default_keys()) + return 1; + if (set_default_keys(d->sshbind, 0, 0, 0)) + return 1; + + return 0; +} + +int ssh_on_listen(protocol_ctx *ctx, const char *host, + const char *port) +{ + int s; + ssh_data *d = (ssh_data *) ctx->src.data; + + if (ssh_bind_options_set(d->sshbind, SSH_BIND_OPTIONS_BINDADDR, + host)) + return 1; + if (ssh_bind_options_set(d->sshbind, SSH_BIND_OPTIONS_BINDPORT_STR, + port)) + return 1; + + s = ssh_bind_listen(d->sshbind); + if (s < 0) { + E_STRERR("Error listening to SSH socket: %s", ssh_get_error(d->sshbind)); + return s; + } + N("SSH bind and listen on %s:%s fd %d", host, port, + ssh_bind_get_fd(d->sshbind)); + + s = pthread_create(&d->self, NULL, ssh_thread_mainloop, d); + if (s) { + E_STRERR("SSH Thread creation on %s:%s fd %d", + host, port, ssh_bind_get_fd(d->sshbind)); + } + + return s; +} + +int ssh_on_shutdown(protocol_ctx *ctx) +{ + return 0; +} + +static int set_default_keys(ssh_bind sshbind, int rsa_already_set, + int dsa_already_set, int ecdsa_already_set) +{ + if (!rsa_already_set) { + if (ssh_bind_options_set(sshbind, SSH_BIND_OPTIONS_RSAKEY, + "./ssh_host_rsa_key")) { + E2("Faled to set RSA key: %s", ssh_get_error(sshbind)); + return 1; + } + } + if (!dsa_already_set) { + if (ssh_bind_options_set(sshbind, SSH_BIND_OPTIONS_DSAKEY, + "./ssh_host_dsa_key")) { + E2("Failed to set DSA key: %s", ssh_get_error(sshbind)); + return 1; + } + } + if (!ecdsa_already_set) { + if (ssh_bind_options_set(sshbind, SSH_BIND_OPTIONS_ECDSAKEY, + "./ssh_host_ecdsa_key")) { + E2("Failed to set ECDSA key: %s", ssh_get_error(sshbind)); + return 1; + } + } + return 0; +} + +static int gen_default_keys(void) +{ + int s = 0; + + if (gen_export_sshkey(SSH_KEYTYPE_RSA, 1024, "./ssh_host_rsa_key")) { + W("libssh %s key generation failed, using fallback ssh-keygen", "RSA"); + remove("./ssh_host_rsa_key"); + s |= system("ssh-keygen -t rsa -b 1024 -f ./ssh_host_rsa_key -N '' >/dev/null 2>/dev/null"); + } + if (gen_export_sshkey(SSH_KEYTYPE_DSS, 1024, "./ssh_host_dsa_key")) { + W("libssh %s key generation failed, using fallback ssh-keygen", "DSA"); + remove("./ssh_host_dsa_key"); + s |= system("ssh-keygen -t dsa -b 1024 -f ./ssh_host_dsa_key -N '' >/dev/null 2>/dev/null"); + } + if (gen_export_sshkey(SSH_KEYTYPE_ECDSA, 1024, "./ssh_host_ecdsa_key")) { + W("libssh %s key generation failed, using fallback ssh-keygen", "ECDSA"); + remove("./ssh_host_ecdsa_key"); + s |= system("ssh-keygen -t ecdsa -b 256 -f ./ssh_host_ecdsa_key -N '' >/dev/null 2>/dev/null"); + } + + return s != 0; +} + +static int gen_export_sshkey(enum ssh_keytypes_e type, int length, const char *path) +{ + ssh_key priv_key; + const char *type_str = NULL; + int s; + + assert(path); + assert(length == 1024 || length == 2048 || + length == 4096); + + switch (type) { + case SSH_KEYTYPE_DSS: + type_str = "DSS"; + break; + case SSH_KEYTYPE_RSA: + type_str = "RSA"; + break; + case SSH_KEYTYPE_ECDSA: + type_str = "ECDSA"; + break; + default: + W2("Unknown SSH key type: %d", type); + return 1; + } + N2("Generating %s key with length %d bits and save it on disk: %s", + type_str, length, path); + s = ssh_pki_generate(type, length, &priv_key); + if (s != SSH_OK) { + E2("Generating %s key failed: %d", type_str, s); + return 1; + } + s = ssh_pki_export_privkey_file(priv_key, "", NULL, + NULL, path); + ssh_key_free(priv_key); + + if (s != SSH_OK) { + W2("SSH private key export to file failed: %d", s); + return 1; + } + + return 0; +} + +static void ssh_log_cb(int priority, const char *function, + const char *buffer, void *userdata) +{ + switch (priority) { + case 0: + W("libssh: %s", buffer); + break; + case 1: + N("libssh: %s", buffer); + break; + default: + D("libssh: %s", buffer); + break; + } +} + +static void * +ssh_thread_mainloop(void *arg) +{ + int s, auth = 0, shell = 0; + ssh_session ses; + ssh_message message; + ssh_channel chan = NULL; + + assert(arg); + ssh_data *d = (ssh_data *) arg; + pthread_detach(d->self); + + while (1) { + ses = ssh_new(); + assert(ses); + + s = ssh_bind_accept(d->sshbind, ses); + if (s == SSH_ERROR) { + W("SSH error while accepting a connection: %s", + ssh_get_error(ses)); + goto failed; + } + + if (ssh_handle_key_exchange(ses)) { + W("SSH key exchange failed: %s", ssh_get_error(ses)); + goto failed; + } + + /* proceed to authentication */ + auth = authenticate(ses); + if (!auth) { + W("SSH authentication error: %s", ssh_get_error(ses)); + goto failed; + } + + /* wait for a channel session */ + do { + message = ssh_message_get(ses); + if (message) { + if (ssh_message_type(message) == SSH_REQUEST_CHANNEL_OPEN && + ssh_message_subtype(message) == SSH_CHANNEL_SESSION) + { + chan = ssh_message_channel_request_open_reply_accept(message); + ssh_message_free(message); + break; + } else { + ssh_message_reply_default(message); + ssh_message_free(message); + } + } else { + break; + } + } while (!chan); + + if (!chan) { + W("SSH client did not ask for a channel session: %s", + ssh_get_error(ses)); + goto failed; + } + + /* wait for a shell */ + do { + message = ssh_message_get(ses); + if (message != NULL) { + if (ssh_message_type(message) == SSH_REQUEST_CHANNEL) { + if (ssh_message_subtype(message) == SSH_CHANNEL_REQUEST_SHELL) { + shell = 1; + ssh_message_channel_request_reply_success(message); + ssh_message_free(message); + break; + } else if (ssh_message_subtype(message) == SSH_CHANNEL_REQUEST_PTY) { + ssh_message_channel_request_reply_success(message); + ssh_message_free(message); + continue; + } + } + ssh_message_reply_default(message); + ssh_message_free(message); + } else { + break; + } + } while (!shell); + + if (!shell) { + W("SSH client had no shell requested: %s", ssh_get_error(ses)); + goto failed; + } + + N("%s", "Dropping user into shell"); + client_mainloop(chan); + +failed: + ssh_disconnect(ses); + ssh_free(ses); + } + + return NULL; +} + +static int authenticate(ssh_session session) +{ + ssh_message message; + + do { + message = ssh_message_get(session); + if (!message) + break; + + switch (ssh_message_type(message)) { + + case SSH_REQUEST_AUTH: + switch (ssh_message_subtype(message)) { + case SSH_AUTH_METHOD_PASSWORD: + N("SSH: user '%s' wants to auth with pass '%s'", + ssh_message_auth_user(message), + ssh_message_auth_password(message)); + if (auth_password(ssh_message_auth_user(message), + ssh_message_auth_password(message))) + { + ssh_message_auth_reply_success(message,0); + ssh_message_free(message); + return 1; + } + ssh_message_auth_set_methods(message, + SSH_AUTH_METHOD_PASSWORD | + SSH_AUTH_METHOD_INTERACTIVE); + /* not authenticated, send default message */ + ssh_message_reply_default(message); + break; + + case SSH_AUTH_METHOD_NONE: + default: + N("SSH: User '%s' wants to auth with unknown auth '%d'", + ssh_message_auth_user(message), + ssh_message_subtype(message)); + ssh_message_auth_set_methods(message, + SSH_AUTH_METHOD_PASSWORD | + SSH_AUTH_METHOD_INTERACTIVE); + ssh_message_reply_default(message); + break; + } + break; + + default: + ssh_message_auth_set_methods(message, + SSH_AUTH_METHOD_PASSWORD | + SSH_AUTH_METHOD_INTERACTIVE); + ssh_message_reply_default(message); + } + ssh_message_free(message); + } while (1); + + return 0; +} + +static int auth_password(const char *user, const char *password) +{ +/* + if(strcmp(user, SSHD_USER)) + return 0; + if(strcmp(password, SSHD_PASSWORD)) + return 0; +*/ + return 1; /* authenticated */ +} + +static int client_mainloop(ssh_channel chan) +{ + ssh_session session = ssh_channel_get_session(chan); + socket_t fd; + struct termios *term = NULL; + struct winsize *win = NULL; + pid_t childpid; + ssh_event event; + short events; + + childpid = forkpty(&fd, NULL, term, win); + if (childpid == 0) { + execl("/bin/bash", "/bin/bash", (char *)NULL); + abort(); + } + + ssh_channel_cb.userdata = &fd; + ssh_callbacks_init(&ssh_channel_cb); + ssh_set_channel_callbacks(chan, &ssh_channel_cb); + + events = POLLIN | POLLPRI | POLLERR | POLLHUP | POLLNVAL; + event = ssh_event_new(); + + if (event == NULL) { + W("%s", "Couldn't get a event"); + return 1; + } + if (ssh_event_add_fd(event, fd, events, copy_fd_to_chan, chan) != SSH_OK) { + W("%s", "Couldn't add an fd to the event"); + return 1; + } + if (ssh_event_add_session(event, session) != SSH_OK) { + W("%s", "Couldn't add the session to the event"); + return 1; + } + + do { + ssh_event_dopoll(event, 1000); + } while (!ssh_channel_is_closed(chan)); + + ssh_event_remove_fd(event, fd); + ssh_event_remove_session(event, session); + ssh_event_free(event); + return 0; +} + +static int copy_fd_to_chan(socket_t fd, int revents, void *userdata) +{ + ssh_channel chan = (ssh_channel)userdata; + char buf[BUFSIZ]; + int sz = 0; + + if(!chan) { + close(fd); + return -1; + } + if(revents & POLLIN) { + sz = read(fd, buf, BUFSIZ); + if(sz > 0) { + ssh_channel_write(chan, buf, sz); + } + } + if(revents & POLLHUP) { + ssh_channel_close(chan); + sz = -1; + } + return sz; +} + +static int copy_chan_to_fd(ssh_session session, + ssh_channel channel, + void *data, + uint32_t len, + int is_stderr, + void *userdata) +{ + int fd = *(int*)userdata; + int sz; + (void)session; + (void)channel; + (void)is_stderr; + + sz = write(fd, data, len); + return sz; +} + +static void chan_close(ssh_session session, ssh_channel channel, + void *userdata) +{ + int fd = *(int*)userdata; + (void)session; + (void)channel; + + close(fd); +} |