diff --git a/fuzz/fuzz_targets/roundtrip_roundtrip.rs b/fuzz/fuzz_targets/roundtrip_roundtrip.rs index 20cdb64..7beb722 100644 --- a/fuzz/fuzz_targets/roundtrip_roundtrip.rs +++ b/fuzz/fuzz_targets/roundtrip_roundtrip.rs @@ -9,6 +9,9 @@ fuzz_target!(|module: wasm_smith::Module| { let orig_bytes = module.to_bytes(); let parsed_module = Module::from_wasm_bytes(&orig_bytes[..]).unwrap(); let roundtrip_bytes = parsed_module.to_wasm_bytes(); + if let Ok(filename) = std::env::var("ROUNDTRIP_WASM_SAVE") { + std::fs::write(filename, &roundtrip_bytes[..]).unwrap(); + } let parsed_roundtrip_module = Module::from_wasm_bytes(&roundtrip_bytes[..]).unwrap(); let _ = parsed_roundtrip_module.to_wasm_bytes(); }); diff --git a/src/backend/final.rs b/src/backend/final.rs index 296d4a3..d3eeb60 100644 --- a/src/backend/final.rs +++ b/src/backend/final.rs @@ -26,6 +26,21 @@ impl<'a, FT: FuncTypeSink> WasmContext<'a, FT> { self.func_type_sink.add_signature(params, results) } + fn find_fallthrough_ty<'b>( + &mut self, + params: &[Type], + mut targets: impl Iterator, + ) -> u32 { + let fallthrough_rets = targets + .find_map(|target| match target { + SerializedBlockTarget::Fallthrough(tys, ..) => Some(&tys[..]), + _ => None, + }) + .unwrap_or(&[]); + + self.create_type(params.to_vec(), fallthrough_rets.to_vec()) + } + fn translate(&mut self, op: &SerializedOperator, locations: &Locations) { log::trace!("translate: {:?}", op); match op { @@ -69,9 +84,13 @@ impl<'a, FT: FuncTypeSink> WasmContext<'a, FT> { ref if_true, ref if_false, } => { - self.wasm.operators.push(wasm_encoder::Instruction::If( - wasm_encoder::BlockType::Empty, - )); + let fallthrough_ty = + self.find_fallthrough_ty(&[], [if_true, if_false].iter().map(|&targ| targ)); + self.wasm + .operators + .push(wasm_encoder::Instruction::If(BlockType::FunctionType( + fallthrough_ty, + ))); self.translate_target(1, if_true, locations); self.wasm.operators.push(wasm_encoder::Instruction::Else); self.translate_target(1, if_false, locations); @@ -81,10 +100,17 @@ impl<'a, FT: FuncTypeSink> WasmContext<'a, FT> { ref targets, ref default, } => { - let ty = self.create_type(vec![Type::I32], vec![]); - for _ in 0..(targets.len() + 2) { + let fallthrough_ty = self.find_fallthrough_ty( + &[Type::I32], + targets.iter().chain(std::iter::once(default)), + ); + self.wasm.operators.push(wasm_encoder::Instruction::Block( + wasm_encoder::BlockType::FunctionType(fallthrough_ty), + )); + let index_ty = self.create_type(vec![Type::I32], vec![]); + for _ in 0..(targets.len() + 1) { self.wasm.operators.push(wasm_encoder::Instruction::Block( - wasm_encoder::BlockType::FunctionType(ty), + wasm_encoder::BlockType::FunctionType(index_ty), )); } @@ -132,12 +158,17 @@ impl<'a, FT: FuncTypeSink> WasmContext<'a, FT> { ) { log::trace!("translate_target: {:?}", target); match target { - &SerializedBlockTarget::Fallthrough(ref ops) => { + &SerializedBlockTarget::Fallthrough(_, ref ops) => { for op in ops { self.translate(op, locations); } + if extra_blocks > 0 { + self.wasm + .operators + .push(wasm_encoder::Instruction::Br((extra_blocks - 1) as u32)); + } } - &SerializedBlockTarget::Branch(branch, ref ops) => { + &SerializedBlockTarget::Branch(branch, _, ref ops) => { for op in ops { self.translate(op, locations); } diff --git a/src/backend/serialize.rs b/src/backend/serialize.rs index a223d0a..c940cab 100644 --- a/src/backend/serialize.rs +++ b/src/backend/serialize.rs @@ -12,6 +12,7 @@ use crate::{ ValueDef, }; use fxhash::FxHashSet; +use wasmparser::Type; /// A Wasm function body with a serialized sequence of operators that /// mirror Wasm opcodes in every way *except* for locals corresponding @@ -24,21 +25,21 @@ pub struct SerializedBody { #[derive(Clone, Debug, PartialEq, Eq)] pub enum SerializedBlockTarget { - Fallthrough(Vec), - Branch(usize, Vec), + Fallthrough(Vec, Vec), + Branch(usize, Vec, Vec), } #[derive(Clone, Debug, PartialEq, Eq)] pub enum SerializedOperator { StartBlock { header: BlockId, - params: Vec<(wasmparser::Type, Value)>, - results: Vec, + params: Vec<(Type, Value)>, + results: Vec, }, StartLoop { header: BlockId, - params: Vec<(wasmparser::Type, Value)>, - results: Vec, + params: Vec<(Type, Value)>, + results: Vec, }, Br(SerializedBlockTarget), BrIf { @@ -119,8 +120,8 @@ impl SerializedBlockTarget { w: &mut W, ) { match self { - &SerializedBlockTarget::Branch(_, ref ops) - | &SerializedBlockTarget::Fallthrough(ref ops) => { + &SerializedBlockTarget::Branch(_, _, ref ops) + | &SerializedBlockTarget::Fallthrough(_, ref ops) => { for op in ops { op.visit_value_locals(r, w); } @@ -246,10 +247,15 @@ impl<'a> SerializedBodyContext<'a> { self.push_value(value, &mut rev_ops); } rev_ops.reverse(); + let tys: Vec = self.f.blocks[target.target] + .params + .iter() + .map(|&(ty, _)| ty) + .collect(); log::trace!(" -> ops: {:?}", rev_ops); match target.relative_branch { - Some(branch) => SerializedBlockTarget::Branch(branch, rev_ops), - None => SerializedBlockTarget::Fallthrough(rev_ops), + Some(branch) => SerializedBlockTarget::Branch(branch, tys, rev_ops), + None => SerializedBlockTarget::Fallthrough(tys, rev_ops), } }) .collect::>(); @@ -283,8 +289,6 @@ impl<'a> SerializedBodyContext<'a> { self.operators.extend(rev_ops); self.operators .push(SerializedOperator::BrTable { targets, default }); - self.operators - .push(SerializedOperator::Operator(Operator::Unreachable)); } &Terminator::Return { ref values, .. } => { let mut rev_ops = vec![];