#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <signal.h>
#include <sys/epoll.h>
#include <assert.h>

#include "pevent.h"
#include "log.h"


void event_init(event_ctx **ctx)
{
    assert(ctx);
    if (!*ctx)
        *ctx = (event_ctx *) malloc(sizeof(**ctx));
    assert(*ctx);

    memset(*ctx, 0, sizeof(**ctx));
    (*ctx)->epoll_fd = -1;
}

void event_free(event_ctx **ctx)
{
    close((*ctx)->epoll_fd);
    free((*ctx));
    *ctx = NULL;
}

int event_setup(event_ctx *ctx)
{
    assert(ctx);

    if (ctx->epoll_fd < 0)
        /* flags == 0 -> obsolete size arg is dropped */
        ctx->epoll_fd = epoll_create1(0);

    return ctx->epoll_fd < 0;
}

int event_add_sock(event_ctx *ctx, psocket *sock)
{
    int s;
    struct epoll_event ev = {0,{0}};

    assert(ctx && sock);

    ev.data.fd = sock->fd;
    ev.events = EPOLLIN | EPOLLET;
    s = epoll_ctl(ctx->epoll_fd, EPOLL_CTL_ADD, sock->fd, &ev);
    if (s)
        return 1;

    return 0;
}

int event_add_fd(event_ctx *ctx, int fd)
{
    int s;
    struct epoll_event ev = {0,{0}};

    assert(ctx);

    ev.data.fd = fd;
    ev.events = EPOLLIN | EPOLLET;
    s = epoll_ctl(ctx->epoll_fd, EPOLL_CTL_ADD, fd, &ev);
    if (s)
        return 1;

    return 0;
}

int event_loop(event_ctx *ctx, on_event_cb on_event, void *user_data)
{
    int n, i;
    sigset_t eset;

    assert(ctx && on_event);
    sigemptyset(&eset);
    ctx->active = 1;

    while (ctx->active) {
        errno = 0;
        n = epoll_pwait(ctx->epoll_fd, ctx->events, POTD_MAXEVENTS, -1, &eset);
        if (errno == EINTR)
            continue;
        if (n < 0) {
            ctx->active = 0;
            break;
        }

        for (i = 0; i < n; ++i) {
            ctx->current_event = i;

            if ((ctx->events[i].events & EPOLLERR) ||
                (ctx->events[i].events & EPOLLHUP) ||
                (ctx->events[i].events & EPOLLRDHUP) ||
                (!(ctx->events[i].events & EPOLLIN)))
            {
                E_STRERR("Event epoll for descriptor %d",
                    ctx->events[i].data.fd);
                ctx->active = 0;
            } else {
                if (!on_event(ctx, ctx->events[i].data.fd, user_data))
                    W2("Event callback failed: [fd: %d , npoll: %d]",
                        ctx->events[i].data.fd, n);
            }

            if (!ctx->active)
                break;
        }
    }

    return ctx->active == 0;
}

forward_state
event_forward_connection(event_ctx *ctx, int dest_fd, on_data_cb on_data,
                         void *user_data)
{
    int data_avail = 1;
    int has_input;
    int saved_errno;
    forward_state rc = CON_OK;
    ssize_t siz;
    char buf[BUFSIZ];
    struct epoll_event *ev;

    assert(ctx->current_event >= 0 &&
        ctx->current_event < POTD_MAXEVENTS);
    ev = &ctx->events[ctx->current_event];

    while (data_avail) {
        has_input = 0;
        saved_errno = 0;
        siz = -1;

        if (ev->events & EPOLLIN) {
            has_input = 1;
            errno = 0;
            siz = read(ev->data.fd, &buf[0], BUFSIZ);
            saved_errno = errno;
        } else break;
        if (saved_errno == EAGAIN)
            break;

        switch (siz) {
            case -1:
                E_STRERR("Client read from fd %d", ev->data.fd);
                rc = CON_IN_ERROR;
                break;
            case 0:
                rc = CON_IN_TERMINATED;
                break;
            default:
                D2("Read %lu bytes from fd %d", siz, ev->data.fd);
                break;
        }

        if (rc != CON_OK)
            break;

        if (on_data &&
            on_data(ctx, ev->data.fd, dest_fd, buf, siz, user_data))
        {
            W2("On data callback failed, not forwarding from %d to %d",
                ev->data.fd, dest_fd);
            continue;
        }

        if (has_input) {
            errno = 0;
            siz = write(dest_fd, &buf[0], siz);

            switch (siz) {
                case -1:
                    rc = CON_OUT_ERROR;
                    break;
                case 0:
                    rc = CON_OUT_TERMINATED;
                    break;
                default:
                    D2("Written %lu bytes from fd %d to fd %d",
                        siz, ev->data.fd, dest_fd);
                    break;
            }
        }

        if (rc != CON_OK)
            break;
    }

    D2("Connection state: %d", rc);
    if (rc != CON_OK) {
        if (shutdown(ev->data.fd, SHUT_RDWR))
            E_STRERR("Shutdown source socket fd %d", ev->data.fd);
        if (shutdown(dest_fd, SHUT_RDWR))
            E_STRERR("Shutdown dest socket fd %d", dest_fd);
    }
    return rc;
}