From b6ce3abc1d5b97c93e67be75a44ab5969e53f98f Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Fri, 2 Dec 2022 11:58:04 -0800 Subject: [PATCH] Add maximal-SSA mode. --- fuzz/fuzz_targets/differential.rs | 3 +- fuzz/fuzz_targets/roundtrip.rs | 3 +- src/bin/waffle-util.rs | 25 ++++- src/cfg/mod.rs | 3 + src/ir/module.rs | 33 +++++-- src/passes.rs | 1 + src/passes/maxssa.rs | 157 ++++++++++++++++++++++++++++++ 7 files changed, 212 insertions(+), 13 deletions(-) create mode 100644 src/passes/maxssa.rs diff --git a/fuzz/fuzz_targets/differential.rs b/fuzz/fuzz_targets/differential.rs index 88241bc..425f9b9 100644 --- a/fuzz/fuzz_targets/differential.rs +++ b/fuzz/fuzz_targets/differential.rs @@ -134,7 +134,8 @@ fuzz_target!(|module: wasm_smith::ConfiguredModule| { } }; - let parsed_module = Module::from_wasm_bytes(&orig_bytes[..]).unwrap(); + let mut parsed_module = Module::from_wasm_bytes(&orig_bytes[..]).unwrap(); + parsed_module.optimize(); let roundtrip_bytes = parsed_module.to_wasm_bytes().unwrap(); if let Ok(filename) = std::env::var("FUZZ_DUMP_WASM") { diff --git a/fuzz/fuzz_targets/roundtrip.rs b/fuzz/fuzz_targets/roundtrip.rs index b0502af..81467f6 100644 --- a/fuzz/fuzz_targets/roundtrip.rs +++ b/fuzz/fuzz_targets/roundtrip.rs @@ -7,7 +7,7 @@ fuzz_target!(|module: wasm_smith::Module| { let _ = env_logger::try_init(); log::debug!("original module: {:?}", module); let orig_bytes = module.to_bytes(); - let parsed_module = match Module::from_wasm_bytes(&orig_bytes[..]) { + let mut parsed_module = match Module::from_wasm_bytes(&orig_bytes[..]) { Ok(m) => m, Err(e) => { match e.downcast::() { @@ -24,5 +24,6 @@ fuzz_target!(|module: wasm_smith::Module| { } } }; + parsed_module.optimize(); let _ = parsed_module.to_wasm_bytes(); }); diff --git a/src/bin/waffle-util.rs b/src/bin/waffle-util.rs index f95010f..52a7f9a 100644 --- a/src/bin/waffle-util.rs +++ b/src/bin/waffle-util.rs @@ -12,6 +12,15 @@ struct Options { #[structopt(short, long)] debug: bool, + #[structopt( + help = "Do basic optimizations: GVN and const-prop", + long = "basic-opts" + )] + basic_opts: bool, + + #[structopt(help = "Transform to maximal SSA", long = "max-ssa")] + max_ssa: bool, + #[structopt(subcommand)] command: Command, } @@ -45,13 +54,25 @@ fn main() -> Result<()> { Command::PrintIR { wasm } => { let bytes = std::fs::read(wasm)?; debug!("Loaded {} bytes of Wasm data", bytes.len()); - let module = Module::from_wasm_bytes(&bytes[..])?; + let mut module = Module::from_wasm_bytes(&bytes[..])?; + if opts.basic_opts { + module.optimize(); + } + if opts.max_ssa { + module.convert_to_max_ssa(); + } println!("{}", module.display()); } Command::RoundTrip { input, output } => { let bytes = std::fs::read(input)?; debug!("Loaded {} bytes of Wasm data", bytes.len()); - let module = Module::from_wasm_bytes(&bytes[..])?; + let mut module = Module::from_wasm_bytes(&bytes[..])?; + if opts.basic_opts { + module.optimize(); + } + if opts.max_ssa { + module.convert_to_max_ssa(); + } let produced = module.to_wasm_bytes()?; std::fs::write(output, &produced[..])?; } diff --git a/src/cfg/mod.rs b/src/cfg/mod.rs index cef0a21..3be9a9d 100644 --- a/src/cfg/mod.rs +++ b/src/cfg/mod.rs @@ -80,6 +80,9 @@ impl CFGInfo { let mut def_block: PerEntity = PerEntity::default(); for (block, block_def) in f.blocks.entries() { + for &(_, param) in &block_def.params { + def_block[param] = block; + } for &value in &block_def.insts { def_block[value] = block; } diff --git a/src/ir/module.rs b/src/ir/module.rs index 5556bf2..ad430b7 100644 --- a/src/ir/module.rs +++ b/src/ir/module.rs @@ -221,21 +221,36 @@ impl<'a> Module<'a> { } pub fn from_wasm_bytes(bytes: &'a [u8]) -> Result { - let mut module = frontend::wasm_to_ir(bytes)?; - for func_decl in module.funcs.values_mut() { - if let Some(body) = func_decl.body_mut() { - let cfg = crate::cfg::CFGInfo::new(body); - crate::passes::basic_opt::gvn(body, &cfg); - crate::passes::resolve_aliases::run(body); - } - } - Ok(module) + frontend::wasm_to_ir(bytes) } pub fn to_wasm_bytes(&self) -> Result> { backend::compile(self) } + pub fn per_func_body(&mut self, f: F) { + for func_decl in self.funcs.values_mut() { + if let Some(body) = func_decl.body_mut() { + f(body); + } + } + } + + pub fn optimize(&mut self) { + self.per_func_body(|body| { + let cfg = crate::cfg::CFGInfo::new(body); + crate::passes::basic_opt::gvn(body, &cfg); + crate::passes::resolve_aliases::run(body); + }); + } + + pub fn convert_to_max_ssa(&mut self) { + self.per_func_body(|body| { + let cfg = crate::cfg::CFGInfo::new(body); + crate::passes::maxssa::run(body, &cfg); + }); + } + pub fn display<'b>(&'b self) -> ModuleDisplay<'b> where 'b: 'a, diff --git a/src/passes.rs b/src/passes.rs index 0f55dcf..96f923f 100644 --- a/src/passes.rs +++ b/src/passes.rs @@ -3,3 +3,4 @@ pub mod basic_opt; pub mod dom_pass; pub mod resolve_aliases; +pub mod maxssa; diff --git a/src/passes/maxssa.rs b/src/passes/maxssa.rs new file mode 100644 index 0000000..4987470 --- /dev/null +++ b/src/passes/maxssa.rs @@ -0,0 +1,157 @@ +//! Conversion pass that creates "maximal SSA": only local uses (no +//! uses of defs in other blocks), with all values explicitly passed +//! through blockparams. This makes some other transforms easier +//! because it removes the need to worry about adding blockparams when +//! mutating the CFG (all possible blockparams are already there!). + +use crate::cfg::CFGInfo; +use crate::entity::PerEntity; +use crate::ir::{Block, FunctionBody, Value, ValueDef}; +use std::collections::{BTreeSet, HashMap}; + +pub fn run(body: &mut FunctionBody, cfg: &CFGInfo) { + MaxSSAPass::new().run(body, cfg); +} + +struct MaxSSAPass { + /// Additional block args that must be passed to each block, in + /// order. Value numbers are *original* values. + new_args: PerEntity>, + /// For each block, a value map: from original value to local copy + /// of value. + value_map: HashMap<(Block, Value), Value>, +} + +impl MaxSSAPass { + fn new() -> Self { + Self { + new_args: PerEntity::default(), + value_map: HashMap::new(), + } + } + + fn run(mut self, body: &mut FunctionBody, cfg: &CFGInfo) { + for block in body.blocks.iter() { + self.visit(body, cfg, block); + } + self.update(body); + } + + fn visit(&mut self, body: &mut FunctionBody, cfg: &CFGInfo, block: Block) { + // For each use in the block, process the use. Collect all + // uses first to deduplicate and allow more efficient + // processing (and to appease the borrow checker). + let mut uses = BTreeSet::default(); + for &inst in &body.blocks[block].insts { + match &body.values[inst] { + &ValueDef::Operator(_, ref args, _) => { + for &arg in args { + let arg = body.resolve_alias(arg); + uses.insert(arg); + } + } + &ValueDef::PickOutput(value, ..) => { + let value = body.resolve_alias(value); + uses.insert(value); + } + _ => {} + } + } + body.blocks[block].terminator.visit_uses(|u| { + let u = body.resolve_alias(u); + uses.insert(u); + }); + + for u in uses { + self.visit_use(body, cfg, block, u); + } + } + + fn visit_use(&mut self, body: &mut FunctionBody, cfg: &CFGInfo, block: Block, value: Value) { + if self.value_map.contains_key(&(block, value)) { + return; + } + if cfg.def_block[value] == block { + return; + } + self.new_args[block].push(value); + + // Create a blockparam. + let ty = body.values[value].ty().unwrap(); + let blockparam = body.add_blockparam(block, ty); + self.value_map.insert((block, value), blockparam); + + // Recursively visit preds and use the value there, to ensure + // they have the value available as well. + for i in 0..body.blocks[block].preds.len() { + // Don't borrow for whole loop while iterating (`body` is + // taken as mut by recursion, but we don't add preds). + let pred = body.blocks[block].preds[i]; + self.visit_use(body, cfg, pred, value); + } + } + + fn update_preds(&mut self, body: &mut FunctionBody, block: Block) { + for i in 0..body.blocks[block].preds.len() { + let pred = body.blocks[block].preds[i]; + let pred_succ_idx = body.blocks[block].pos_in_pred_succ[i]; + body.blocks[pred] + .terminator + .update_target(pred_succ_idx, |target| { + for &new_arg in &self.new_args[block] { + let actual_value = self + .value_map + .get(&(pred, new_arg)) + .copied() + .unwrap_or(new_arg); + target.args.push(actual_value); + } + }); + } + } + + fn update_uses(&mut self, body: &mut FunctionBody, block: Block) { + let resolve = |body: &FunctionBody, value: Value| { + let value = body.resolve_alias(value); + self.value_map + .get(&(block, value)) + .copied() + .unwrap_or(value) + }; + + for i in 0..body.blocks[block].insts.len() { + let inst = body.blocks[block].insts[i]; + let mut def = std::mem::take(&mut body.values[inst]); + match &mut def { + ValueDef::Operator(_, args, _) => { + for arg in args { + *arg = resolve(body, *arg); + } + } + ValueDef::PickOutput(value, ..) => { + *value = resolve(body, *value); + } + ValueDef::Alias(_) => { + // Nullify the alias: should no longer be needed. + def = ValueDef::None; + } + _ => {} + } + body.values[inst] = def; + } + let mut term = std::mem::take(&mut body.blocks[block].terminator); + term.update_uses(|u| { + *u = resolve(body, *u); + }); + body.blocks[block].terminator = term; + } + + fn update(&mut self, body: &mut FunctionBody) { + for block in body.blocks.iter() { + if self.new_args[block].len() > 0 { + self.update_preds(body, block); + } + self.update_uses(body, block); + } + } +}