diff --git a/fuzz/fuzz_targets/opt_diff.rs b/fuzz/fuzz_targets/opt_diff.rs index d910e17..6610053 100644 --- a/fuzz/fuzz_targets/opt_diff.rs +++ b/fuzz/fuzz_targets/opt_diff.rs @@ -49,7 +49,7 @@ fuzz_target!( let mut opt_module = parsed_module.clone(); opt_module.per_func_body(|body| body.optimize()); - opt_module.per_func_body(|body| body.convert_to_max_ssa()); + opt_module.per_func_body(|body| body.convert_to_max_ssa(None)); let mut opt_ctx = InterpContext::new(&opt_module).unwrap(); // Allow a little leeway for opts to not actually optimize. diff --git a/src/bin/waffle-util.rs b/src/bin/waffle-util.rs index 2f13230..ae52cf3 100644 --- a/src/bin/waffle-util.rs +++ b/src/bin/waffle-util.rs @@ -67,7 +67,7 @@ fn apply_options(opts: &Options, module: &mut Module) -> Result<()> { module.per_func_body(|body| body.optimize()); } if opts.max_ssa { - module.per_func_body(|body| body.convert_to_max_ssa()); + module.per_func_body(|body| body.convert_to_max_ssa(None)); } Ok(()) } diff --git a/src/ir/func.rs b/src/ir/func.rs index d4694ff..a06c2a4 100644 --- a/src/ir/func.rs +++ b/src/ir/func.rs @@ -4,6 +4,7 @@ use crate::entity::{EntityRef, EntityVec, PerEntity}; use crate::frontend::parse_body; use crate::ir::SourceLoc; use anyhow::Result; +use std::collections::HashSet; /// A declaration of a function: there is one `FuncDecl` per `Func` /// index. @@ -50,10 +51,10 @@ impl<'a> FuncDecl<'a> { } } - pub fn convert_to_max_ssa(&mut self) { + pub fn convert_to_max_ssa(&mut self, cut_blocks: Option>) { match self { FuncDecl::Body(_, _, body) => { - body.convert_to_max_ssa(); + body.convert_to_max_ssa(cut_blocks); } _ => {} } @@ -150,9 +151,9 @@ impl FunctionBody { crate::passes::empty_blocks::run(self); } - pub fn convert_to_max_ssa(&mut self) { + pub fn convert_to_max_ssa(&mut self, cut_blocks: Option>) { let cfg = crate::cfg::CFGInfo::new(self); - crate::passes::maxssa::run(self, &cfg); + crate::passes::maxssa::run(self, cut_blocks, &cfg); } pub fn add_block(&mut self) -> Block { diff --git a/src/passes/maxssa.rs b/src/passes/maxssa.rs index 4987470..bed6caf 100644 --- a/src/passes/maxssa.rs +++ b/src/passes/maxssa.rs @@ -7,13 +7,16 @@ use crate::cfg::CFGInfo; use crate::entity::PerEntity; use crate::ir::{Block, FunctionBody, Value, ValueDef}; -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; -pub fn run(body: &mut FunctionBody, cfg: &CFGInfo) { - MaxSSAPass::new().run(body, cfg); +pub fn run(body: &mut FunctionBody, cut_blocks: Option>, cfg: &CFGInfo) { + MaxSSAPass::new(cut_blocks).run(body, cfg); } struct MaxSSAPass { + /// Blocks at which all live values must cross through blockparams + /// (or if None, then all blocks). + cut_blocks: Option>, /// Additional block args that must be passed to each block, in /// order. Value numbers are *original* values. new_args: PerEntity>, @@ -23,8 +26,9 @@ struct MaxSSAPass { } impl MaxSSAPass { - fn new() -> Self { + fn new(cut_blocks: Option>) -> Self { Self { + cut_blocks, new_args: PerEntity::default(), value_map: HashMap::new(), } @@ -76,7 +80,7 @@ impl MaxSSAPass { } self.new_args[block].push(value); - // Create a blockparam. + // Create a placeholder value. let ty = body.values[value].ty().unwrap(); let blockparam = body.add_blockparam(block, ty); self.value_map.insert((block, value), blockparam); @@ -89,24 +93,44 @@ impl MaxSSAPass { let pred = body.blocks[block].preds[i]; self.visit_use(body, cfg, pred, value); } + + // If all preds have the same value, and this is not a + // cut-block, rewrite the blockparam to an alias instead. + if !self.is_cut_block(block) { + if let Some(pred_value) = iter_all_same( + body.blocks[block] + .preds + .iter() + .map(|&pred| *self.value_map.get(&(pred, value)).unwrap_or(&value)) + .filter(|&val| val != blockparam), + ) { + body.blocks[block].params.pop(); + self.new_args[block].pop(); + body.values[blockparam] = ValueDef::Alias(pred_value); + self.value_map.insert((block, value), pred_value); + } + } } - fn update_preds(&mut self, body: &mut FunctionBody, block: Block) { - for i in 0..body.blocks[block].preds.len() { - let pred = body.blocks[block].preds[i]; - let pred_succ_idx = body.blocks[block].pos_in_pred_succ[i]; - body.blocks[pred] - .terminator - .update_target(pred_succ_idx, |target| { - for &new_arg in &self.new_args[block] { - let actual_value = self - .value_map - .get(&(pred, new_arg)) - .copied() - .unwrap_or(new_arg); - target.args.push(actual_value); - } - }); + fn is_cut_block(&self, block: Block) -> bool { + self.cut_blocks + .as_ref() + .map(|cut_blocks| cut_blocks.contains(&block)) + .unwrap_or(true) + } + + fn update_branch_args(&mut self, body: &mut FunctionBody) { + for (block, blockdata) in body.blocks.entries_mut() { + blockdata.terminator.update_targets(|target| { + for &new_arg in &self.new_args[target.block] { + let actual_value = self + .value_map + .get(&(block, new_arg)) + .copied() + .unwrap_or(new_arg); + target.args.push(actual_value); + } + }); } } @@ -147,11 +171,19 @@ impl MaxSSAPass { } fn update(&mut self, body: &mut FunctionBody) { + self.update_branch_args(body); for block in body.blocks.iter() { - if self.new_args[block].len() > 0 { - self.update_preds(body, block); - } self.update_uses(body, block); } } } + +fn iter_all_same>(iter: I) -> Option { + let mut item = None; + for val in iter { + if *item.get_or_insert(val) != val { + return None; + } + } + item +}