diff --git a/fuzz/fuzz_targets/roundtrip.rs b/fuzz/fuzz_targets/roundtrip.rs index 81467f6..66958f7 100644 --- a/fuzz/fuzz_targets/roundtrip.rs +++ b/fuzz/fuzz_targets/roundtrip.rs @@ -24,6 +24,7 @@ fuzz_target!(|module: wasm_smith::Module| { } } }; + parsed_module.expand_all_funcs().unwrap(); parsed_module.optimize(); let _ = parsed_module.to_wasm_bytes(); }); diff --git a/src/backend/mod.rs b/src/backend/mod.rs index dd6da27..aac78f8 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -585,7 +585,7 @@ pub fn compile(module: &Module<'_>) -> anyhow::Result> { for (func, func_decl) in module.funcs.entries().skip(num_func_imports) { match func_decl { FuncDecl::Import(_) => anyhow::bail!("Import comes after func with body: {}", func), - FuncDecl::Body(sig, _) => { + FuncDecl::Lazy(sig, _) | FuncDecl::Body(sig, _) => { funcs.function(sig.index() as u32); } } @@ -689,21 +689,43 @@ pub fn compile(module: &Module<'_>) -> anyhow::Result> { into_mod.section(&elem); let mut code = wasm_encoder::CodeSection::new(); + enum FuncOrRawBytes<'a> { + Func(wasm_encoder::Function), + Raw(&'a [u8]), + } + let bodies = module .funcs .entries() .skip(num_func_imports) .collect::>() .par_iter() - .map(|(func, func_decl)| -> Result { - let body = func_decl.body().unwrap(); - log::debug!("Compiling {}", func); - WasmFuncBackend::new(body)?.compile() + .map(|(func, func_decl)| -> Result { + match func_decl { + FuncDecl::Lazy(_, reader) => { + let data = &module.orig_bytes[reader.range()]; + Ok(FuncOrRawBytes::Raw(data)) + } + FuncDecl::Body(_, body) => { + log::debug!("Compiling {}", func); + WasmFuncBackend::new(body)? + .compile() + .map(|f| FuncOrRawBytes::Func(f)) + } + FuncDecl::Import(_) => unreachable!("Should have skipped imports"), + } }) - .collect::>>()?; + .collect::>>>()?; for body in bodies { - code.function(&body); + match body { + FuncOrRawBytes::Func(f) => { + code.function(&f); + } + FuncOrRawBytes::Raw(bytes) => { + code.raw(bytes); + } + } } into_mod.section(&code); diff --git a/src/bin/waffle-util.rs b/src/bin/waffle-util.rs index 52a7f9a..32a49a6 100644 --- a/src/bin/waffle-util.rs +++ b/src/bin/waffle-util.rs @@ -41,6 +41,19 @@ enum Command { }, } +fn apply_options(opts: &Options, module: &mut Module) -> Result<()> { + if opts.basic_opts || opts.max_ssa { + module.expand_all_funcs()?; + } + if opts.basic_opts { + module.per_func_body(|body| body.optimize()); + } + if opts.max_ssa { + module.per_func_body(|body| body.convert_to_max_ssa()); + } + Ok(()) +} + fn main() -> Result<()> { let opts = Options::from_args(); @@ -50,29 +63,19 @@ fn main() -> Result<()> { } let _ = logger.try_init(); - match opts.command { + match &opts.command { Command::PrintIR { wasm } => { let bytes = std::fs::read(wasm)?; debug!("Loaded {} bytes of Wasm data", bytes.len()); let mut module = Module::from_wasm_bytes(&bytes[..])?; - if opts.basic_opts { - module.optimize(); - } - if opts.max_ssa { - module.convert_to_max_ssa(); - } + apply_options(&opts, &mut module)?; println!("{}", module.display()); } Command::RoundTrip { input, output } => { let bytes = std::fs::read(input)?; debug!("Loaded {} bytes of Wasm data", bytes.len()); let mut module = Module::from_wasm_bytes(&bytes[..])?; - if opts.basic_opts { - module.optimize(); - } - if opts.max_ssa { - module.convert_to_max_ssa(); - } + apply_options(&opts, &mut module)?; let produced = module.to_wasm_bytes()?; std::fs::write(output, &produced[..])?; } diff --git a/src/frontend.rs b/src/frontend.rs index 8c9f262..b6c9754 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -146,10 +146,7 @@ fn handle_payload<'a>( *next_func += 1; let my_sig = module.funcs[func_idx].sig(); - let body = parse_body(module, my_sig, body)?; - - let existing_body = module.funcs[func_idx].body_mut().unwrap(); - *existing_body = body; + module.funcs[func_idx] = FuncDecl::Lazy(my_sig, body); } Payload::ExportSection(reader) => { for export in reader { @@ -269,10 +266,10 @@ fn handle_payload<'a>( Ok(()) } -fn parse_body<'a>( +pub(crate) fn parse_body<'a>( module: &'a Module, my_sig: Signature, - body: wasmparser::FunctionBody, + body: &mut wasmparser::FunctionBody, ) -> Result { let mut ret: FunctionBody = FunctionBody::default(); diff --git a/src/ir/display.rs b/src/ir/display.rs index c998105..eb0284d 100644 --- a/src/ir/display.rs +++ b/src/ir/display.rs @@ -190,14 +190,10 @@ impl<'a> Display for ModuleDisplay<'a> { for seg in &memory_data.segments { writeln!( f, - " {} offset {}: [{}]", + " {} offset {}: # {} bytes", memory, seg.offset, - seg.data - .iter() - .map(|&byte| format!("0x{:02x}", byte)) - .collect::>() - .join(", ") + seg.data.len() )?; } } @@ -217,6 +213,10 @@ impl<'a> Display for ModuleDisplay<'a> { writeln!(f, " {}: {} = # {}", func, sig, sig_strs.get(&sig).unwrap())?; writeln!(f, "{}", body.display(" "))?; } + FuncDecl::Lazy(sig, reader) => { + writeln!(f, " {}: {} = # {}", func, sig, sig_strs.get(&sig).unwrap())?; + writeln!(f, " # raw bytes (length {})", reader.range().len())?; + } FuncDecl::Import(sig) => { writeln!(f, " {}: {} # {}", func, sig, sig_strs.get(&sig).unwrap())?; } diff --git a/src/ir/func.rs b/src/ir/func.rs index f1bcdfa..7d6df34 100644 --- a/src/ir/func.rs +++ b/src/ir/func.rs @@ -1,21 +1,54 @@ use super::{Block, FunctionBodyDisplay, Local, Module, Signature, Type, Value, ValueDef}; use crate::cfg::CFGInfo; use crate::entity::{EntityRef, EntityVec, PerEntity}; +use crate::frontend::parse_body; +use anyhow::Result; #[derive(Clone, Debug)] -pub enum FuncDecl { +pub enum FuncDecl<'a> { Import(Signature), + Lazy(Signature, wasmparser::FunctionBody<'a>), Body(Signature, FunctionBody), } -impl FuncDecl { +impl<'a> FuncDecl<'a> { pub fn sig(&self) -> Signature { match self { FuncDecl::Import(sig) => *sig, + FuncDecl::Lazy(sig, ..) => *sig, FuncDecl::Body(sig, ..) => *sig, } } + pub fn parse(&mut self, module: &Module) -> Result<()> { + match self { + FuncDecl::Lazy(sig, body) => { + let body = parse_body(module, *sig, body)?; + *self = FuncDecl::Body(*sig, body); + Ok(()) + } + _ => Ok(()), + } + } + + pub fn optimize(&mut self) { + match self { + FuncDecl::Body(_, body) => { + body.optimize(); + } + _ => {} + } + } + + pub fn convert_to_max_ssa(&mut self) { + match self { + FuncDecl::Body(_, body) => { + body.convert_to_max_ssa(); + } + _ => {} + } + } + pub fn body(&self) -> Option<&FunctionBody> { match self { FuncDecl::Body(_, body) => Some(body), @@ -78,6 +111,18 @@ impl FunctionBody { } } + pub fn optimize(&mut self) { + let cfg = crate::cfg::CFGInfo::new(self); + crate::passes::basic_opt::gvn(self, &cfg); + crate::passes::resolve_aliases::run(self); + crate::passes::empty_blocks::run(self); + } + + pub fn convert_to_max_ssa(&mut self) { + let cfg = crate::cfg::CFGInfo::new(self); + crate::passes::maxssa::run(self, &cfg); + } + pub fn add_block(&mut self) -> Block { let id = self.blocks.push(BlockDef::default()); log::trace!("add_block: block {}", id); @@ -94,6 +139,23 @@ impl FunctionBody { log::trace!("add_edge: from {} to {}", from, to); } + pub fn recompute_edges(&mut self) { + for block in self.blocks.values_mut() { + block.preds.clear(); + block.succs.clear(); + block.pos_in_succ_pred.clear(); + block.pos_in_pred_succ.clear(); + } + + for block in 0..self.blocks.len() { + let block = Block::new(block); + let terminator = self.blocks[block].terminator.clone(); + terminator.visit_successors(|succ| { + self.add_edge(block, succ); + }); + } + } + pub fn add_value(&mut self, value: ValueDef) -> Value { log::trace!("add_value: def {:?}", value); let value = self.values.push(value); diff --git a/src/ir/module.rs b/src/ir/module.rs index ead764c..077248c 100644 --- a/src/ir/module.rs +++ b/src/ir/module.rs @@ -1,5 +1,5 @@ use super::{Func, FuncDecl, Global, Memory, ModuleDisplay, Signature, Table, Type}; -use crate::entity::EntityVec; +use crate::entity::{EntityRef, EntityVec}; use crate::ir::FunctionBody; use crate::{backend, frontend}; use anyhow::Result; @@ -7,7 +7,7 @@ use anyhow::Result; #[derive(Clone, Debug)] pub struct Module<'a> { pub orig_bytes: &'a [u8], - pub funcs: EntityVec, + pub funcs: EntityVec>, pub signatures: EntityVec, pub globals: EntityVec, pub tables: EntityVec, @@ -171,19 +171,22 @@ impl<'a> Module<'a> { } } - pub fn optimize(&mut self) { - self.per_func_body(|body| { - let cfg = crate::cfg::CFGInfo::new(body); - crate::passes::basic_opt::gvn(body, &cfg); - crate::passes::resolve_aliases::run(body); - }); + pub fn expand_func<'b>(&'b mut self, id: Func) -> Result<&'b mut FuncDecl<'a>> { + if let FuncDecl::Lazy(..) = self.funcs[id] { + // End the borrow. This is cheap (a slice copy). + let mut func = self.funcs[id].clone(); + func.parse(self)?; + self.funcs[id] = func; + } + Ok(&mut self.funcs[id]) } - pub fn convert_to_max_ssa(&mut self) { - self.per_func_body(|body| { - let cfg = crate::cfg::CFGInfo::new(body); - crate::passes::maxssa::run(body, &cfg); - }); + pub fn expand_all_funcs(&mut self) -> Result<()> { + for id in 0..self.funcs.len() { + let id = Func::new(id); + self.expand_func(id)?; + } + Ok(()) } pub fn display<'b>(&'b self) -> ModuleDisplay<'b> diff --git a/src/passes.rs b/src/passes.rs index 39512ca..c001335 100644 --- a/src/passes.rs +++ b/src/passes.rs @@ -2,5 +2,6 @@ pub mod basic_opt; pub mod dom_pass; +pub mod empty_blocks; pub mod maxssa; pub mod resolve_aliases; diff --git a/src/passes/empty_blocks.rs b/src/passes/empty_blocks.rs new file mode 100644 index 0000000..1a7556b --- /dev/null +++ b/src/passes/empty_blocks.rs @@ -0,0 +1,128 @@ +//! Pass to remove empty blocks. + +use crate::entity::EntityRef; +use crate::ir::{Block, BlockTarget, FunctionBody, Terminator, Value, ValueDef}; +use std::borrow::Cow; +use std::collections::HashSet; + +#[derive(Clone, Debug)] +struct Forwarding { + to: Block, + args: Vec, +} + +#[derive(Clone, Copy, Debug)] +enum ForwardingArg { + BlockParam(usize), + Value(Value), +} + +impl Forwarding { + fn compose(a: &Forwarding, b: &Forwarding) -> Forwarding { + // `b` should be the target of `a.to`, but we can't assert + // that here. The composed target is thus `b.to`. + let to = b.to; + + // For each arg in `b.args`, evaluate, replacing any + // `BlockParam` with the corresponding value from `a.args`. + let args = b + .args + .iter() + .map(|&arg| match arg { + ForwardingArg::BlockParam(idx) => a.args[idx].clone(), + ForwardingArg::Value(v) => ForwardingArg::Value(v), + }) + .collect::>(); + + Forwarding { to, args } + } +} + +fn block_to_forwarding(body: &FunctionBody, block: Block) -> Option { + // Must be empty except for terminator, and must have an + // unconditional-branch terminator. + if body.blocks[block].insts.len() > 0 { + return None; + } + let target = match &body.blocks[block].terminator { + &Terminator::Br { ref target } => target, + _ => return None, + }; + + // If conditions met, then gather ForwardingArgs. + let args = target + .args + .iter() + .map(|&arg| { + let arg = body.resolve_alias(arg); + match &body.values[arg] { + &ValueDef::BlockParam(param_block, index, _) if param_block == block => { + ForwardingArg::BlockParam(index) + } + _ => ForwardingArg::Value(arg), + } + }) + .collect::>(); + + Some(Forwarding { + to: target.block, + args, + }) +} + +fn rewrite_target(forwardings: &[Option], target: &BlockTarget) -> Option { + if !forwardings[target.block.index()].is_some() { + return None; + } + + let mut forwarding = Cow::Borrowed(forwardings[target.block.index()].as_ref().unwrap()); + let mut seen = HashSet::new(); + while forwardings[forwarding.to.index()].is_some() && seen.insert(forwarding.to.index()) { + forwarding = Cow::Owned(Forwarding::compose( + &forwarding, + forwardings[forwarding.to.index()].as_ref().unwrap(), + )); + } + + let args = forwarding + .args + .iter() + .map(|arg| match arg { + &ForwardingArg::Value(v) => v, + &ForwardingArg::BlockParam(idx) => target.args[idx], + }) + .collect::>(); + + Some(BlockTarget { + block: forwarding.to, + args, + }) +} + +pub fn run(body: &mut FunctionBody) { + // Identify empty blocks, and to where they should forward. + let forwardings = body + .blocks + .iter() + .map(|block| { + if block != body.entry { + block_to_forwarding(body, block) + } else { + None + } + }) + .collect::>(); + + // Rewrite every target according to a forwarding (or potentially + // a chain of composed forwardings). + for block_data in body.blocks.values_mut() { + block_data.terminator.update_targets(|target| { + if let Some(new_target) = rewrite_target(&forwardings[..], target) { + *target = new_target; + } + }); + } + + // Recompute preds/succs. + body.recompute_edges(); +}