1
0
Fork 0
forked from AbleOS/ableos
ableos-idl/kernel/src/task.rs
2024-09-13 22:41:31 +01:00

179 lines
4.2 KiB
Rust

use {
alloc::{boxed::Box, sync::Arc},
core::{
future::Future,
pin::Pin,
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
},
crossbeam_queue::SegQueue,
slab::Slab,
};
pub fn yield_now() -> impl Future<Output = ()> {
struct YieldNow(bool);
impl Future for YieldNow {
type Output = ();
#[inline(always)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0 {
Poll::Ready(())
} else {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
YieldNow(false)
}
pub struct Executor<F: Future<Output = ()> + Send> {
tasks: Slab<Task<F>>,
task_queue: Arc<TaskQueue>,
}
impl<F: Future<Output = ()> + Send> Executor<F> {
pub fn new(size: usize) -> Self {
Self {
tasks: Slab::with_capacity(size),
task_queue: Arc::new(TaskQueue::new()),
}
}
#[inline]
pub fn spawn(&mut self, future: F) {
self.task_queue
.queue
.push(self.tasks.insert(Task::new(future)));
}
pub fn run(&mut self) {
let mut task_batch = [0; 32];
let mut batch_len = 0;
loop {
self.task_queue.batch_pop(&mut task_batch, &mut batch_len);
if batch_len == 0 {
if self.task_queue.is_empty() {
break;
} else {
continue;
}
}
for &id in &task_batch[..batch_len] {
if let Some(task) = self.tasks.get_mut(id) {
let waker = task
.waker
.get_or_insert_with(|| TaskWaker::new(id, Arc::clone(&self.task_queue)));
let waker = unsafe { Waker::from_raw(TaskWaker::into_raw_waker(waker)) };
let mut cx = Context::from_waker(&waker);
if let Poll::Ready(()) = task.poll(&mut cx) {
self.tasks.remove(id);
self.task_queue.free_tasks.push(id);
}
}
}
}
}
}
struct Task<F: Future<Output = ()> + Send> {
future: Pin<Box<F>>,
waker: Option<TaskWaker>,
}
impl<F: Future<Output = ()> + Send> Task<F> {
#[inline(always)]
pub fn new(future: F) -> Self {
Self {
future: Box::pin(future),
waker: None,
}
}
#[inline(always)]
fn poll(&mut self, cx: &mut Context) -> Poll<()> {
self.future.as_mut().poll(cx)
}
}
struct TaskWaker {
id: usize,
task_queue: Arc<TaskQueue>,
}
impl TaskWaker {
#[inline(always)]
fn new(id: usize, task_queue: Arc<TaskQueue>) -> Self {
Self { id, task_queue }
}
#[inline(always)]
fn wake(&self) {
self.task_queue.queue.push(self.id);
}
fn into_raw_waker(waker: &TaskWaker) -> RawWaker {
let ptr = waker as *const TaskWaker;
RawWaker::new(ptr.cast(), &VTABLE)
}
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw);
unsafe fn clone_raw(ptr: *const ()) -> RawWaker {
let waker = &*(ptr as *const TaskWaker);
TaskWaker::into_raw_waker(waker)
}
unsafe fn wake_raw(ptr: *const ()) {
let waker = &*(ptr as *const TaskWaker);
waker.wake();
}
unsafe fn wake_by_ref_raw(ptr: *const ()) {
let waker = &*(ptr as *const TaskWaker);
waker.wake();
}
unsafe fn drop_raw(_: *const ()) {}
struct TaskQueue {
queue: SegQueue<usize>,
next_task: usize,
free_tasks: SegQueue<usize>,
}
impl TaskQueue {
fn new() -> Self {
Self {
queue: SegQueue::new(),
next_task: 0,
free_tasks: SegQueue::new(),
}
}
#[inline(always)]
fn batch_pop(&self, output: &mut [usize], len: &mut usize) {
*len = 0;
while let Some(id) = self.queue.pop() {
output[*len] = id;
*len += 1;
if *len == output.len() {
break;
}
}
}
#[inline(always)]
fn is_empty(&self) -> bool {
self.queue.is_empty()
}
}