use {
    alloc::{boxed::Box, sync::Arc},
    core::{
        future::Future,
        pin::Pin,
        task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
    },
    crossbeam_queue::SegQueue,
    slab::Slab,
};

pub fn yield_now() -> impl Future<Output = ()> {
    struct YieldNow(bool);
    impl Future for YieldNow {
        type Output = ();

        #[inline(always)]
        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
            if self.0 {
                Poll::Ready(())
            } else {
                self.0 = true;
                cx.waker().wake_by_ref();
                Poll::Pending
            }
        }
    }

    YieldNow(false)
}

pub struct Executor<F: Future<Output = ()> + Send> {
    tasks:      Slab<Task<F>>,
    task_queue: Arc<TaskQueue>,
}

impl<F: Future<Output = ()> + Send> Executor<F> {
    pub fn new(size: usize) -> Self {
        Self {
            tasks:      Slab::with_capacity(size),
            task_queue: Arc::new(TaskQueue::new()),
        }
    }

    #[inline]
    pub fn spawn(&mut self, future: F) {
        self.task_queue
            .queue
            .push(self.tasks.insert(Task::new(future)));
    }

    pub fn run(&mut self) {
        let mut task_batch = [0; 32];
        let mut batch_len = 0;

        loop {
            self.task_queue.batch_pop(&mut task_batch, &mut batch_len);

            if batch_len == 0 {
                if self.task_queue.is_empty() {
                    break;
                } else {
                    continue;
                }
            }

            for &id in &task_batch[..batch_len] {
                if let Some(task) = self.tasks.get_mut(id) {
                    let waker = task
                        .waker
                        .get_or_insert_with(|| TaskWaker::new(id, Arc::clone(&self.task_queue)));

                    let waker = unsafe { Waker::from_raw(TaskWaker::into_raw_waker(waker)) };
                    let mut cx = Context::from_waker(&waker);

                    if let Poll::Ready(()) = task.poll(&mut cx) {
                        self.tasks.remove(id);
                        self.task_queue.free_tasks.push(id);
                    }
                }
            }
        }
    }
}

struct Task<F: Future<Output = ()> + Send> {
    future: Pin<Box<F>>,
    waker:  Option<TaskWaker>,
}

impl<F: Future<Output = ()> + Send> Task<F> {
    #[inline(always)]
    pub fn new(future: F) -> Self {
        Self {
            future: Box::pin(future),
            waker:  None,
        }
    }

    #[inline(always)]
    fn poll(&mut self, cx: &mut Context) -> Poll<()> {
        self.future.as_mut().poll(cx)
    }
}

struct TaskWaker {
    id: usize,
    task_queue: Arc<TaskQueue>,
}

impl TaskWaker {
    #[inline(always)]
    fn new(id: usize, task_queue: Arc<TaskQueue>) -> Self {
        Self { id, task_queue }
    }

    #[inline(always)]
    fn wake(&self) {
        self.task_queue.queue.push(self.id);
    }

    fn into_raw_waker(waker: &TaskWaker) -> RawWaker {
        let ptr = waker as *const TaskWaker;
        RawWaker::new(ptr.cast(), &VTABLE)
    }
}

const VTABLE: RawWakerVTable = RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw);

unsafe fn clone_raw(ptr: *const ()) -> RawWaker {
    let waker = &*(ptr as *const TaskWaker);
    TaskWaker::into_raw_waker(waker)
}

unsafe fn wake_raw(ptr: *const ()) {
    let waker = &*(ptr as *const TaskWaker);
    waker.wake();
}

unsafe fn wake_by_ref_raw(ptr: *const ()) {
    let waker = &*(ptr as *const TaskWaker);
    waker.wake();
}

unsafe fn drop_raw(_: *const ()) {}

struct TaskQueue {
    queue:      SegQueue<usize>,
    next_task:  usize,
    free_tasks: SegQueue<usize>,
}

impl TaskQueue {
    fn new() -> Self {
        Self {
            queue:      SegQueue::new(),
            next_task:  0,
            free_tasks: SegQueue::new(),
        }
    }

    #[inline(always)]
    fn batch_pop(&self, output: &mut [usize], len: &mut usize) {
        *len = 0;
        while let Some(id) = self.queue.pop() {
            output[*len] = id;
            *len += 1;
            if *len == output.len() {
                break;
            }
        }
    }

    #[inline(always)]
    fn is_empty(&self) -> bool {
        self.queue.is_empty()
    }
}