use { alloc::{ boxed::Box, collections::{BTreeMap, BTreeSet}, sync::Arc, }, core::{ future::Future, pin::Pin, sync::atomic::{AtomicBool, Ordering}, task::{Context, ContextBuilder, Poll, RawWaker, RawWakerVTable, Waker}, }, crossbeam_queue::SegQueue, slab::Slab, }; pub fn yield_now() -> impl Future { struct YieldNow(bool); impl Future for YieldNow { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.0 { Poll::Ready(()) } else { self.0 = true; cx.waker().wake_by_ref(); Poll::Pending } } } YieldNow(false) } pub trait Process: Future + Send {} impl + Send> Process for T {} pub struct Executor { tasks: Slab, task_queue: Arc>, interrupt_lookup: [Option; u8::MAX as usize], buffer_lookup: BTreeMap>, } impl Executor { pub fn new() -> Self { Self { tasks: Slab::new(), task_queue: Arc::new(SegQueue::new()), interrupt_lookup: [None; u8::MAX as usize], buffer_lookup: BTreeMap::new(), } } pub fn spawn(&mut self, future: Pin>) -> usize { let id = self.tasks.insert(Task::new(future)); self.task_queue.push(id); id } pub fn pause(&self, id: usize) { if let Some(task) = self.tasks.get(id) { task.set_paused(true); } } pub fn unpause(&self, id: usize) { if let Some(task) = self.tasks.get(id) { task.set_paused(false); self.task_queue.push(id); } } pub fn interrupt_subscribe(&mut self, pid: usize, interrupt_type: u8) { self.pause(pid); self.interrupt_lookup[interrupt_type as usize] = Some(pid); } pub fn buffer_subscribe(&mut self, pid: usize, buffer_id: usize) { self.pause(pid); if let Some(buf) = self.buffer_lookup.get_mut(&buffer_id) { buf.insert(pid); } else { self.buffer_lookup.insert(buffer_id, BTreeSet::from([pid])); } } pub fn run(&mut self) { let mut task_batch = [0; 32]; loop { let mut batch_len = 0; while let Some(id) = self.task_queue.pop() { task_batch[batch_len] = id; batch_len += 1; if batch_len == task_batch.len() { break; } } if batch_len == 0 { // break; continue; } for &(mut id) in &task_batch[..batch_len] { if let Some(task) = self.tasks.get_mut(id) { if task.is_paused() { continue; } let waker = create_waker(id, Arc::clone(&self.task_queue)); let mut cx = ContextBuilder::from_waker(&waker).ext(&mut id).build(); if let Poll::Ready(()) = task.poll(&mut cx) { self.tasks.remove(id); self.interrupt_lookup.map(|pid| { if let Some(pid) = pid { if pid == id { return None; } } return pid; }); self.buffer_lookup.iter_mut().for_each(|(_, pid_set)| { pid_set.remove(&id); }); } } } } } pub fn send_interrupt(&self, interrupt: u8) { let id = self.interrupt_lookup[interrupt as usize]; if let Some(id) = id { self.unpause(id); } } pub fn send_buffer(&self, id: usize) { if let Some(buf) = self.buffer_lookup.get(&id) { buf.iter().for_each(|pid| self.unpause(*pid)); } } } struct Task { future: Pin>, paused: AtomicBool, } impl Task { fn new(future: Pin>) -> Self { Self { future, paused: AtomicBool::new(false), } } fn poll(&mut self, cx: &mut Context) -> Poll<()> { self.future.as_mut().poll(cx) } fn is_paused(&self) -> bool { self.paused.load(Ordering::Acquire) } fn set_paused(&self, paused: bool) { self.paused.store(paused, Ordering::Release) } } fn create_waker(task_id: usize, task_queue: Arc>) -> Waker { let data = Box::new(TaskWaker { task_id, task_queue, }); let raw_waker = RawWaker::new(Box::into_raw(data) as *const (), &VTABLE); unsafe { Waker::from_raw(raw_waker) } } #[derive(Clone)] struct TaskWaker { task_id: usize, task_queue: Arc>, } impl TaskWaker { fn wake(&self) { self.task_queue.push(self.task_id); } } const VTABLE: RawWakerVTable = RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw); unsafe fn clone_raw(ptr: *const ()) -> RawWaker { let task_waker = Box::from_raw(ptr as *mut TaskWaker); let raw_waker = RawWaker::new(Box::into_raw(task_waker.clone()) as *const (), &VTABLE); raw_waker } unsafe fn wake_raw(ptr: *const ()) { let task_waker = Box::from_raw(ptr as *mut TaskWaker); task_waker.wake(); } unsafe fn wake_by_ref_raw(ptr: *const ()) { let task_waker = &*(ptr as *const TaskWaker); task_waker.wake(); } unsafe fn drop_raw(ptr: *const ()) { drop(Box::from_raw(ptr as *mut TaskWaker)); }