ableos/kernel/src/task.rs

150 lines
3.6 KiB
Rust

//! Async task and executor
use alloc::{boxed::Box, collections::BTreeMap, sync::Arc, task::Wake};
use core::{
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
};
use crossbeam_queue::SegQueue;
use slab::Slab;
use spin::RwLock;
type TaskQueue = Arc<SegQueue<TaskId>>;
type SpawnQueue = Arc<SegQueue<Task>>;
static SPAWN_QUEUE: RwLock<Option<SpawnQueue>> = RwLock::new(None);
/// Spawn a new task
pub fn spawn(future: impl Future<Output = ()> + Send + 'static) {
match &*SPAWN_QUEUE.read() {
Some(s) => s.push(Task::new(future)),
None => panic!("no task executor is running"),
}
}
/// Forcibly yield a task
pub fn yield_now() -> impl Future<Output = ()> {
struct YieldNow(bool);
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0 {
Poll::Ready(())
} else {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
YieldNow(false)
}
/// Tasks executor
#[derive(Default)]
pub struct Executor {
/// All spawned tasks' stash
tasks: Slab<Task>,
/// Awake tasks' queue
queue: TaskQueue,
/// Incoming tasks to enqueue
incoming: SpawnQueue,
/// Wakers
wakers: BTreeMap<TaskId, Waker>,
}
impl Executor {
/// Spawn a task
pub fn spawn(&mut self, future: impl Future<Output = ()> + Send + 'static) {
self.queue
.push(TaskId(self.tasks.insert(Task::new(future))));
}
/// Spin poll loop until it runs out of tasks
pub fn run(&mut self) {
// Assign `self.incoming` to global spawn queue to spawn tasks
// from within
{
let mut spawner = SPAWN_QUEUE.write();
if spawner.is_some() {
panic!("task executor is already running");
}
*spawner = Some(Arc::clone(&self.incoming));
}
// Try to get incoming task, if none available, poll
// enqueued one
while let Some(id) = self
.incoming
.pop()
.map(|t| TaskId(self.tasks.insert(t)))
.or_else(|| self.queue.pop())
{
let Some(task) = self.tasks.get_mut(id.0) else {
panic!("attempted to get non-extant task with id {}", id.0)
};
let mut cx = Context::from_waker(self.wakers.entry(id).or_insert_with(|| {
Waker::from(Arc::new(TaskWaker {
id,
queue: Arc::clone(&self.queue),
}))
}));
match task.poll(&mut cx) {
Poll::Ready(()) => {
// Task done, unregister
self.tasks.remove(id.0);
self.wakers.remove(&id);
}
Poll::Pending => (),
}
}
*SPAWN_QUEUE.write() = None;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
struct TaskId(usize);
/// Async task
struct Task {
future: Pin<Box<dyn Future<Output = ()> + Send>>,
}
impl Task {
/// Create a new task from a future
fn new(future: impl Future<Output = ()> + Send + 'static) -> Self {
Self {
future: Box::pin(future),
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<()> {
self.future.as_mut().poll(cx)
}
}
struct TaskWaker {
id: TaskId,
queue: TaskQueue,
}
impl Wake for TaskWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
self.queue.push(self.id);
}
}