ableos/kernel/src/task.rs

217 lines
5.7 KiB
Rust

use {
alloc::{
boxed::Box,
collections::{BTreeMap, BTreeSet},
sync::Arc,
},
core::{
future::Future,
pin::Pin,
sync::atomic::{AtomicBool, Ordering},
task::{Context, ContextBuilder, Poll, RawWaker, RawWakerVTable, Waker},
},
crossbeam_queue::SegQueue,
slab::Slab,
};
pub fn yield_now() -> impl Future<Output = ()> {
struct YieldNow(bool);
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0 {
Poll::Ready(())
} else {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
YieldNow(false)
}
pub trait Process: Future<Output = ()> + Send {}
impl<T: Future<Output = ()> + Send> Process for T {}
pub struct Executor {
tasks: Slab<Task>,
task_queue: Arc<SegQueue<usize>>,
interrupt_lookup: [Option<usize>; u8::MAX as usize],
buffer_lookup: BTreeMap<usize, BTreeSet<usize>>,
}
impl Executor {
pub fn new() -> Self {
Self {
tasks: Slab::new(),
task_queue: Arc::new(SegQueue::new()),
interrupt_lookup: [None; u8::MAX as usize],
buffer_lookup: BTreeMap::new(),
}
}
pub fn spawn(&mut self, future: Pin<Box<dyn Process>>) -> usize {
let id = self.tasks.insert(Task::new(future));
self.task_queue.push(id);
id
}
pub fn pause(&self, id: usize) {
if let Some(task) = self.tasks.get(id) {
task.set_paused(true);
}
}
pub fn unpause(&self, id: usize) {
if let Some(task) = self.tasks.get(id) {
task.set_paused(false);
self.task_queue.push(id);
}
}
pub fn interrupt_subscribe(&mut self, pid: usize, interrupt_type: u8) {
self.pause(pid);
self.interrupt_lookup[interrupt_type as usize] = Some(pid);
}
pub fn buffer_subscribe(&mut self, pid: usize, buffer_id: usize) {
self.pause(pid);
if let Some(buf) = self.buffer_lookup.get_mut(&buffer_id) {
buf.insert(pid);
} else {
self.buffer_lookup.insert(buffer_id, BTreeSet::from([pid]));
}
}
pub fn run(&mut self) {
let mut task_batch = [0; 32];
loop {
let mut batch_len = 0;
while let Some(id) = self.task_queue.pop() {
task_batch[batch_len] = id;
batch_len += 1;
if batch_len == task_batch.len() {
break;
}
}
if batch_len == 0 {
// break;
continue;
}
for &(mut id) in &task_batch[..batch_len] {
if let Some(task) = self.tasks.get_mut(id) {
if task.is_paused() {
continue;
}
let waker = create_waker(id, Arc::clone(&self.task_queue));
let mut cx = ContextBuilder::from_waker(&waker).ext(&mut id).build();
if let Poll::Ready(()) = task.poll(&mut cx) {
self.tasks.remove(id);
self.interrupt_lookup.map(|pid| {
if let Some(pid) = pid {
if pid == id {
return None;
}
}
return pid;
});
self.buffer_lookup.iter_mut().for_each(|(_, pid_set)| {
pid_set.remove(&id);
});
}
}
}
}
}
pub fn send_interrupt(&self, interrupt: u8) {
let id = self.interrupt_lookup[interrupt as usize];
if let Some(id) = id {
self.unpause(id);
}
}
pub fn send_buffer(&self, id: usize) {
if let Some(buf) = self.buffer_lookup.get(&id) {
buf.iter().for_each(|pid| self.unpause(*pid));
}
}
}
struct Task {
future: Pin<Box<dyn Process>>,
paused: AtomicBool,
}
impl Task {
fn new(future: Pin<Box<dyn Process>>) -> Self {
Self {
future,
paused: AtomicBool::new(false),
}
}
fn poll(&mut self, cx: &mut Context) -> Poll<()> {
self.future.as_mut().poll(cx)
}
fn is_paused(&self) -> bool {
self.paused.load(Ordering::Acquire)
}
fn set_paused(&self, paused: bool) {
self.paused.store(paused, Ordering::Release)
}
}
fn create_waker(task_id: usize, task_queue: Arc<SegQueue<usize>>) -> Waker {
let data = Box::new(TaskWaker {
task_id,
task_queue,
});
let raw_waker = RawWaker::new(Box::into_raw(data) as *const (), &VTABLE);
unsafe { Waker::from_raw(raw_waker) }
}
#[derive(Clone)]
struct TaskWaker {
task_id: usize,
task_queue: Arc<SegQueue<usize>>,
}
impl TaskWaker {
fn wake(&self) {
self.task_queue.push(self.task_id);
}
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw);
unsafe fn clone_raw(ptr: *const ()) -> RawWaker {
let task_waker = Box::from_raw(ptr as *mut TaskWaker);
let raw_waker = RawWaker::new(Box::into_raw(task_waker.clone()) as *const (), &VTABLE);
raw_waker
}
unsafe fn wake_raw(ptr: *const ()) {
let task_waker = Box::from_raw(ptr as *mut TaskWaker);
task_waker.wake();
}
unsafe fn wake_by_ref_raw(ptr: *const ()) {
let task_waker = &*(ptr as *const TaskWaker);
task_waker.wake();
}
unsafe fn drop_raw(ptr: *const ()) {
drop(Box::from_raw(ptr as *mut TaskWaker));
}