diff --git a/Cargo.lock b/Cargo.lock index 1d304991..d6861c6a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -318,6 +318,7 @@ dependencies = [ "linked_list_allocator", "log", "slab", + "spin 0.9.4", "versioning", ] diff --git a/kernel/Cargo.toml b/kernel/Cargo.toml index d2c0d1cd..7d56e3de 100644 --- a/kernel/Cargo.toml +++ b/kernel/Cargo.toml @@ -7,6 +7,7 @@ version = "0.1.2" linked_list_allocator = "0.9" log = "0.4.14" slab = { version = "0.4", default-features = false } +spin = "0.9" [dependencies.crossbeam-queue] version = "0.3" diff --git a/kernel/src/task.rs b/kernel/src/task.rs index 39f26d15..5bf4eb1e 100644 --- a/kernel/src/task.rs +++ b/kernel/src/task.rs @@ -8,8 +8,40 @@ use core::{ }; use crossbeam_queue::SegQueue; use slab::Slab; +use spin::RwLock; type TaskQueue = Arc>; +type SpawnQueue = Arc>; + +static SPAWN_QUEUE: RwLock> = RwLock::new(None); + +/// Spawn a new task +pub fn spawn(future: impl Future + Send + 'static) { + match &*SPAWN_QUEUE.read() { + Some(s) => s.push(Task::new(future)), + None => panic!("no task executor is running"), + } +} + +/// Forcibly yield a task +pub fn yield_now() -> impl Future { + struct YieldNow(bool); + impl Future for YieldNow { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.0 { + Poll::Ready(()) + } else { + self.0 = true; + cx.waker().wake_by_ref(); + Poll::Pending + } + } + } + + YieldNow(false) +} /// Tasks executor #[derive(Default)] @@ -20,41 +52,63 @@ pub struct Executor { /// Awake tasks' queue queue: TaskQueue, + /// Incoming tasks to enqueue + incoming: SpawnQueue, + /// Wakers wakers: BTreeMap, } impl Executor { - /// Spawn a future - pub fn spawn(&mut self, future: impl Future + 'static) { + /// Spawn a task + pub fn spawn(&mut self, future: impl Future + Send + 'static) { self.queue .push(TaskId(self.tasks.insert(Task::new(future)))); } - /// Run tasks - pub fn run(&mut self) -> ! { - loop { - while let Some(id) = self.queue.pop() { - let task = match self.tasks.get_mut(id.0) { - Some(t) => t, - None => continue, - }; + /// Spin poll loop until it runs out of tasks + pub fn run(&mut self) { + // Assign `self.incoming` to global spawn queue to spawn tasks + // from within + { + let mut spawner = SPAWN_QUEUE.write(); + if spawner.is_some() { + panic!("task executor is already running"); + } - let mut cx = Context::from_waker( - self.wakers - .entry(id) - .or_insert_with(|| TaskWaker::new(id, Arc::clone(&self.queue))), - ); + *spawner = Some(Arc::clone(&self.incoming)); + } - match task.poll(&mut cx) { - Poll::Ready(()) => { - self.tasks.remove(id.0); - self.wakers.remove(&id); - } - Poll::Pending => (), + // Try to get incoming task, if none available, poll + // enqueued one + while let Some(id) = self + .incoming + .pop() + .map(|t| TaskId(self.tasks.insert(t))) + .or_else(|| self.queue.pop()) + { + let task = match self.tasks.get_mut(id.0) { + Some(t) => t, + None => panic!("attempted to get non-extant task with id {}", id.0), + }; + + let mut cx = Context::from_waker( + self.wakers + .entry(id) + .or_insert_with(|| TaskWaker::new(id, Arc::clone(&self.queue))), + ); + + match task.poll(&mut cx) { + Poll::Ready(()) => { + // Task done, unregister + self.tasks.remove(id.0); + self.wakers.remove(&id); } + Poll::Pending => (), } } + + *SPAWN_QUEUE.write() = None; } } @@ -63,12 +117,12 @@ struct TaskId(usize); /// Async task struct Task { - future: Pin>>, + future: Pin + Send>>, } impl Task { /// Create a new task from a future - fn new(future: impl Future + 'static) -> Self { + fn new(future: impl Future + Send + 'static) -> Self { Self { future: Box::pin(future), }