From 87e06c1242ba6fe2deec5c4689af87ebf107df33 Mon Sep 17 00:00:00 2001 From: Graham Kelly Date: Mon, 1 Jul 2024 07:37:03 -0400 Subject: [PATCH] mutable module interp --- src/bin/waffle-util.rs | 4 +- src/interp.rs | 110 ++++++++++++++++++++++++++++++++++------- 2 files changed, 93 insertions(+), 21 deletions(-) diff --git a/src/bin/waffle-util.rs b/src/bin/waffle-util.rs index 403e32f..d85b5e5 100644 --- a/src/bin/waffle-util.rs +++ b/src/bin/waffle-util.rs @@ -124,7 +124,7 @@ fn main() -> Result<()> { let mut ctx = InterpContext::new(&module)?; debug!("Calling start function"); if let Some(start) = module.start_func { - ctx.call(&module, start, &[]).ok().unwrap(); + ctx.call(&mut module, start, &[]).ok().unwrap(); } // Find a function called `_start`, if any. if let Some(waffle::Export { @@ -133,7 +133,7 @@ fn main() -> Result<()> { }) = module.exports.iter().find(|e| &e.name == "_start") { debug!("Calling _start"); - ctx.call(&module, *func, &[]).ok().unwrap(); + ctx.call(&mut module, *func, &[]).ok().unwrap(); } } } diff --git a/src/interp.rs b/src/interp.rs index 91a4500..ca6692d 100644 --- a/src/interp.rs +++ b/src/interp.rs @@ -6,6 +6,7 @@ use crate::ops::Operator; use smallvec::{smallvec, SmallVec}; use std::collections::HashMap; +use std::sync::Arc; mod wasi; @@ -18,8 +19,7 @@ pub struct InterpContext { pub globals: PerEntity, pub fuel: u64, pub trace_handler: Option) -> bool + Send>>, - pub import_hander: - Option InterpResult>>, + pub import_hander: Arc, &str, &[ConstVal]) -> InterpResult>, } type MultiVal = SmallVec<[ConstVal; 2]>; @@ -88,11 +88,11 @@ impl InterpContext { globals, fuel: u64::MAX, trace_handler: None, - import_hander: None, + import_hander: Arc::new(|_, _, _,_| InterpResult::TraceHandlerQuit), }) } - pub fn call(&mut self, module: &Module<'_>, mut func: Func, args: &[ConstVal]) -> InterpResult { + pub fn call(&mut self, module: &mut Module<'_>, mut func: Func, args: &[ConstVal]) -> InterpResult { let mut args = args.to_vec(); 'redo: loop { let body = match &module.funcs[func] { @@ -101,9 +101,9 @@ impl InterpContext { FuncDecl::Import(..) => { let import = &module.imports[func.index()]; assert_eq!(import.kind, ImportKind::Func(func)); - return self.call_import(&import.name[..], &args); + return self.call_import(module,&import.name[..].to_owned(), &args); } - FuncDecl::Body(_, _, body) => body, + FuncDecl::Body(_, _, body) => body.clone(), FuncDecl::None => panic!("FuncDecl::None in call()"), }; @@ -178,6 +178,25 @@ impl InterpContext { _ => return result, } } + &ValueDef::Operator(Operator::CallRef { .. }, args, _) => { + let args = body.arg_pool[args] + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + let multivalue = frame.values.get(&arg).unwrap(); + assert_eq!(multivalue.len(), 1); + multivalue[0] + }) + .collect::>(); + let ConstVal::Ref(Some(func)) = args.last().unwrap() else { + return InterpResult::TraceHandlerQuit; + }; + let result = self.call(module, *func, &args[..args.len() - 1]); + match result { + InterpResult::Ok(vals) => vals, + _ => return result, + } + } &ValueDef::Operator(ref op, args, _) => { let args = body.arg_pool[args] .iter() @@ -284,7 +303,7 @@ impl InterpContext { return InterpResult::Trap(frame.func, frame.cur_block, u32::MAX) } &Terminator::Br { ref target } => { - frame.apply_target(body, target); + frame.apply_target(&body, target); } &Terminator::CondBr { cond, @@ -295,9 +314,9 @@ impl InterpContext { let cond = frame.values.get(&cond).unwrap(); let cond = cond[0].as_u32().unwrap() != 0; if cond { - frame.apply_target(body, if_true); + frame.apply_target(&body, if_true); } else { - frame.apply_target(body, if_false); + frame.apply_target(&body, if_false); } } &Terminator::Select { @@ -309,9 +328,9 @@ impl InterpContext { let value = frame.values.get(&value).unwrap(); let value = value[0].as_u32().unwrap() as usize; if value < targets.len() { - frame.apply_target(body, &targets[value]); + frame.apply_target(&body, &targets[value]); } else { - frame.apply_target(body, default); + frame.apply_target(&body, default); } } &Terminator::Return { ref values } => { @@ -325,16 +344,35 @@ impl InterpContext { log::trace!("returning from {}: {:?}", func, values); return InterpResult::Ok(values); } - Terminator::ReturnCallRef { sig, args } => todo!(), + Terminator::ReturnCallRef { + sig, + args: ref args2, + } => { + let args2 = args2 + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + let multivalue = frame.values.get(&arg).unwrap(); + assert_eq!(multivalue.len(), 1); + multivalue[0] + }) + .collect::>(); + let ConstVal::Ref(Some(fu)) = args.last().unwrap() else { + return InterpResult::TraceHandlerQuit; + }; + func = *fu; + args = args2[..args2.len() - 1].to_vec(); + continue 'redo; + } } } } } - fn call_import(&mut self, name: &str, args: &[ConstVal]) -> InterpResult { - let mut r = self.import_hander.take().unwrap(); - let rs = r(self, name, args); - self.import_hander = Some(r); + fn call_import(&mut self,module: &mut Module<'_>, name: &str, args: &[ConstVal]) -> InterpResult { + let mut r = self.import_hander.clone(); + let rs = r(self, module,name, args); + // self.import_hander = Some(r); return rs; } } @@ -388,6 +426,7 @@ pub enum ConstVal { I64(u64), F32(u32), F64(u64), + Ref(Option), #[default] None, } @@ -995,9 +1034,32 @@ pub fn const_eval( ConstVal::None }), - (Operator::TableGet { .. }, _) - | (Operator::TableSet { .. }, _) - | (Operator::TableGrow { .. }, _) => None, + (Operator::TableGet { table_index }, [ConstVal::I32(i)]) => ctx.and_then(|global| { + Some(ConstVal::Ref( + global.tables[*table_index] + .elements + .get(*i as usize) + .and_then(|x| { + if *x == Func::invalid() { + None + } else { + Some(*x) + } + }), + )) + }), + (Operator::TableSet { table_index }, [ConstVal::I32(i), ConstVal::Ref(r)]) => { + ctx.and_then(|global| { + global.tables[*table_index].elements[*i as usize] = r.unwrap_or_default(); + Some(ConstVal::I32(0)) + }) + } + (Operator::TableGrow { table_index }, [ConstVal::I32(i)]) => ctx.and_then(|global| { + global.tables[*table_index] + .elements + .extend((0..*i).map(|a| Func::default())); + Some(ConstVal::I32(0)) + }), (Operator::TableSize { table_index }, []) => { ctx.map(|global| ConstVal::I32(global.tables[*table_index].elements.len() as u32)) @@ -1232,6 +1294,16 @@ pub fn const_eval( write_u64(&mut global.memories[memory.memory], addr, *data); Some(ConstVal::None) }), + (Operator::RefFunc { func_index }, []) => Some(ConstVal::Ref(Some(*func_index))), + ( + Operator::RefNull { + ty: Type::FuncRef | Type::TypedFuncRef { .. }, + }, + [], + ) => Some(ConstVal::Ref(None)), + (Operator::RefIsNull, [ConstVal::Ref(r)]) => { + Some(ConstVal::I32(if r.is_none() { 1 } else { 0 })) + } (_, args) if args.iter().any(|&arg| arg == ConstVal::None) => None, _ => None, }