From 685c6f0b20723379fede80547b11963236bf90e2 Mon Sep 17 00:00:00 2001
From: koniifer <koniifer@proton.me>
Date: Tue, 26 Nov 2024 19:32:19 +0000
Subject: [PATCH] task scheduler weirdness

---
 kernel/src/holeybytes/ecah.rs |   3 +-
 kernel/src/holeybytes/mod.rs  |   7 +-
 kernel/src/kmain.rs           |   2 +-
 kernel/src/lib.rs             |   7 +-
 kernel/src/task.rs            | 154 ++++++++++++++++------------------
 5 files changed, 85 insertions(+), 88 deletions(-)

diff --git a/kernel/src/holeybytes/ecah.rs b/kernel/src/holeybytes/ecah.rs
index 67576f47..a34f683a 100644
--- a/kernel/src/holeybytes/ecah.rs
+++ b/kernel/src/holeybytes/ecah.rs
@@ -33,7 +33,7 @@ unsafe fn x86_in<T: x86_64::instructions::port::PortRead>(address: u16) -> T {
 }
 
 #[inline(always)]
-pub fn handler(vm: &mut Vm) {
+pub fn handler(vm: &mut Vm, pid: &usize) {
     let ecall_number = vm.registers[2].cast::<u64>();
 
     match ecall_number {
@@ -209,7 +209,6 @@ pub fn handler(vm: &mut Vm) {
             let buffer_id = vm.registers[3].cast::<u64>();
             let map_ptr = vm.registers[4].cast::<u64>();
             let max_length = vm.registers[5].cast::<u64>();
-
             let mut buffs = IPC_BUFFERS.lock();
             let buff: &mut IpcBuffer = match buffs.get_mut(&buffer_id) {
                 Some(buff) => buff,
diff --git a/kernel/src/holeybytes/mod.rs b/kernel/src/holeybytes/mod.rs
index 2c4dbab4..ab68d04f 100644
--- a/kernel/src/holeybytes/mod.rs
+++ b/kernel/src/holeybytes/mod.rs
@@ -65,7 +65,12 @@ impl<'p> Future for ExecThread {
                 return Poll::Ready(Err(err));
             }
             Ok(VmRunOk::End) => return Poll::Ready(Ok(())),
-            Ok(VmRunOk::Ecall) => ecah::handler(&mut self.vm),
+            Ok(VmRunOk::Ecall) => ecah::handler(
+                &mut self.vm,
+                cx.ext()
+                    .downcast_ref()
+                    .expect("PID did not exist in Context"),
+            ),
             Ok(VmRunOk::Timer) => (),
             Ok(VmRunOk::Breakpoint) => {
                 log::error!(
diff --git a/kernel/src/kmain.rs b/kernel/src/kmain.rs
index a9f7bc61..928dbe57 100644
--- a/kernel/src/kmain.rs
+++ b/kernel/src/kmain.rs
@@ -122,7 +122,7 @@ pub fn kmain(_cmdline: &str, boot_modules: BootModules) -> ! {
             if cmd_len > 0 {
                 thr.set_arguments(cmd.as_ptr() as u64, cmd_len);
             }
-            executor.spawn(Box::pin(async move {
+            executor.spawn(Box::pin(async {
                 if let Err(e) = thr.await {
                     log::error!("{e:?}");
                 }
diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs
index 52956ecd..cf68a83d 100644
--- a/kernel/src/lib.rs
+++ b/kernel/src/lib.rs
@@ -10,9 +10,11 @@
     abi_x86_interrupt,
     lazy_get,
     alloc_error_handler,
+    local_waker,
+    context_ext,
     ptr_sub_ptr,
     naked_functions,
-    pointer_is_aligned_to,
+    pointer_is_aligned_to
 )]
 #![allow(dead_code, internal_features, static_mut_refs)]
 extern crate alloc;
@@ -35,8 +37,7 @@ mod utils;
 // #[cfg(feature = "tests")]
 mod ktest;
 
-use alloc::string::ToString;
-use versioning::Version;
+use {alloc::string::ToString, versioning::Version};
 
 /// Kernel's version
 pub const VERSION: Version = Version {
diff --git a/kernel/src/task.rs b/kernel/src/task.rs
index 3326871b..01c6688c 100644
--- a/kernel/src/task.rs
+++ b/kernel/src/task.rs
@@ -3,7 +3,8 @@ use {
     core::{
         future::Future,
         pin::Pin,
-        task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
+        sync::atomic::{AtomicBool, Ordering},
+        task::{Context, ContextBuilder, Poll, RawWaker, RawWakerVTable, Waker},
     },
     crossbeam_queue::SegQueue,
     slab::Slab,
@@ -14,7 +15,6 @@ pub fn yield_now() -> impl Future<Output = ()> {
     impl Future for YieldNow {
         type Output = ();
 
-        #[inline(always)]
         fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
             if self.0 {
                 Poll::Ready(())
@@ -29,53 +29,70 @@ pub fn yield_now() -> impl Future<Output = ()> {
     YieldNow(false)
 }
 
+pub trait Process: Future<Output = ()> + Send {}
+impl<T: Future<Output = ()> + Send> Process for T {}
+
 pub struct Executor {
     tasks:      Slab<Task>,
-    task_queue: Arc<TaskQueue>,
+    task_queue: Arc<SegQueue<usize>>,
 }
 
 impl Executor {
     pub fn new() -> Self {
         Self {
             tasks:      Slab::new(),
-            task_queue: Arc::new(TaskQueue::new()),
+            task_queue: Arc::new(SegQueue::new()),
         }
     }
 
-    #[inline]
-    pub fn spawn(&mut self, future: Pin<Box<dyn Future<Output = ()> + Send>>) -> usize {
+    pub fn spawn(&mut self, future: Pin<Box<dyn Process>>) -> usize {
         let id = self.tasks.insert(Task::new(future));
-        self.task_queue.queue.push(id);
+        self.task_queue.push(id);
+
         id
     }
 
+    pub fn pause(&mut self, id: usize) {
+        if let Some(task) = self.tasks.get(id) {
+            task.set_paused(true);
+        }
+    }
+
+    pub fn unpause(&mut self, id: usize) {
+        if let Some(task) = self.tasks.get(id) {
+            task.set_paused(false);
+            self.task_queue.push(id);
+        }
+    }
+
     pub fn run(&mut self) {
         let mut task_batch = [0; 32];
-        let mut batch_len = 0;
-
         loop {
-            self.task_queue.batch_pop(&mut task_batch, &mut batch_len);
+            let mut batch_len = 0;
 
-            if batch_len == 0 {
-                if self.task_queue.is_empty() {
+            while let Some(id) = self.task_queue.pop() {
+                task_batch[batch_len] = id;
+                batch_len += 1;
+                if batch_len == task_batch.len() {
                     break;
-                } else {
-                    continue;
                 }
             }
 
-            for &id in &task_batch[..batch_len] {
-                if let Some(task) = self.tasks.get_mut(id) {
-                    let waker = task
-                        .waker
-                        .get_or_insert_with(|| TaskWaker::new(id, Arc::clone(&self.task_queue)));
+            if batch_len == 0 {
+                break;
+            }
 
-                    let waker = unsafe { Waker::from_raw(TaskWaker::into_raw_waker(waker)) };
-                    let mut cx = Context::from_waker(&waker);
+            for &(mut id) in &task_batch[..batch_len] {
+                if let Some(task) = self.tasks.get_mut(id) {
+                    if task.is_paused() {
+                        continue;
+                    }
+
+                    let waker = create_waker(id, Arc::clone(&self.task_queue));
+                    let mut cx = ContextBuilder::from_waker(&waker).ext(&mut id).build();
 
                     if let Poll::Ready(()) = task.poll(&mut cx) {
                         self.tasks.remove(id);
-                        self.task_queue.free_tasks.push(id);
                     }
                 }
             }
@@ -84,95 +101,70 @@ impl Executor {
 }
 
 struct Task {
-    future: Pin<Box<dyn Future<Output = ()> + Send>>,
-    waker:  Option<TaskWaker>,
+    future: Pin<Box<dyn Process>>,
+    paused: AtomicBool,
 }
 
 impl Task {
-    #[inline(always)]
-    pub fn new(future: Pin<Box<dyn Future<Output = ()> + Send>>) -> Self {
+    fn new(future: Pin<Box<dyn Process>>) -> Self {
         Self {
             future,
-            waker: None,
+            paused: AtomicBool::new(false),
         }
     }
 
-    #[inline(always)]
     fn poll(&mut self, cx: &mut Context) -> Poll<()> {
         self.future.as_mut().poll(cx)
     }
+
+    fn is_paused(&self) -> bool {
+        self.paused.load(Ordering::Acquire)
+    }
+
+    fn set_paused(&self, paused: bool) {
+        self.paused.store(paused, Ordering::Release)
+    }
 }
 
+fn create_waker(task_id: usize, task_queue: Arc<SegQueue<usize>>) -> Waker {
+    let data = Box::new(TaskWaker {
+        task_id,
+        task_queue,
+    });
+    let raw_waker = RawWaker::new(Box::into_raw(data) as *const (), &VTABLE);
+    unsafe { Waker::from_raw(raw_waker) }
+}
+
+#[derive(Clone)]
 struct TaskWaker {
-    id: usize,
-    task_queue: Arc<TaskQueue>,
+    task_id:    usize,
+    task_queue: Arc<SegQueue<usize>>,
 }
 
 impl TaskWaker {
-    #[inline(always)]
-    fn new(id: usize, task_queue: Arc<TaskQueue>) -> Self {
-        Self { id, task_queue }
-    }
-
-    #[inline(always)]
     fn wake(&self) {
-        self.task_queue.queue.push(self.id);
-    }
-
-    fn into_raw_waker(waker: &TaskWaker) -> RawWaker {
-        let ptr = waker as *const TaskWaker;
-        RawWaker::new(ptr.cast(), &VTABLE)
+        self.task_queue.push(self.task_id);
     }
 }
 
 const VTABLE: RawWakerVTable = RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw);
 
 unsafe fn clone_raw(ptr: *const ()) -> RawWaker {
-    let waker = &*(ptr as *const TaskWaker);
-    TaskWaker::into_raw_waker(waker)
+    let task_waker = Box::from_raw(ptr as *mut TaskWaker);
+    let raw_waker = RawWaker::new(Box::into_raw(task_waker.clone()) as *const (), &VTABLE);
+    raw_waker
 }
 
 unsafe fn wake_raw(ptr: *const ()) {
-    let waker = &*(ptr as *const TaskWaker);
-    waker.wake();
+    let task_waker = Box::from_raw(ptr as *mut TaskWaker);
+    task_waker.wake();
 }
 
 unsafe fn wake_by_ref_raw(ptr: *const ()) {
-    let waker = &*(ptr as *const TaskWaker);
-    waker.wake();
+    let task_waker = &*(ptr as *const TaskWaker);
+    task_waker.wake();
 }
 
-unsafe fn drop_raw(_: *const ()) {}
-
-struct TaskQueue {
-    queue:      SegQueue<usize>,
-    next_task:  usize,
-    free_tasks: SegQueue<usize>,
-}
-
-impl TaskQueue {
-    fn new() -> Self {
-        Self {
-            queue:      SegQueue::new(),
-            next_task:  0,
-            free_tasks: SegQueue::new(),
-        }
-    }
-
-    #[inline(always)]
-    fn batch_pop(&self, output: &mut [usize], len: &mut usize) {
-        *len = 0;
-        while let Some(id) = self.queue.pop() {
-            output[*len] = id;
-            *len += 1;
-            if *len == output.len() {
-                break;
-            }
-        }
-    }
-
-    #[inline(always)]
-    fn is_empty(&self) -> bool {
-        self.queue.is_empty()
-    }
+unsafe fn drop_raw(ptr: *const ()) {
+    drop(Box::from_raw(ptr as *mut TaskWaker));
 }