aboutsummaryrefslogtreecommitdiff
path: root/src/protocol_ssh.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/protocol_ssh.c')
-rw-r--r--src/protocol_ssh.c143
1 files changed, 118 insertions, 25 deletions
diff --git a/src/protocol_ssh.c b/src/protocol_ssh.c
index baa1548..f78210c 100644
--- a/src/protocol_ssh.c
+++ b/src/protocol_ssh.c
@@ -8,9 +8,12 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/wait.h>
+#include <sys/mman.h>
#include <signal.h>
#include <poll.h>
#include <pwd.h>
+#include <pthread.h>
+#include <limits.h>
#include <linux/limits.h>
#include <libssh/callbacks.h>
#include <libssh/server.h>
@@ -28,6 +31,11 @@
LIBSSH_VERSION_MICRO < 3
#pragma message "Unsupported libssh version < 0.7.3"
#endif
+#define USER_LEN LOGIN_NAME_MAX
+#define PASS_LEN 80
+#define CACHE_MAX 32
+#define CACHE_TIME (60 * 20) /* max cache time 20 minutes */
+#define LOGIN_SUCCESS_PROB ((double)1/4) /* successful login probability */
static int version_logged = 0;
static const char rsa_key_suf[] = "ssh_host_rsa_key";
@@ -49,15 +57,25 @@ struct protocol_cbs potd_ssh_callbacks = {
.on_shutdown = ssh_on_shutdown
};
+typedef struct ssh_login_cache {
+ char user[USER_LEN];
+ char pass[PASS_LEN];
+ time_t last_used;
+ pthread_mutex_t cache_mtx;
+} ssh_login_cache;
+
static int set_default_keys(ssh_bind sshbind, int rsa_already_set,
int dsa_already_set, int ecdsa_already_set);
static int gen_default_keys(void);
-static int gen_export_sshkey(enum ssh_keytypes_e type, int length, const char *path);
-static void ssh_log_cb(int priority, const char *function, const char *buffer, void *userdata);
+static int gen_export_sshkey(enum ssh_keytypes_e type, int length,
+ const char *path);
+static void ssh_log_cb(int priority, const char *function, const char *buffer,
+ void *userdata);
static void ssh_mainloop(ssh_data *arg)
__attribute__((noreturn));
-static int authenticate(ssh_session session);
-static int auth_password(const char *user, const char *password);
+static int authenticate(ssh_session session, ssh_login_cache *cache);
+static int auth_password(const char *user, const char *pass,
+ ssh_login_cache *cache);
static int client_mainloop(ssh_client *arg);
static int copy_fd_to_chan(socket_t fd, int revents, void *userdata);
static int copy_chan_to_fd(ssh_session session, ssh_channel channel, void *data,
@@ -375,6 +393,9 @@ static void ssh_log_cb(int priority, const char *function,
static void ssh_mainloop(ssh_data *arg)
{
+ pthread_mutexattr_t shared;
+ ssh_login_cache *cache = NULL;
+ size_t i;
int s, auth = 0, shell = 0, is_child;
ssh_session ses;
ssh_message message;
@@ -384,6 +405,16 @@ static void ssh_mainloop(ssh_data *arg)
assert(arg);
set_procname("[potd] ssh");
assert( set_child_sighandler() == 0 );
+ cache = (ssh_login_cache *) mmap(NULL, sizeof(*cache) * CACHE_MAX,
+ PROT_READ|PROT_WRITE,
+ MAP_SHARED|MAP_ANONYMOUS, -1, 0);
+ assert( cache );
+ memset(cache, 0, sizeof(*cache) * CACHE_MAX);
+ pthread_mutexattr_init(&shared);
+ pthread_mutexattr_setpshared(&shared, PTHREAD_PROCESS_SHARED);
+ for (i = 0; i < CACHE_MAX; ++i) {
+ pthread_mutex_init(&(cache[i].cache_mtx), &shared);
+ }
while (1) {
ses = ssh_new();
@@ -420,7 +451,7 @@ static void ssh_mainloop(ssh_data *arg)
}
/* proceed to authentication */
- auth = authenticate(ses);
+ auth = authenticate(ses, cache);
if (!auth) {
W("SSH authentication error: %s", ssh_get_error(ses));
goto failed;
@@ -488,13 +519,14 @@ static void ssh_mainloop(ssh_data *arg)
ssh_bind_get_fd(arg->sshbind));
failed:
+ munmap(cache, sizeof(*cache) * CACHE_MAX);
ssh_disconnect(ses);
ssh_free(ses);
exit(EXIT_SUCCESS);
}
}
-static int authenticate(ssh_session session)
+static int authenticate(ssh_session session, ssh_login_cache *cache)
{
ssh_message message;
@@ -512,7 +544,7 @@ static int authenticate(ssh_session session)
ssh_message_auth_user(message),
ssh_message_auth_password(message));
if (auth_password(ssh_message_auth_user(message),
- ssh_message_auth_password(message)))
+ ssh_message_auth_password(message), cache))
{
ssh_message_auth_reply_success(message,0);
ssh_message_free(message);
@@ -550,18 +582,78 @@ static int authenticate(ssh_session session)
return 0;
}
-static int auth_password(const char *user, const char *password)
+static int auth_password(const char *user, const char *pass,
+ ssh_login_cache *cache)
{
- (void) user;
- (void) password;
-
-/*
- if(strcmp(user, SSHD_USER))
- return 0;
- if(strcmp(password, SSHD_PASSWORD))
- return 0;
-*/
- return 1; /* authenticated */
+ int got_auth = 0, cached = 0;
+ size_t i;
+ double d;
+ time_t o, t = time(NULL);
+ struct tm *tmp;
+ char time_str[64] = {0};
+
+ for (i = 0; i < CACHE_MAX; ++i) {
+ pthread_mutex_lock(&cache[i].cache_mtx);
+ if (cache[i].user[0] && cache[i].pass[0]) {
+ o = cache[i].last_used;
+ cache[i].last_used = t;
+
+ if (strncmp(user, cache[i].user, USER_LEN) == 0 &&
+ strnlen(user, USER_LEN) == strnlen(cache[i].user, USER_LEN) &&
+ strncmp(pass, cache[i].pass, PASS_LEN) == 0 &&
+ strnlen(pass, PASS_LEN) == strnlen(cache[i].pass, PASS_LEN))
+ {
+ tmp = localtime(&o);
+ if (!strftime(time_str, sizeof time_str, "%H:%M:%S", tmp))
+ snprintf(time_str, sizeof time_str, "%s", "UNKNOWN_TIME");
+ N("Got cached user/pass '%s'/'%s' from %s",
+ user, pass, time_str);
+ got_auth = 1;
+ }
+
+ d = difftime(t, o);
+ if (d > CACHE_TIME) {
+ D("Delete cached user/pass '%s'/'%s' (timeout)",
+ cache[i].user, cache[i].pass);
+ cache[i].user[0] = 0;
+ cache[i].pass[0] = 0;
+ }
+ }
+ pthread_mutex_unlock(&cache[i].cache_mtx);
+
+ if (got_auth)
+ break;
+ }
+
+ /* not auth'd but we have still some randomness */
+ if (!got_auth) {
+ srandom(time(NULL));
+ d = (double)(random() % RAND_MAX);
+ d /= (double)RAND_MAX;
+ if (d <= LOGIN_SUCCESS_PROB) {
+ N("Randomness won for user/pass '%s'/'%s': %.02f < %.02f",
+ user, pass, d, LOGIN_SUCCESS_PROB);
+ got_auth = 1;
+
+ for (i = 0; i < CACHE_MAX; ++i) {
+ pthread_mutex_lock(&cache[i].cache_mtx);
+ if (!cache[i].user[0] && !cache[i].pass[0]) {
+ D("Caching user/pass '%s'/'%s'",
+ user, pass);
+ snprintf(cache[i].user, sizeof cache[i].user, "%s", user);
+ snprintf(cache[i].pass, sizeof cache[i].pass, "%s", pass);
+ cache[i].last_used = t;
+ cached = 1;
+ }
+ pthread_mutex_unlock(&cache[i].cache_mtx);
+
+ if (cached)
+ break;
+ }
+ }
+ }
+
+ return got_auth;
}
static int client_mainloop(ssh_client *data)
@@ -641,11 +733,11 @@ static int copy_chan_to_fd(ssh_session session,
int is_stderr,
void *userdata)
{
- int fd = *(int*)userdata;
+ int fd = *(int*) userdata;
int sz;
- (void)session;
- (void)channel;
- (void)is_stderr;
+
+ (void) session;
+ (void) is_stderr;
sz = write(fd, data, len);
if (sz <= 0)
@@ -657,9 +749,10 @@ static int copy_chan_to_fd(ssh_session session,
static void chan_close(ssh_session session, ssh_channel channel,
void *userdata)
{
- int fd = *(int*)userdata;
- (void)session;
- (void)channel;
+ int fd = *(int*) userdata;
+
+ (void) session;
+ (void) channel;
close(fd);
}