#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <sys/epoll.h>
#include <pthread.h>
#include <assert.h>

#include "server.h"
#include "socket.h"
#include "log.h"

typedef struct client_thread_args {
    pthread_t self;
    psocket client_psock;
    char host_buf[NI_MAXHOST], service_buf[NI_MAXSERV];
    const server_ctx *server_ctx;
} client_thread_args;

static int server_accept_client(const server_ctx ctx[],
                                size_t siz, struct epoll_event *event);
static void *
client_mainloop_epoll(void *arg);


server_ctx *
server_init_ctx(server_ctx *ctx, init_cb init_fn)
{
    if (!ctx)
        ctx = (server_ctx *) malloc(sizeof(*ctx));
    assert(ctx);

    memset(ctx, 0, sizeof(*ctx));
    if (!init_fn(ctx))
        return NULL;

    return ctx;
}

int server_validate_ctx(const server_ctx *ctx)
{
    assert(ctx);
    assert(ctx->server_cbs.on_connect && ctx->server_cbs.on_disconnect
        && ctx->server_cbs.mainloop);
    assert(ctx->server_cbs.on_free && ctx->server_cbs.on_listen
        && ctx->server_cbs.on_shutdown);
    return 0;
}

int server_setup_epoll(server_ctx ctx[], size_t siz)
{
    int s, fd = epoll_create1(0); /* flags == 0 -> obsolete size arg is dropped */
    struct epoll_event ev;

    assert(ctx);
    assert(siz > 0 && siz < POTD_MAXFD);
    if (fd < 0)
        return -1;

    for (size_t i = 0; i < siz; ++i) {
        memset(&ev, 0, sizeof(ev));
        ev.data.fd = ctx[i].sock.fd;
        ev.events = EPOLLIN | EPOLLET;

        s = socket_addrtostr_in(&ctx[i].sock,
                                ctx[i].host_buf, ctx[i].service_buf);
        if (s) {
            E_GAIERR(s, "Convert socket address to string");
            return -2;
        }
        N("Redirector service listening on %s:%s",
          ctx[i].host_buf, ctx[i].service_buf);

        s = epoll_ctl(fd, EPOLL_CTL_ADD, ctx[i].sock.fd, &ev);
        if (s) {
            close(fd);
            return -3;
        }
    }

    return fd;
}

int server_mainloop_epoll(int epoll_fd, const server_ctx ctx[], size_t siz)
{
    static struct epoll_event *events = NULL;

    if (!events)
        events = (struct epoll_event *) calloc(POTD_MAXEVENTS, sizeof(*events));

    assert(events);
    assert(ctx);
    assert(siz > 0 && siz < POTD_MAXFD);

    while (1) {
        int n, i;

        n = epoll_wait(epoll_fd, events, POTD_MAXEVENTS, -1);
        if (n < 0)
            return 1;

        for (i = 0; i < n; ++i) {
            if ((events[i].events & EPOLLERR) ||
                (events[i].events & EPOLLHUP) ||
                (!(events[i].events & EPOLLIN)))
            {
                E("Epoll for descriptor %d failed", events[i].data.fd);
                E_STRERR("epoll_wait");
                close(events[i].data.fd);
                continue;
            } else {
                if (server_accept_client(ctx, siz, &events[i])) {
                    /* new client connection, accept succeeded */
                    continue;
                }
                W2("Server accept client failed: [fd: %d , npoll: %d]", events[i].data.fd, n);
            }
        }
    }

    free(events);
    return 0;
}

static int server_accept_client(const server_ctx ctx[],
                                size_t siz, struct epoll_event *event)
{
    size_t i;
    int s;
    client_thread_args *args;

    for (i = 0; i < siz; ++i) {
        if (ctx[i].sock.fd == event->data.fd) {
            args = (client_thread_args *) calloc(1, sizeof(client_thread_args));

            if (socket_accept_in(&ctx[i].sock, &args->client_psock)) {
                E_STRERR("Could not accept client connection");
                return 0;
            }

            args->server_ctx = &ctx[i];
            s = socket_addrtostr_in(&args->client_psock,
                                    args->host_buf, args->service_buf);
            if (s) {
                E_GAIERR(s, "Convert socket address to string");
                goto error;
            }
            N("New connection from %s:%s to %s:%s",
              args->host_buf, args->service_buf,
              ctx[i].host_buf, ctx[i].service_buf);

            if (pthread_create(&args->self, NULL,
                               client_mainloop_epoll, args))
            {
                E_STRERR("Thread creation");
                goto error;
            }

            return 1;
error:
            close(args->client_psock.fd);
            free(args);
            return 0;
        }
    }

    return 0;
}

static void *
client_mainloop_epoll(void *arg)
{
    client_thread_args *args;
    int s, epoll_fd;
    struct epoll_event event;
    struct epoll_event *events;

    assert(arg);
    args = (client_thread_args *) arg;
    pthread_detach(args->self);
    events = (struct epoll_event *) calloc(POTD_MAXEVENTS, sizeof(*events));
    assert(events);

    epoll_fd = epoll_create1(0);
    if (epoll_fd < 0)
        goto finish;

    event.data.fd = args->client_psock.fd;
    event.events = EPOLLIN | EPOLLOUT | EPOLLET;
    memset(&event, 0, sizeof(event));
    s = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, args->client_psock.fd, &event);
    if (s)
        goto finish;

finish:
    close(epoll_fd);
    close(args->client_psock.fd);
    free(events);
    free(args);
    return NULL;
}