/*
Copyright (c) 2018 Contributors as noted in the AUTHORS file

This file is part of 0MQ.

0MQ is free software; you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 3 of the License, or
(at your option) any later version.

0MQ is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "../tests/testutil.hpp"

#include <poller.hpp>
#include <i_poll_events.hpp>
#include <ip.hpp>

#include <unity.h>

#ifndef _WIN32
#define closesocket close
#endif

void setUp ()
{
}
void tearDown ()
{
}

void test_create ()
{
    zmq::thread_ctx_t thread_ctx;
    zmq::poller_t poller (thread_ctx);
}

#if 0
// TODO this triggers an assertion. should it be a valid use case?
void test_start_empty ()
{
    zmq::thread_ctx_t thread_ctx;
    zmq::poller_t poller (thread_ctx);
    poller.start ();
    msleep (SETTLE_TIME);
}
#endif

struct test_events_t : zmq::i_poll_events
{
    test_events_t (zmq::fd_t fd_, zmq::poller_t &poller_) :
        fd (fd_),
        poller (poller_)
    {
    }

    virtual void in_event ()
    {
        poller.rm_fd (handle);
        handle = (zmq::poller_t::handle_t) NULL;

        // this must only be incremented after rm_fd
        in_events.add (1);
    }


    virtual void out_event ()
    {
        // TODO
    }


    virtual void timer_event (int id_)
    {
        LIBZMQ_UNUSED (id_);
        poller.rm_fd (handle);
        handle = (zmq::poller_t::handle_t) NULL;

        // this must only be incremented after rm_fd
        timer_events.add (1);
    }

    void set_handle (zmq::poller_t::handle_t handle_) { handle = handle_; }

    zmq::atomic_counter_t in_events, timer_events;

  private:
    zmq::fd_t fd;
    zmq::poller_t &poller;
    zmq::poller_t::handle_t handle;
};

void wait_in_events (test_events_t &events)
{
    void *watch = zmq_stopwatch_start ();
    while (events.in_events.get () < 1) {
#ifdef ZMQ_BUILD_DRAFT
        TEST_ASSERT_LESS_OR_EQUAL_MESSAGE (SETTLE_TIME,
                                           zmq_stopwatch_intermediate (watch),
                                           "Timeout waiting for in event");
#endif
    }
    zmq_stopwatch_stop (watch);
}

void wait_timer_events (test_events_t &events)
{
    void *watch = zmq_stopwatch_start ();
    while (events.timer_events.get () < 1) {
#ifdef ZMQ_BUILD_DRAFT
        TEST_ASSERT_LESS_OR_EQUAL_MESSAGE (SETTLE_TIME,
                                           zmq_stopwatch_intermediate (watch),
                                           "Timeout waiting for timer event");
#endif
    }
    zmq_stopwatch_stop (watch);
}

void create_nonblocking_fdpair (zmq::fd_t *r, zmq::fd_t *w)
{
    int rc = zmq::make_fdpair (r, w);
    TEST_ASSERT_EQUAL_INT (0, rc);
    TEST_ASSERT_NOT_EQUAL (zmq::retired_fd, *r);
    TEST_ASSERT_NOT_EQUAL (zmq::retired_fd, *w);
    zmq::unblock_socket (*r);
    zmq::unblock_socket (*w);
}

void send_signal (zmq::fd_t w)
{
#if defined ZMQ_HAVE_EVENTFD
    const uint64_t inc = 1;
    ssize_t sz = write (w, &inc, sizeof (inc));
    assert (sz == sizeof (inc));
#else
    {
        char msg[] = "test";
        int rc = send (w, msg, sizeof (msg), 0);
        assert (rc == sizeof (msg));
    }
#endif
}

void close_fdpair (zmq::fd_t w, zmq::fd_t r)
{
    int rc = closesocket (w);
    TEST_ASSERT_EQUAL_INT (0, rc);
#if !defined ZMQ_HAVE_EVENTFD
    rc = closesocket (r);
    TEST_ASSERT_EQUAL_INT (0, rc);
#else
    LIBZMQ_UNUSED (r);
#endif
}

void test_add_fd_and_start_and_receive_data ()
{
    zmq::thread_ctx_t thread_ctx;
    zmq::poller_t poller (thread_ctx);

    zmq::fd_t r, w;
    create_nonblocking_fdpair (&r, &w);

    test_events_t events (r, poller);

    zmq::poller_t::handle_t handle = poller.add_fd (r, &events);
    events.set_handle (handle);
    poller.set_pollin (handle);
    poller.start ();

    send_signal (w);

    wait_in_events (events);

    // required cleanup
    close_fdpair (w, r);
}

void test_add_fd_and_remove_by_timer ()
{
    zmq::fd_t r, w;
    create_nonblocking_fdpair (&r, &w);

    zmq::thread_ctx_t thread_ctx;
    zmq::poller_t poller (thread_ctx);

    test_events_t events (r, poller);

    zmq::poller_t::handle_t handle = poller.add_fd (r, &events);
    events.set_handle (handle);

    poller.add_timer (50, &events, 0);
    poller.start ();

    wait_timer_events (events);

    // required cleanup
    close_fdpair (w, r);
}

int main (void)
{
    UNITY_BEGIN ();

    zmq::initialize_network ();
    setup_test_environment ();

    RUN_TEST (test_create);
    RUN_TEST (test_add_fd_and_start_and_receive_data);
    RUN_TEST (test_add_fd_and_remove_by_timer);

    zmq::shutdown_network ();

    return UNITY_END ();
}