/*
 * Server-side socket communication functions
 *
 * Copyright (C) 1998 Alexandre Julliard
 */

#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/uio.h>
#include <unistd.h>

#include "config.h"
#include "server.h"
#include "object.h"

/* Some versions of glibc don't define this */
#ifndef SCM_RIGHTS
#define SCM_RIGHTS 1
#endif

/* client state */
enum state
{
    RUNNING,   /* running normally */
    SENDING,   /* sending us a request */
    WAITING,   /* waiting for us to reply */
    READING    /* reading our reply */
};

/* client structure */
struct client
{
    enum state           state;      /* client state */
    struct select_user   select;     /* select user */
    unsigned int         seq;        /* current sequence number */
    struct header        head;       /* current msg header */
    char                *data;       /* current msg data */
    int                  count;      /* bytes sent/received so far */
    int                  pass_fd;    /* fd to pass to and from the client */
    struct thread       *self;       /* client thread (opaque pointer) */
    struct timeout_user *timeout;    /* current timeout (opaque pointer) */
};


/* exit code passed to remove_client */
#define OUT_OF_MEMORY  -1
#define BROKEN_PIPE    -2
#define PROTOCOL_ERROR -3


/* signal a client protocol error */
static void protocol_error( struct client *client, const char *err, ... )
{
    va_list args;

    va_start( args, err );
    fprintf( stderr, "Protocol error:%d: ", client->select.fd );
    vfprintf( stderr, err, args );
    va_end( args );
}

/* send a message to a client that is ready to receive something */
static void do_write( struct client *client, int client_fd )
{
    struct iovec vec[2];
#ifndef HAVE_MSGHDR_ACCRIGHTS
    struct cmsg_fd cmsg;
#endif
    struct msghdr msghdr;
    int ret;

    /* make sure we have something to send */
    assert( client->count < client->head.len );
    /* make sure the client is listening */
    assert( client->state == READING );

    msghdr.msg_name = NULL;
    msghdr.msg_namelen = 0;
    msghdr.msg_iov = vec;

    if (client->count < sizeof(client->head))
    {
        vec[0].iov_base = (char *)&client->head + client->count;
        vec[0].iov_len  = sizeof(client->head) - client->count;
        vec[1].iov_base = client->data;
        vec[1].iov_len  = client->head.len - sizeof(client->head);
        msghdr.msg_iovlen = 2;
    }
    else
    {
        vec[0].iov_base = client->data + client->count - sizeof(client->head);
        vec[0].iov_len  = client->head.len - client->count;
        msghdr.msg_iovlen = 1;
    }

#ifdef HAVE_MSGHDR_ACCRIGHTS
    if (client->pass_fd != -1)  /* we have an fd to send */
    {
        msghdr.msg_accrights = (void *)&client->pass_fd;
        msghdr.msg_accrightslen = sizeof(client->pass_fd);
    }
    else
    {
        msghdr.msg_accrights = NULL;
        msghdr.msg_accrightslen = 0;
    }
#else  /* HAVE_MSGHDR_ACCRIGHTS */
    if (client->pass_fd != -1)  /* we have an fd to send */
    {
        cmsg.len = sizeof(cmsg);
        cmsg.level = SOL_SOCKET;
        cmsg.type = SCM_RIGHTS;
        cmsg.fd = client->pass_fd;
        msghdr.msg_control = &cmsg;
        msghdr.msg_controllen = sizeof(cmsg);
    }
    else
    {
        msghdr.msg_control = NULL;
        msghdr.msg_controllen = 0;
    }
    msghdr.msg_flags = 0;
#endif  /* HAVE_MSGHDR_ACCRIGHTS */

    ret = sendmsg( client_fd, &msghdr, 0 );
    if (ret == -1)
    {
        if (errno != EPIPE) perror("sendmsg");
        remove_client( client, BROKEN_PIPE );
        return;
    }
    if (client->pass_fd != -1)  /* We sent the fd, now we can close it */
    {
        close( client->pass_fd );
        client->pass_fd = -1;
    }
    if ((client->count += ret) < client->head.len) return;

    /* we have finished with this message */
    if (client->data) free( client->data );
    client->data  = NULL;
    client->count = 0;
    client->state = RUNNING;
    client->seq++;
    set_select_events( &client->select, READ_EVENT );
}


/* read a message from a client that has something to say */
static void do_read( struct client *client, int client_fd )
{
    struct iovec vec;
    int pass_fd = -1;
    int ret;

#ifdef HAVE_MSGHDR_ACCRIGHTS
    struct msghdr msghdr;

    msghdr.msg_accrights    = (void *)&pass_fd;
    msghdr.msg_accrightslen = sizeof(int);
#else  /* HAVE_MSGHDR_ACCRIGHTS */
    struct msghdr msghdr;
    struct cmsg_fd cmsg;

    cmsg.len   = sizeof(cmsg);
    cmsg.level = SOL_SOCKET;
    cmsg.type  = SCM_RIGHTS;
    cmsg.fd    = -1;
    msghdr.msg_control    = &cmsg;
    msghdr.msg_controllen = sizeof(cmsg);
    msghdr.msg_flags      = 0;
#endif  /* HAVE_MSGHDR_ACCRIGHTS */

    msghdr.msg_name    = NULL;
    msghdr.msg_namelen = 0;
    msghdr.msg_iov     = &vec;
    msghdr.msg_iovlen  = 1;

    if (client->count < sizeof(client->head))
    {
        vec.iov_base = (char *)&client->head + client->count;
        vec.iov_len  = sizeof(client->head) - client->count;
    }
    else
    {
        if (!client->data &&
            !(client->data = malloc(client->head.len-sizeof(client->head))))
        {
            remove_client( client, OUT_OF_MEMORY );
            return;
        }
        vec.iov_base = client->data + client->count - sizeof(client->head);
        vec.iov_len  = client->head.len - client->count;
    }

    ret = recvmsg( client_fd, &msghdr, 0 );
    if (ret == -1)
    {
        perror("recvmsg");
        remove_client( client, BROKEN_PIPE );
        return;
    }
#ifndef HAVE_MSGHDR_ACCRIGHTS
    pass_fd = cmsg.fd;
#endif
    if (pass_fd != -1)
    {
        /* can only receive one fd per message */
        if (client->pass_fd != -1) close( client->pass_fd );
        client->pass_fd = pass_fd;
    }
    else if (!ret)  /* closed pipe */
    {
        remove_client( client, BROKEN_PIPE );
        return;
    }

    if (client->state == RUNNING) client->state = SENDING;
    assert( client->state == SENDING );

    client->count += ret;

    /* received the complete header yet? */
    if (client->count < sizeof(client->head)) return;

    /* sanity checks */
    if (client->head.seq != client->seq)
    {
        protocol_error( client, "bad sequence %08x instead of %08x\n",
                        client->head.seq, client->seq );
        remove_client( client, PROTOCOL_ERROR );
        return;
    }
    if ((client->head.len < sizeof(client->head)) ||
        (client->head.len > MAX_MSG_LENGTH + sizeof(client->head)))
    {
        protocol_error( client, "bad header length %08x\n",
                        client->head.len );
        remove_client( client, PROTOCOL_ERROR );
        return;
    }

    /* received the whole message? */
    if (client->count == client->head.len)
    {
        /* done reading the data, call the callback function */

        int len = client->head.len - sizeof(client->head);
        char *data = client->data;
        int passed_fd = client->pass_fd;
        enum request type = client->head.type;

        /* clear the info now, as the client may be deleted by the callback */
        client->head.len  = 0;
        client->head.type = 0;
        client->count     = 0;
        client->data      = NULL;
        client->pass_fd   = -1;
        client->state     = WAITING;
        client->seq++;

        call_req_handler( client->self, type, data, len, passed_fd );
        if (passed_fd != -1) close( passed_fd );
        if (data) free( data );
    }
}

/* handle a client event */
static void client_event( int event, void *private )
{
    struct client *client = (struct client *)private;
    if (event & WRITE_EVENT) do_write( client, client->select.fd );
    if (event & READ_EVENT) do_read( client, client->select.fd );
}


/*******************************************************************/
/* server-side exported functions                                  */

/* add a client */
struct client *add_client( int fd, struct thread *self )
{
    int flags;
    struct client *client = mem_alloc( sizeof(*client) );
    if (!client) return NULL;

    flags = fcntl( fd, F_GETFL, 0 );
    fcntl( fd, F_SETFL, flags | O_NONBLOCK );

    client->state                = RUNNING;
    client->select.fd            = fd;
    client->select.func          = client_event;
    client->select.private       = client;
    client->seq                  = 0;
    client->head.len             = 0;
    client->head.type            = 0;
    client->count                = 0;
    client->data                 = NULL;
    client->self                 = self;
    client->timeout              = NULL;
    client->pass_fd              = -1;
    register_select_user( &client->select );
    set_select_events( &client->select, READ_EVENT );
    return client;
}

/* remove a client */
void remove_client( struct client *client, int exit_code )
{
    assert( client );

    call_kill_handler( client->self, exit_code );

    if (client->timeout) remove_timeout_user( client->timeout );
    unregister_select_user( &client->select );
    close( client->select.fd );

    /* Purge messages */
    if (client->data) free( client->data );
    if (client->pass_fd != -1) close( client->pass_fd );
    free( client );
}

/* send a reply to a client */
int send_reply_v( struct client *client, int type, int pass_fd,
                  struct iovec *vec, int veclen )
{
    int i;
    unsigned int len;
    char *p;

    assert( client );
    assert( client->state == WAITING );
    assert( !client->data );

    if (debug_level) trace_reply( client->self, type, pass_fd, vec, veclen );

    for (i = len = 0; i < veclen; i++) len += vec[i].iov_len;
    assert( len < MAX_MSG_LENGTH );

    if (len && !(client->data = malloc( len ))) return -1;
    client->count     = 0;
    client->head.len  = len + sizeof(client->head);
    client->head.type = type;
    client->head.seq  = client->seq;
    client->pass_fd   = pass_fd;

    for (i = 0, p = client->data; i < veclen; i++)
    {
        memcpy( p, vec[i].iov_base, vec[i].iov_len );
        p += vec[i].iov_len;
    }

    client->state = READING;
    set_select_events( &client->select, WRITE_EVENT );
    return 0;
}
