diff --git a/src/cfg/mod.rs b/src/cfg/mod.rs index 24602d2..43d0a41 100644 --- a/src/cfg/mod.rs +++ b/src/cfg/mod.rs @@ -31,6 +31,8 @@ pub struct CFGInfo { pub def_block: PerEntity, /// Preds for a given block. pub preds: PerEntity>, + /// A given block's position in each predecessor's successor list. + pub pred_pos: PerEntity>, } #[derive(Clone, Debug, Default)] @@ -61,22 +63,19 @@ impl CFGInfo { pub fn new(f: &FunctionBody) -> CFGInfo { let mut return_blocks = vec![]; let mut preds: PerEntity> = PerEntity::default(); + let mut pred_pos: PerEntity> = PerEntity::default(); for (block_id, block) in f.blocks.entries() { if let Terminator::Return { .. } = &block.terminator { return_blocks.push(block_id); } + let mut target_idx = 0; block.terminator.visit_targets(|target| { preds[target.block].push(block_id); + pred_pos[target.block].push(target_idx); + target_idx += 1; }); } - // Dedup preds. - for block in f.blocks.iter() { - let preds = &mut preds[block]; - preds.sort_unstable(); - preds.dedup(); - } - let postorder = postorder::calculate(f.entry, |block| &f.blocks[block].succs[..]); let domtree = @@ -127,6 +126,7 @@ impl CFGInfo { domtree_children, def_block, preds, + pred_pos, } } diff --git a/src/interp.rs b/src/interp.rs index 3a44fa8..b4b0f49 100644 --- a/src/interp.rs +++ b/src/interp.rs @@ -398,6 +398,15 @@ impl ConstVal { _ => None, } } + + pub fn meet(a: Option, b: Option) -> Option { + match (a, b) { + (None, None) => None, + (Some(a), None) | (None, Some(a)) => Some(a), + (Some(a), Some(b)) if a == b => Some(a), + _ => Some(ConstVal::None), + } + } } pub fn const_eval( diff --git a/src/ir/func.rs b/src/ir/func.rs index 651961f..832c33e 100644 --- a/src/ir/func.rs +++ b/src/ir/func.rs @@ -674,21 +674,13 @@ impl Terminator { } } - pub fn visit_target(&self, index: usize, mut f: F) { + pub fn visit_target R>(&self, index: usize, mut f: F) -> R { match (index, self) { (0, Terminator::Br { ref target, .. }) => f(target), - (0, Terminator::CondBr { ref if_true, .. }) => { - f(if_true); - } - (1, Terminator::CondBr { ref if_false, .. }) => { - f(if_false); - } - (0, Terminator::Select { ref default, .. }) => { - f(default); - } - (i, Terminator::Select { ref targets, .. }) if i <= targets.len() => { - f(&targets[i - 1]); - } + (0, Terminator::CondBr { ref if_true, .. }) => f(if_true), + (1, Terminator::CondBr { ref if_false, .. }) => f(if_false), + (0, Terminator::Select { ref default, .. }) => f(default), + (i, Terminator::Select { ref targets, .. }) if i <= targets.len() => f(&targets[i - 1]), _ => panic!("out of bounds"), } } diff --git a/src/passes/basic_opt.rs b/src/passes/basic_opt.rs index 78fe981..89aa23c 100644 --- a/src/passes/basic_opt.rs +++ b/src/passes/basic_opt.rs @@ -7,6 +7,7 @@ use crate::passes::dom_pass::{dom_pass, DomtreePass}; use crate::pool::ListRef; use crate::scoped_map::ScopedMap; use crate::Operator; +use smallvec::{smallvec, SmallVec}; pub fn gvn(body: &mut FunctionBody, cfg: &CFGInfo) { dom_pass::( @@ -14,16 +15,18 @@ pub fn gvn(body: &mut FunctionBody, cfg: &CFGInfo) { cfg, &mut GVNPass { map: ScopedMap::default(), + cfg, }, ); } #[derive(Debug)] -struct GVNPass { +struct GVNPass<'a> { map: ScopedMap, + cfg: &'a CFGInfo, } -impl DomtreePass for GVNPass { +impl<'a> DomtreePass for GVNPass<'a> { fn enter(&mut self, block: Block, body: &mut FunctionBody) { self.map.push_level(); self.optimize(block, body); @@ -41,8 +44,104 @@ fn value_is_pure(value: Value, body: &FunctionBody) -> bool { } } -impl GVNPass { +fn value_is_const(value: Value, body: &FunctionBody) -> ConstVal { + match body.values[value] { + ValueDef::Operator(Operator::I32Const { value }, _, _) => ConstVal::I32(value), + ValueDef::Operator(Operator::I64Const { value }, _, _) => ConstVal::I64(value), + ValueDef::Operator(Operator::F32Const { value }, _, _) => ConstVal::F32(value), + ValueDef::Operator(Operator::F64Const { value }, _, _) => ConstVal::F64(value), + _ => ConstVal::None, + } +} + +fn const_op(val: ConstVal) -> Operator { + match val { + ConstVal::I32(value) => Operator::I32Const { value }, + ConstVal::I64(value) => Operator::I64Const { value }, + ConstVal::F32(value) => Operator::F32Const { value }, + ConstVal::F64(value) => Operator::F64Const { value }, + _ => unreachable!(), + } +} + +fn remove_all_from_vec(v: &mut Vec, indices: &[usize]) { + let mut out = 0; + let mut indices_i = 0; + for i in 0..v.len() { + let keep = indices_i == indices.len() || indices[indices_i] != i; + if keep { + if out < i { + v[out] = v[i].clone(); + } + out += 1; + } else { + indices_i += 1; + } + } + + v.truncate(out); +} + +impl<'a> GVNPass<'a> { fn optimize(&mut self, block: Block, body: &mut FunctionBody) { + if block != body.entry { + // Pass over blockparams, checking all inputs. If all inputs + // resolve to the same SSA value, remove the blockparam and + // make it an alias of that value. If all inputs resolve to + // the same constant value, remove the blockparam and insert a + // new copy of that constant. + let mut blockparams_to_remove: SmallVec<[usize; 4]> = smallvec![]; + let mut const_insts_to_insert: SmallVec<[Value; 4]> = smallvec![]; + for (i, &(ty, blockparam)) in body.blocks[block].params.iter().enumerate() { + let mut inputs: SmallVec<[Value; 4]> = smallvec![]; + let mut const_val = None; + for (&pred, &pos) in self.cfg.preds[block] + .iter() + .zip(self.cfg.pred_pos[block].iter()) + { + let input = body.blocks[pred] + .terminator + .visit_target(pos, |target| target.args[i]); + inputs.push(input); + const_val = ConstVal::meet(const_val, Some(value_is_const(input, body))); + } + let const_val = const_val.unwrap(); + + assert!(inputs.len() > 0); + if inputs.iter().all(|x| *x == inputs[0]) { + // All inputs are the same value; remove the + // blockparam and rewrite it as an alias of the one + // single value. + body.values[blockparam] = ValueDef::Alias(inputs[0]); + blockparams_to_remove.push(i); + } else if const_val != ConstVal::None { + // All inputs are the same constant; remove the + // blockparam and rewrite it as a new constant + // operator. + let ty = body.type_pool.single(ty); + body.values[blockparam] = + ValueDef::Operator(const_op(const_val), ListRef::default(), ty); + const_insts_to_insert.push(blockparam); + blockparams_to_remove.push(i); + } + } + + for inst in const_insts_to_insert { + body.blocks[block].insts.insert(0, inst); + } + + remove_all_from_vec(&mut body.blocks[block].params, &blockparams_to_remove[..]); + for (&pred, &pos) in self.cfg.preds[block] + .iter() + .zip(self.cfg.pred_pos[block].iter()) + { + body.blocks[pred].terminator.update_target(pos, |target| { + remove_all_from_vec(&mut target.args, &blockparams_to_remove[..]) + }); + } + } + + // Pass over instructions, updating in place. let mut i = 0; while i < body.blocks[block].insts.len() { let inst = body.blocks[block].insts[i]; @@ -50,6 +149,7 @@ impl GVNPass { if value_is_pure(inst, body) { let mut value = body.values[inst].clone(); + // Resolve aliases in the arg lists. match &mut value { &mut ValueDef::Operator(_, args, _) | &mut ValueDef::Trace(_, args) => { for i in 0..args.len() { @@ -65,24 +165,11 @@ impl GVNPass { _ => {} } + // Try to constant-propagate. if let ValueDef::Operator(op, args, ..) = &value { let arg_values = body.arg_pool[*args] .iter() - .map(|&arg| match body.values[arg] { - ValueDef::Operator(Operator::I32Const { value }, _, _) => { - ConstVal::I32(value) - } - ValueDef::Operator(Operator::I64Const { value }, _, _) => { - ConstVal::I64(value) - } - ValueDef::Operator(Operator::F32Const { value }, _, _) => { - ConstVal::F32(value) - } - ValueDef::Operator(Operator::F64Const { value }, _, _) => { - ConstVal::F64(value) - } - _ => ConstVal::None, - }) + .map(|&arg| value_is_const(arg, body)) .collect::>(); let const_val = const_eval(op, &arg_values[..], None); match const_val { @@ -122,6 +209,8 @@ impl GVNPass { } } + // GVN: look for already-existing copies of this + // value. if let Some(value) = self.map.get(&value) { body.set_alias(inst, *value); i -= 1;