From 908ad937e16d6ecdc935d6bd21e006c95eda797d Mon Sep 17 00:00:00 2001 From: Graham Kelly Date: Tue, 9 Apr 2024 13:52:02 -0400 Subject: [PATCH] null refs --- src/backend/mod.rs | 4 + src/interp.rs | 410 +++++++++++++++++++++++---------------------- src/ir.rs | 8 +- src/ir/module.rs | 4 +- src/op_traits.rs | 8 +- src/ops.rs | 7 + 6 files changed, 235 insertions(+), 206 deletions(-) diff --git a/src/backend/mod.rs b/src/backend/mod.rs index ae07a45..7d1cdb2 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -945,6 +945,10 @@ impl<'a> WasmFuncBackend<'a> { Operator::RefFunc { func_index } => { Some(wasm_encoder::Instruction::RefFunc(func_index.index() as u32)) } + Operator::RefNull { ty } => { + let h: wasm_encoder::RefType = ty.clone().into(); + Some(wasm_encoder::Instruction::RefNull(h.heap_type)) + } }; if let Some(inst) = inst { diff --git a/src/interp.rs b/src/interp.rs index c3e72cc..3a44fa8 100644 --- a/src/interp.rs +++ b/src/interp.rs @@ -18,7 +18,8 @@ pub struct InterpContext { pub globals: PerEntity, pub fuel: u64, pub trace_handler: Option) -> bool + Send>>, - pub import_hander: Option InterpResult>> + pub import_hander: + Option InterpResult>>, } type MultiVal = SmallVec<[ConstVal; 2]>; @@ -93,115 +94,91 @@ impl InterpContext { pub fn call(&mut self, module: &Module<'_>, mut func: Func, args: &[ConstVal]) -> InterpResult { let mut args = args.to_vec(); - 'redo: loop{ - let body = match &module.funcs[func] { - FuncDecl::Lazy(..) => panic!("Un-expanded function"), - FuncDecl::Compiled(..) => panic!("Already-compiled function"), - FuncDecl::Import(..) => { - let import = &module.imports[func.index()]; - assert_eq!(import.kind, ImportKind::Func(func)); - return self.call_import(&import.name[..], &args); - } - FuncDecl::Body(_, _, body) => body, - FuncDecl::None => panic!("FuncDecl::None in call()"), - }; + 'redo: loop { + let body = match &module.funcs[func] { + FuncDecl::Lazy(..) => panic!("Un-expanded function"), + FuncDecl::Compiled(..) => panic!("Already-compiled function"), + FuncDecl::Import(..) => { + let import = &module.imports[func.index()]; + assert_eq!(import.kind, ImportKind::Func(func)); + return self.call_import(&import.name[..], &args); + } + FuncDecl::Body(_, _, body) => body, + FuncDecl::None => panic!("FuncDecl::None in call()"), + }; - log::trace!( - "Interp: entering func {}:\n{}\n", - func, - body.display_verbose("| ", Some(module)) - ); - log::trace!("args: {:?}", args); + log::trace!( + "Interp: entering func {}:\n{}\n", + func, + body.display_verbose("| ", Some(module)) + ); + log::trace!("args: {:?}", args); - let mut frame = InterpStackFrame { - func, - cur_block: body.entry, - values: HashMap::new(), - }; + let mut frame = InterpStackFrame { + func, + cur_block: body.entry, + values: HashMap::new(), + }; - for (&arg, &(_, blockparam)) in args.iter().zip(body.blocks[body.entry].params.iter()) { - log::trace!("Entry block param {} gets arg value {:?}", blockparam, arg); - frame.values.insert(blockparam, smallvec![arg]); - } - - loop { - self.fuel -= 1; - if self.fuel == 0 { - return InterpResult::OutOfFuel; + for (&arg, &(_, blockparam)) in args.iter().zip(body.blocks[body.entry].params.iter()) { + log::trace!("Entry block param {} gets arg value {:?}", blockparam, arg); + frame.values.insert(blockparam, smallvec![arg]); } - log::trace!("Interpreting block {}", frame.cur_block); - for (inst_idx, &inst) in body.blocks[frame.cur_block].insts.iter().enumerate() { - log::trace!("Evaluating inst {}", inst); - let result = match &body.values[inst] { - &ValueDef::Alias(_) => smallvec![], - &ValueDef::PickOutput(val, idx, _) => { - let val = body.resolve_alias(val); - smallvec![frame.values.get(&val).unwrap()[idx as usize]] - } - &ValueDef::Operator(Operator::Call { function_index }, args, _) => { - let args = body.arg_pool[args] - .iter() - .map(|&arg| { - let arg = body.resolve_alias(arg); - let multivalue = frame.values.get(&arg).unwrap(); - assert_eq!(multivalue.len(), 1); - multivalue[0] - }) - .collect::>(); - let result = self.call(module, function_index, &args[..]); - match result { - InterpResult::Ok(vals) => vals, - _ => return result, + loop { + self.fuel -= 1; + if self.fuel == 0 { + return InterpResult::OutOfFuel; + } + + log::trace!("Interpreting block {}", frame.cur_block); + for (inst_idx, &inst) in body.blocks[frame.cur_block].insts.iter().enumerate() { + log::trace!("Evaluating inst {}", inst); + let result = match &body.values[inst] { + &ValueDef::Alias(_) => smallvec![], + &ValueDef::PickOutput(val, idx, _) => { + let val = body.resolve_alias(val); + smallvec![frame.values.get(&val).unwrap()[idx as usize]] } - } - &ValueDef::Operator(Operator::CallIndirect { table_index, .. }, args, _) => { - let args = body.arg_pool[args] - .iter() - .map(|&arg| { - let arg = body.resolve_alias(arg); - let multivalue = frame.values.get(&arg).unwrap(); - assert_eq!(multivalue.len(), 1); - multivalue[0] - }) - .collect::>(); - let idx = args.last().unwrap().as_u32().unwrap() as usize; - let func = self.tables[table_index].elements[idx]; - let result = self.call(module, func, &args[..args.len() - 1]); - match result { - InterpResult::Ok(vals) => vals, - _ => return result, - } - } - &ValueDef::Operator(ref op, args, _) => { - let args = body.arg_pool[args] - .iter() - .map(|&arg| { - let arg = body.resolve_alias(arg); - let multivalue = frame - .values - .get(&arg) - .ok_or_else(|| format!("Unset SSA value: {}", arg)) - .unwrap(); - assert_eq!(multivalue.len(), 1); - multivalue[0] - }) - .collect::>(); - let result = match const_eval(op, &args[..], Some(self)) { - Some(result) => result, - None => { - log::trace!("const_eval failed on {:?} args {:?}", op, args); - return InterpResult::Trap( - frame.func, - frame.cur_block, - inst_idx as u32, - ); + &ValueDef::Operator(Operator::Call { function_index }, args, _) => { + let args = body.arg_pool[args] + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + let multivalue = frame.values.get(&arg).unwrap(); + assert_eq!(multivalue.len(), 1); + multivalue[0] + }) + .collect::>(); + let result = self.call(module, function_index, &args[..]); + match result { + InterpResult::Ok(vals) => vals, + _ => return result, } - }; - smallvec![result] - } - &ValueDef::Trace(id, args) => { - if let Some(handler) = self.trace_handler.as_ref() { + } + &ValueDef::Operator( + Operator::CallIndirect { table_index, .. }, + args, + _, + ) => { + let args = body.arg_pool[args] + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + let multivalue = frame.values.get(&arg).unwrap(); + assert_eq!(multivalue.len(), 1); + multivalue[0] + }) + .collect::>(); + let idx = args.last().unwrap().as_u32().unwrap() as usize; + let func = self.tables[table_index].elements[idx]; + let result = self.call(module, func, &args[..args.len() - 1]); + match result { + InterpResult::Ok(vals) => vals, + _ => return result, + } + } + &ValueDef::Operator(ref op, args, _) => { let args = body.arg_pool[args] .iter() .map(|&arg| { @@ -214,115 +191,148 @@ impl InterpContext { assert_eq!(multivalue.len(), 1); multivalue[0] }) - .collect::>(); - if !handler(id, args) { - return InterpResult::TraceHandlerQuit; - } + .collect::>(); + let result = match const_eval(op, &args[..], Some(self)) { + Some(result) => result, + None => { + log::trace!("const_eval failed on {:?} args {:?}", op, args); + return InterpResult::Trap( + frame.func, + frame.cur_block, + inst_idx as u32, + ); + } + }; + smallvec![result] } - smallvec![] - } - &ValueDef::None | &ValueDef::Placeholder(..) | &ValueDef::BlockParam(..) => { - unreachable!(); - } - }; + &ValueDef::Trace(id, args) => { + if let Some(handler) = self.trace_handler.as_ref() { + let args = body.arg_pool[args] + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + let multivalue = frame + .values + .get(&arg) + .ok_or_else(|| format!("Unset SSA value: {}", arg)) + .unwrap(); + assert_eq!(multivalue.len(), 1); + multivalue[0] + }) + .collect::>(); + if !handler(id, args) { + return InterpResult::TraceHandlerQuit; + } + } + smallvec![] + } + &ValueDef::None + | &ValueDef::Placeholder(..) + | &ValueDef::BlockParam(..) => { + unreachable!(); + } + }; - log::trace!("Inst {} gets result {:?}", inst, result); - frame.values.insert(inst, result); - } + log::trace!("Inst {} gets result {:?}", inst, result); + frame.values.insert(inst, result); + } - match &body.blocks[frame.cur_block].terminator { - &Terminator::ReturnCallIndirect { - sig, - table, - args: ref args2, - } => { - let args2 = args2 - .iter() - .map(|&arg| { - let arg = body.resolve_alias(arg); - let multivalue = frame.values.get(&arg).unwrap(); - assert_eq!(multivalue.len(), 1); - multivalue[0] - }) - .collect::>(); - let idx = args2.last().unwrap().as_u32().unwrap() as usize; - let fu = self.tables[table].elements[idx]; - func = fu; - args = args2[..args2.len()-1].to_vec(); - continue 'redo; - // let result = self.call(module, func, &args[..args.len() - 1]); - // return result; - } - &Terminator::ReturnCall { func: fu, args: ref args2 } => { - let args2 = args2 - .iter() - .map(|&arg| { - let arg = body.resolve_alias(arg); - let multivalue = frame.values.get(&arg).unwrap(); - assert_eq!(multivalue.len(), 1); - multivalue[0] - }) - .collect::>(); - func = fu; - args = args2; - continue 'redo; - } - &Terminator::None => { - return InterpResult::Trap(frame.func, frame.cur_block, u32::MAX) - } - &Terminator::Unreachable => { - return InterpResult::Trap(frame.func, frame.cur_block, u32::MAX) - } - &Terminator::Br { ref target } => { - frame.apply_target(body, target); - } - &Terminator::CondBr { - cond, - ref if_true, - ref if_false, - } => { - let cond = body.resolve_alias(cond); - let cond = frame.values.get(&cond).unwrap(); - let cond = cond[0].as_u32().unwrap() != 0; - if cond { - frame.apply_target(body, if_true); - } else { - frame.apply_target(body, if_false); + match &body.blocks[frame.cur_block].terminator { + &Terminator::ReturnCallIndirect { + sig, + table, + args: ref args2, + } => { + let args2 = args2 + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + let multivalue = frame.values.get(&arg).unwrap(); + assert_eq!(multivalue.len(), 1); + multivalue[0] + }) + .collect::>(); + let idx = args2.last().unwrap().as_u32().unwrap() as usize; + let fu = self.tables[table].elements[idx]; + func = fu; + args = args2[..args2.len() - 1].to_vec(); + continue 'redo; + // let result = self.call(module, func, &args[..args.len() - 1]); + // return result; } - } - &Terminator::Select { - value, - ref targets, - ref default, - } => { - let value = body.resolve_alias(value); - let value = frame.values.get(&value).unwrap(); - let value = value[0].as_u32().unwrap() as usize; - if value < targets.len() { - frame.apply_target(body, &targets[value]); - } else { - frame.apply_target(body, default); + &Terminator::ReturnCall { + func: fu, + args: ref args2, + } => { + let args2 = args2 + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + let multivalue = frame.values.get(&arg).unwrap(); + assert_eq!(multivalue.len(), 1); + multivalue[0] + }) + .collect::>(); + func = fu; + args = args2; + continue 'redo; + } + &Terminator::None => { + return InterpResult::Trap(frame.func, frame.cur_block, u32::MAX) + } + &Terminator::Unreachable => { + return InterpResult::Trap(frame.func, frame.cur_block, u32::MAX) + } + &Terminator::Br { ref target } => { + frame.apply_target(body, target); + } + &Terminator::CondBr { + cond, + ref if_true, + ref if_false, + } => { + let cond = body.resolve_alias(cond); + let cond = frame.values.get(&cond).unwrap(); + let cond = cond[0].as_u32().unwrap() != 0; + if cond { + frame.apply_target(body, if_true); + } else { + frame.apply_target(body, if_false); + } + } + &Terminator::Select { + value, + ref targets, + ref default, + } => { + let value = body.resolve_alias(value); + let value = frame.values.get(&value).unwrap(); + let value = value[0].as_u32().unwrap() as usize; + if value < targets.len() { + frame.apply_target(body, &targets[value]); + } else { + frame.apply_target(body, default); + } + } + &Terminator::Return { ref values } => { + let values = values + .iter() + .map(|&value| { + let value = body.resolve_alias(value); + frame.values.get(&value).unwrap()[0] + }) + .collect(); + log::trace!("returning from {}: {:?}", func, values); + return InterpResult::Ok(values); } - } - &Terminator::Return { ref values } => { - let values = values - .iter() - .map(|&value| { - let value = body.resolve_alias(value); - frame.values.get(&value).unwrap()[0] - }) - .collect(); - log::trace!("returning from {}: {:?}", func, values); - return InterpResult::Ok(values); } } } } - } fn call_import(&mut self, name: &str, args: &[ConstVal]) -> InterpResult { let mut r = self.import_hander.take().unwrap(); - let rs = r(self,name,args); + let rs = r(self, name, args); self.import_hander = Some(r); return rs; } diff --git a/src/ir.rs b/src/ir.rs index a8b8f84..e89941b 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -27,7 +27,7 @@ impl From for Type { } impl From for Type { fn from(ty: wasmparser::RefType) -> Self { - if ty.is_extern_ref(){ + if ty.is_extern_ref() { return Type::ExternRef; } match ty.type_index() { @@ -49,7 +49,7 @@ impl std::fmt::Display for Type { Type::F64 => write!(f, "f64"), Type::V128 => write!(f, "v128"), Type::FuncRef => write!(f, "funcref"), - Type::ExternRef => write!(f,"externref"), + Type::ExternRef => write!(f, "externref"), Type::TypedFuncRef(nullable, idx) => write!( f, "funcref({}, {})", @@ -68,7 +68,9 @@ impl From for wasm_encoder::ValType { Type::F32 => wasm_encoder::ValType::F32, Type::F64 => wasm_encoder::ValType::F64, Type::V128 => wasm_encoder::ValType::V128, - Type::FuncRef | Type::TypedFuncRef(..) | Type::ExternRef => wasm_encoder::ValType::Ref(ty.into()), + Type::FuncRef | Type::TypedFuncRef(..) | Type::ExternRef => { + wasm_encoder::ValType::Ref(ty.into()) + } } } } diff --git a/src/ir/module.rs b/src/ir/module.rs index 5d9d3d7..28c23c4 100644 --- a/src/ir/module.rs +++ b/src/ir/module.rs @@ -190,9 +190,9 @@ impl<'a> Module<'a> { } pub fn to_wasm_bytes(&self) -> Result> { - backend::compile(self).map(|a|a.finish()) + backend::compile(self).map(|a| a.finish()) } - pub fn to_encoded_module(&self) -> Result{ + pub fn to_encoded_module(&self) -> Result { backend::compile(self) } diff --git a/src/op_traits.rs b/src/op_traits.rs index a5ebd6f..b0cf2a4 100644 --- a/src/op_traits.rs +++ b/src/op_traits.rs @@ -485,7 +485,10 @@ pub fn op_inputs( params.push(Type::TypedFuncRef(true, sig_index.index() as u32)); Ok(params.into()) } - Operator::RefIsNull => Ok(vec![op_stack.context("in getting stack")?.last().unwrap().0].into()), + Operator::RefIsNull => { + Ok(vec![op_stack.context("in getting stack")?.last().unwrap().0].into()) + } + Operator::RefNull { ty } => Ok(Cow::Borrowed(&[])), Operator::RefFunc { .. } => Ok(Cow::Borrowed(&[])), Operator::MemoryCopy { .. } => Ok(Cow::Borrowed(&[Type::I32, Type::I32, Type::I32])), Operator::MemoryFill { .. } => Ok(Cow::Borrowed(&[Type::I32, Type::I32, Type::I32])), @@ -961,6 +964,7 @@ pub fn op_outputs( let ty = module.funcs[*func_index].sig(); Ok(vec![Type::TypedFuncRef(true, ty.index() as u32)].into()) } + Operator::RefNull { ty } => Ok(vec![ty.clone()].into()), } } @@ -1431,6 +1435,7 @@ impl Operator { Operator::CallRef { .. } => &[All], Operator::RefIsNull => &[], Operator::RefFunc { .. } => &[], + Operator::RefNull { ty } => &[], } } @@ -1927,6 +1932,7 @@ impl std::fmt::Display for Operator { Operator::CallRef { sig_index } => write!(f, "call_ref<{}>", sig_index)?, Operator::RefIsNull => write!(f, "ref_is_null")?, Operator::RefFunc { func_index } => write!(f, "ref_func<{}>", func_index)?, + Operator::RefNull { ty } => write!(f, "ref_null<{}>", ty)?, } Ok(()) diff --git a/src/ops.rs b/src/ops.rs index fa677ab..3da0a4b 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -1,6 +1,7 @@ //! Operators. use crate::{entity::EntityRef, Func, Global, Memory, Signature, Table, Type}; +use anyhow::Context; use std::convert::TryFrom; pub use wasmparser::{Ieee32, Ieee64}; @@ -635,6 +636,9 @@ pub enum Operator { sig_index: Signature, }, RefIsNull, + RefNull { + ty: Type, + }, RefFunc { func_index: Func, }, @@ -1289,6 +1293,9 @@ impl<'a, 'b> std::convert::TryFrom<&'b wasmparser::Operator<'a>> for Operator { &wasmparser::Operator::MemoryFill { mem } => Ok(Operator::MemoryFill { mem: Memory::from(mem), }), + &wasmparser::Operator::RefNull { hty } => Ok(Operator::RefNull { + ty: wasmparser::RefType::new(true, hty).unwrap().into(), + }), _ => Err(()), } }