#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <stdarg.h>
#include <fcntl.h>
#include <signal.h>
#include <pwd.h>
#include <grp.h>
#include <sys/types.h>
#include <sys/sysmacros.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <sys/prctl.h>
#include <sys/mount.h>
#include <linux/limits.h>
#include <assert.h>

#include "utils.h"
#include "log.h"

#define _POSIX_PATH_MAX 256

char *arg0 = NULL;
static int null_fd = -1;

static void sighandler_child(int signo);


int set_fd_nonblock(int fd)
{
    int flags;

    flags = fcntl(fd, F_GETFL, 0);
    if (flags < 0)
        return 1;
    flags |= O_NONBLOCK;
    if (fcntl(fd, F_SETFL, flags) == -1)
        return 1;

    return 0;
}

static void sighandler_child(int signo)
{
    switch (signo) {
        case SIGHUP:
            if (getppid() == 1) {
                N("Master process %d died, exiting", getpgrp());
                exit(EXIT_SUCCESS);
            }
            break;
    }
}

int set_child_sighandler(void)
{
    if (prctl(PR_SET_PDEATHSIG, SIGHUP) != 0)
        return 1;
    assert( signal(SIGCHLD, SIG_IGN) != SIG_ERR );
    assert( signal(SIGPIPE, SIG_IGN) != SIG_ERR );

    return signal(SIGHUP, sighandler_child) == SIG_ERR;
}

void set_procname(const char *new_arg0)
{
    assert(arg0);
	memset(arg0, 0, _POSIX_PATH_MAX);
	strncpy(arg0, new_arg0, _POSIX_PATH_MAX);
}

pid_t daemonize(int stay_foreground)
{
    int status = -1;
    pid_t pid;

    /* Fork off the parent process */
    pid = fork();

    /* An error occurred */
    if (pid < 0)
        return pid;

    /* Success: Let the parent terminate */
    if (pid > 0) {
        if (!stay_foreground)
            exit(EXIT_SUCCESS);
        waitpid(-1, &status, 0);
        exit(EXIT_SUCCESS);
    }

    /* On success: The child process becomes session leader */
    if (!stay_foreground && setsid() < 0) {
        E_STRERR("%s", "setsid");
        exit(EXIT_FAILURE);
    }

    /* Fork off for the second time*/
    if (!stay_foreground) {
        pid = fork();

        /* An error occurred */
        if (pid < 0)
            exit(EXIT_FAILURE);

        /* Success: Let the parent terminate */
        if (pid > 0)
            exit(EXIT_SUCCESS);
    }

    if (!stay_foreground && setpgrp()) {
        E_STRERR("%s", "setpgrp");
        exit(EXIT_FAILURE);
    }

    /* Set new file permissions */
    umask(0);

    if (!stay_foreground) {
        /* Change the working directory to the root directory */
        /* or another appropriated directory */
        chdir("/");
        /* Close all open file descriptors */
        assert( close_fds_except(-1) == 0 );
        assert( redirect_devnull_to(0, 1, 2, -1) == 0 );
    } else {
        assert( close_fds_except(0, 1, 2, -1) == 0 );
    }

    return pid;
}

int close_fds_except(int fds, ...)
{
    int fd;
    long max_fd;
    size_t i, except_count, found;
    va_list ap;

    max_fd = sysconf(_SC_OPEN_MAX) - 1;
    if (max_fd <= 0)
        return 1;
    if (fds < 0)
        return 1;

    va_start(ap, fds);
    {
        int *all_fds = (int *) malloc((max_fd+1) * sizeof(*all_fds));
        assert(all_fds);
        memset(all_fds, -1, max_fd * sizeof(*all_fds));

        except_count = 0;
        while ( (fd = va_arg(ap, int)) >= 0 ) {
            all_fds[except_count++] = fd;
        }
        all_fds[except_count++] = fds;

        for (fd = max_fd; fd >= 0; --fd) {
            found = 0;
            for (i = 0; i < except_count; ++i) {
                if (fd == all_fds[i])
                    found++;
            }
            if (!found) {
                close(fd);
            }
        }

        free(all_fds);
    }
    va_end(ap);

    return 0;
}

int redirect_devnull_to(int fds, ...)
{
    int fd, rc = 0;
    va_list ap;

    if (null_fd < 0)
        null_fd = open("/dev/null", O_RDWR);
    if (null_fd < 0)
        return -1;
    if (fds < -1)
        return -1;

    va_start(ap, fds);
    {
        while ( (fd = va_arg(ap, int)) >= 0 ) {
            if ( dup2(null_fd, fd) < 0 )
                rc++;
        }
    }
    va_end(ap);

    return rc;
}

int change_user_group(const char *user, const char *group)
{
    struct passwd *pwd = NULL;
    struct group *grp = NULL;
    gid_t gid;

    pwd = getpwnam(user);
    if (!pwd)
        return 1;

    if (!group) {
        gid = pwd->pw_gid;
    } else {
        grp = getgrnam(group);
        if (!grp)
            return 1;
        gid = grp->gr_gid;
    }

    if (setregid(gid, gid))
        return 1;
    if (setreuid(pwd->pw_uid, pwd->pw_uid))
        return 1;

    return 0;
}

int safe_chroot(const char *newroot)
{
    int s;

    s = chdir(newroot);
    if (s) {
        E_STRERR("Change directory to '%s'", newroot);
        return 1;
    }

    s = chroot(".");
    if (s) {
        E_STRERR("Change root directory to '%s'", ".");
        return 1;
    }

    s = chdir("/");
    if (s) {
        E_STRERR("Change directory inside new root to '%s'", "/");
        return 1;
    }

    return 0;
}

int dir_is_mountpoint(const char *path)
{
    struct stat current = {0}, parent = {0};
    size_t plen = strnlen(path, PATH_MAX);
    char parent_path[plen + 4];

    if (stat(path, &current))
        goto error;
    strncpy(parent_path, path, plen);
    parent_path[plen] = '/';
    parent_path[plen+1] = '.';
    parent_path[plen+2] = '.';
    parent_path[plen+3] = 0;

    if (stat(parent_path, &parent))
        goto error;

    return current.st_dev != parent.st_dev;
error:
    W_STRERR("Mountpoint check for '%s'", path);
    return -1;
}

void chk_chroot(void)
{
    struct stat s = {0};

    if (stat("/", &s) == 0) {
        if (s.st_ino != 2)
            return;
    }

    E("%s", "Can not mount filesystem as slave");
    exit(EXIT_FAILURE);
}

void mount_root(void)
{
    int s;

    s = mount("none", "/", "none", MS_PRIVATE|MS_REC, NULL);
    if (s)
        chk_chroot();
}

int mount_dev(const char *mount_path)
{
    int s;

    s = mount("tmpfs", mount_path, "tmpfs",
              MS_NOSUID|MS_STRICTATIME|
              MS_NOEXEC|MS_REC,
              "size=4k,mode=755,gid=0");
    if (s) {
        E_STRERR("Mount devtmpfs filesystem to %s", mount_path);
        return 1;
    }

    return 0;
}

int mount_pts(const char *mount_path)
{
    int s;

    s = mount("devpts", mount_path, "devpts",
              MS_MGC_VAL,
              "newinstance,gid=5,mode=620,ptmxmode=0666");

    if (s) {
        E_STRERR("Mount devpts filesystem to %s", mount_path);
        return 1;
    }

    return 0;
}

int mount_proc(const char *mount_path)
{
    int s;

    umount(mount_path);
    s = mount("proc", mount_path, "proc",
              MS_NOSUID|MS_NOEXEC|MS_NODEV|MS_REC, NULL);
    if (s) {
        E_STRERR("Mount proc filesystem to %s", mount_path);
        return 1;
    }

    return 0;
}

int create_device_file_checked(const char *mount_path, const char *device_file,
                               mode_t mode, int add_mode, dev_t dev)
{
    int s;
    mode_t defmode = S_IRUSR|S_IWUSR|
                     S_IRGRP|S_IWGRP|
                     S_IROTH;
    size_t plen = strnlen(mount_path, PATH_MAX);
    size_t dlen = strnlen(device_file, PATH_MAX);
    struct stat devbuf = {0};
    char devpath[plen+dlen+2];

    snprintf(devpath, plen+dlen+2, "%s/%s", mount_path, device_file);
    s = stat(devpath, &devbuf);

    if (s && errno != EEXIST && errno != ENOENT) {
        return 1;
    }
    if (errno == EEXIST) {
        if (remove(devpath))
            return 1;
    }

    D2("Create device file: %s", devpath);
    if (!add_mode)
        defmode = 0;
    s = mknod(devpath, defmode|mode, dev);
    if (s) {
        E_STRERR("Device creation '%s'", devpath);
        return 1;
    }

    return 0;
}

int create_device_files(const char *mount_path)
{
    int s = 0;

    s |= create_device_file_checked(mount_path, "ptmx", S_IFCHR, 1, makedev(5,2));
    s |= create_device_file_checked(mount_path, "tty", S_IFCHR, 1, makedev(5,0));

    return s;
}

int update_guid_map(pid_t pid, unsigned int map[3], int update_uidmap)
{
    int s, fd;
    ssize_t written;
    const char *const path_pid = "/proc/%d/%s";
    const char *const path_self = "/proc/self/%s";
    char buf[64];

    if (pid < 0) {
        s = snprintf(buf, sizeof buf, path_self,
                (update_uidmap ? "uid_map" : "gid_map"));
    } else {
        s = snprintf(buf, sizeof buf, path_pid, pid,
                (update_uidmap ? "uid_map" : "gid_map"));
    }
    if (s <= 0)
        return 1;

    fd = open(buf, O_WRONLY);
    if (fd < 0)
        return 1;

    s = snprintf(buf, sizeof buf, "%u %u %u\n", map[0], map[1], map[2]);
    written = write(fd, buf, s);
    if (written <= 0)
        return 1;

    return 0;
}

int update_setgroups_self(int allow)
{
    int fd;
    ssize_t written;
    const char *const path_self = "/proc/self/setgroups";
    const char *const str_allow = "allow";
    const char *const str_deny = "deny";

    fd = open(path_self, O_WRONLY);
    if (fd < 0)
        return 1;

    if (allow) {
        written = write(fd, str_allow, sizeof str_allow);
    } else {
        written = write(fd, str_deny, sizeof str_deny);
    }
    if (written <= 0)
        return 1;

    return 0;
}