From 9c84c7d44d548e3025b0e0210538da6481eafee2 Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Sat, 25 Feb 2023 16:57:27 -0800 Subject: [PATCH] Rewrite localifier (regalloc). --- fuzz/fuzz_targets/differential.rs | 2 +- src/backend/localify.rs | 285 ++++++++++++++++++------------ src/bin/waffle-util.rs | 2 +- src/passes.rs | 8 + src/passes/basic_opt.rs | 8 +- 5 files changed, 187 insertions(+), 118 deletions(-) diff --git a/fuzz/fuzz_targets/differential.rs b/fuzz/fuzz_targets/differential.rs index b72af67..c10311e 100644 --- a/fuzz/fuzz_targets/differential.rs +++ b/fuzz/fuzz_targets/differential.rs @@ -38,7 +38,7 @@ fuzz_target!( let mut parsed_module = Module::from_wasm_bytes(&orig_bytes[..], &FrontendOptions::default()).unwrap(); parsed_module.expand_all_funcs().unwrap(); - parsed_module.per_func_body(|body| body.optimize()); + parsed_module.per_func_body(|body| body.optimize(&mut waffle::passes::Fuel::infinite())); let roundtrip_bytes = parsed_module.to_wasm_bytes().unwrap(); if let Ok(filename) = std::env::var("FUZZ_DUMP_WASM") { diff --git a/src/backend/localify.rs b/src/backend/localify.rs index 20b4596..f6e8c4e 100644 --- a/src/backend/localify.rs +++ b/src/backend/localify.rs @@ -6,7 +6,8 @@ use crate::cfg::CFGInfo; use crate::entity::{EntityVec, PerEntity}; use crate::ir::{Block, FunctionBody, Local, Type, Value, ValueDef}; use smallvec::{smallvec, SmallVec}; -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::{HashMap, HashSet}; +use std::ops::Range; #[derive(Clone, Debug, Default)] pub struct Localifier { @@ -26,14 +27,89 @@ struct Context<'a> { trees: &'a Trees, results: Localifier, + /// Precise liveness for each block: live Values at the end. + block_end_live: PerEntity>, + /// Liveranges for each Value, in an arbitrary index space /// (concretely, the span of first to last instruction visit step /// index in an RPO walk over the function body). - ranges: HashMap>, + ranges: HashMap>, /// Number of points. points: usize, } +trait Visitor { + fn visit_use(&mut self, _: Value) {} + fn visit_def(&mut self, _: Value) {} + fn post_inst(&mut self, _: Value) {} + fn pre_inst(&mut self, _: Value) {} + fn post_term(&mut self) {} + fn pre_term(&mut self) {} + fn post_params(&mut self) {} + fn pre_params(&mut self) {} +} + +struct BlockVisitor<'a, V: Visitor> { + body: &'a FunctionBody, + trees: &'a Trees, + visitor: V, +} +impl<'a, V: Visitor> BlockVisitor<'a, V> { + fn new(body: &'a FunctionBody, trees: &'a Trees, visitor: V) -> Self { + Self { + body, + trees, + visitor, + } + } + fn visit_block(&mut self, block: Block) { + self.visitor.post_term(); + self.body.blocks[block].terminator.visit_uses(|u| { + self.visit_use(u); + }); + self.visitor.pre_term(); + + for &inst in self.body.blocks[block].insts.iter().rev() { + self.visitor.post_inst(inst); + self.visit_inst(inst, /* root = */ true); + self.visitor.pre_inst(inst); + } + + self.visitor.post_params(); + for &(_, param) in &self.body.blocks[block].params { + self.visitor.visit_def(param); + } + self.visitor.pre_params(); + } + fn visit_inst(&mut self, value: Value, root: bool) { + // If this is an instruction... + if let ValueDef::Operator(_, ref args, _) = &self.body.values[value] { + // If root, we need to process the def. + if root { + self.visitor.visit_def(value); + } + // Handle uses. + for &arg in args { + self.visit_use(arg); + } + } + // Otherwise, it may be an alias (but resolved above) or + // PickOutput, which we "see through" in handle_use of + // consumers. + } + fn visit_use(&mut self, value: Value) { + let value = self.body.resolve_alias(value); + if self.trees.owner.contains_key(&value) { + // If this is a treeified value, then don't process the use, + // but process the instruction directly here. + self.visit_inst(value, /* root = */ false); + } else { + // Otherwise, this is a proper use. + self.visitor.visit_use(value); + } + } +} + impl<'a> Context<'a> { fn new(body: &'a FunctionBody, cfg: &'a CFGInfo, trees: &'a Trees) -> Self { let mut results = Localifier::default(); @@ -49,122 +125,106 @@ impl<'a> Context<'a> { cfg, trees, results, + block_end_live: PerEntity::default(), ranges: HashMap::default(), points: 0, } } + fn compute_liveness(&mut self) { + struct LivenessVisitor { + live: HashSet, + } + impl Visitor for LivenessVisitor { + fn visit_use(&mut self, value: Value) { + self.live.insert(value); + } + fn visit_def(&mut self, value: Value) { + self.live.remove(&value); + } + } + + let mut workqueue: Vec = self.cfg.rpo.values().cloned().collect(); + let mut workqueue_set: HashSet = workqueue.iter().cloned().collect(); + while let Some(block) = workqueue.pop() { + let live = self.block_end_live[block].clone(); + let mut visitor = BlockVisitor::new(self.body, self.trees, LivenessVisitor { live }); + visitor.visit_block(block); + let live = visitor.visitor.live; + + for &pred in &self.body.blocks[block].preds { + let pred_live = &mut self.block_end_live[pred]; + let mut changed = false; + for &value in &live { + if pred_live.insert(value) { + changed = true; + } + } + if changed && workqueue_set.insert(pred) { + workqueue.push(pred); + } + } + } + } + fn find_ranges(&mut self) { let mut point = 0; - let mut live: HashMap = HashMap::default(); - let mut block_starts: HashMap = HashMap::default(); + struct LiveRangeVisitor<'b> { + point: &'b mut usize, + live: HashMap, + ranges: &'b mut HashMap>, + } + impl<'b> Visitor for LiveRangeVisitor<'b> { + fn pre_params(&mut self) { + *self.point += 1; + } + fn pre_inst(&mut self, _: Value) { + *self.point += 1; + } + fn pre_term(&mut self) { + *self.point += 1; + } + fn visit_use(&mut self, value: Value) { + self.live.entry(value).or_insert(*self.point); + } + fn visit_def(&mut self, value: Value) { + let range = if let Some(start) = self.live.remove(&value) { + start..(*self.point + 1) + } else { + *self.point..(*self.point + 1) + }; + let existing_range = self.ranges.entry(value).or_insert(range.clone()); + existing_range.start = std::cmp::min(existing_range.start, range.start); + existing_range.end = std::cmp::max(existing_range.end, range.end); + } + } + for &block in self.cfg.rpo.values().rev() { - block_starts.insert(block, point); - - self.body.blocks[block].terminator.visit_uses(|u| { - self.handle_use(&mut live, &mut point, u); - }); - point += 1; - - for &inst in self.body.blocks[block].insts.iter().rev() { - self.handle_inst(&mut live, &mut point, inst, /* root = */ true); - point += 1; + let visitor = LiveRangeVisitor { + live: HashMap::default(), + point: &mut point, + ranges: &mut self.ranges, + }; + let mut visitor = BlockVisitor::new(&self.body, &self.trees, visitor); + // Live-outs to succ blocks: in this block-local + // handling, model them as uses as the end of the block. + for &livein in &self.block_end_live[block] { + visitor.visitor.visit_use(livein); } - - for &(_, param) in &self.body.blocks[block].params { - self.handle_def(&mut live, &mut point, param); - } - point += 1; - - // If there were any in-edges from blocks numbered earlier - // in postorder ("loop backedges"), extend the start of - // the backward-range on all live values at this point to - // the origin of the edge. (In forward program order, - // extend the *end* of the liverange down to the end of - // the loop.) - // - // Note that we do this *after* inserting our own start - // above, so we handle self-loops properly. - for &pred in &self.body.blocks[block].preds { - if let Some(&start) = block_starts.get(&pred) { - for live_start in live.values_mut() { - *live_start = std::cmp::min(*live_start, start); - } - } + // Visit all insts. + visitor.visit_block(block); + // Live-ins from pred blocks: anything still live has a + // virtual def at top of block. + let still_live = visitor.visitor.live.keys().cloned().collect::>(); + for live in still_live { + visitor.visitor.visit_def(live); } } self.points = point; } - fn handle_def(&mut self, live: &mut HashMap, point: &mut usize, value: Value) { - // If the value was not live, make it so just for this - // point. Otherwise, end the liverange. - log::trace!("localify: point {}: live {:?}: def {}", point, live, value); - match live.entry(value) { - Entry::Vacant(_) => { - log::trace!(" -> was dead; use {}..{}", *point, *point + 1); - self.ranges.insert(value, *point..(*point + 1)); - } - Entry::Occupied(o) => { - let start = o.remove(); - log::trace!(" -> was live; use {}..{}", start, *point + 1); - self.ranges.insert(value, start..(*point + 1)); - } - } - } - - fn handle_use(&mut self, live: &mut HashMap, point: &mut usize, value: Value) { - let value = self.body.resolve_alias(value); - log::trace!("localify: point {}: live {:?}: use {}", point, live, value); - if self.trees.owner.contains_key(&value) { - log::trace!(" -> treeified, going to inst"); - // If this is a treeified value, then don't process the use, - // but process the instruction directly here. - self.handle_inst(live, point, value, /* root = */ false); - } else { - // Otherwise, update liveranges: make value live at this - // point if not live already. - live.entry(value).or_insert(*point); - } - } - - fn handle_inst( - &mut self, - live: &mut HashMap, - point: &mut usize, - value: Value, - root: bool, - ) { - log::trace!( - "localify: point {}: live {:?}: handling inst {} root {}", - point, - live, - value, - root - ); - - // If this is an instruction... - if let ValueDef::Operator(_, ref args, _) = &self.body.values[value] { - // If root, we need to process the def. - if root { - *point += 1; - log::trace!(" -> def {}", value); - self.handle_def(live, point, value); - } - *point += 1; - // Handle uses. - for &arg in args { - log::trace!(" -> arg {}", arg); - self.handle_use(live, point, arg); - } - } - // Otherwise, it may be an alias (but resolved above) or - // PickOutput, which we "see through" in handle_use of - // consumers. - } - fn allocate(&mut self) { // Sort values by ranges' starting points, then value to break ties. let mut ranges: Vec<(Value, std::ops::Range)> = @@ -181,16 +241,6 @@ impl<'a> Context<'a> { let mut freelist: HashMap> = HashMap::new(); for i in 0..self.points { - // Process ends. (Ends are exclusive, so we do them - // first; another range can grab the local at the same - // point index in this same iteration.) - if let Some(expiring) = expiring.remove(&i) { - for (ty, local) in expiring { - log::trace!(" -> expiring {} of type {} back to freelist", local, ty); - freelist.entry(ty).or_insert_with(|| vec![]).push(local); - } - } - // Process starts. while range_idx < ranges.len() && ranges[range_idx].1.start == i { let (value, range) = ranges[range_idx].clone(); @@ -228,10 +278,21 @@ impl<'a> Context<'a> { } self.results.values[value] = allocs; } + + // Process ends. (Ends are exclusive, so we do them + // first; another range can grab the local at the same + // point index in this same iteration.) + if let Some(expiring) = expiring.remove(&i) { + for (ty, local) in expiring { + log::trace!(" -> expiring {} of type {} back to freelist", local, ty); + freelist.entry(ty).or_insert_with(|| vec![]).push(local); + } + } } } fn compute(mut self) -> Localifier { + self.compute_liveness(); self.find_ranges(); self.allocate(); self.results diff --git a/src/bin/waffle-util.rs b/src/bin/waffle-util.rs index 5f1c7e3..fd142a4 100644 --- a/src/bin/waffle-util.rs +++ b/src/bin/waffle-util.rs @@ -57,7 +57,7 @@ enum Command { fn apply_options(opts: &Options, module: &mut Module) -> Result<()> { module.expand_all_funcs()?; if opts.basic_opts { - module.per_func_body(|body| body.optimize()); + module.per_func_body(|body| body.optimize(&mut waffle::passes::Fuel::infinite())); } if opts.max_ssa { module.per_func_body(|body| body.convert_to_max_ssa()); diff --git a/src/passes.rs b/src/passes.rs index 11dc04c..c5e47fe 100644 --- a/src/passes.rs +++ b/src/passes.rs @@ -14,6 +14,9 @@ pub struct Fuel { } impl Fuel { pub fn consume(&mut self) -> bool { + if self.remaining == u64::MAX { + return true; + } if self.remaining == 0 { false } else { @@ -21,4 +24,9 @@ impl Fuel { true } } + pub fn infinite() -> Fuel { + Fuel { + remaining: u64::MAX, + } + } } diff --git a/src/passes/basic_opt.rs b/src/passes/basic_opt.rs index 828fb0f..971bfb7 100644 --- a/src/passes/basic_opt.rs +++ b/src/passes/basic_opt.rs @@ -47,10 +47,6 @@ impl<'a> GVNPass<'a> { fn optimize(&mut self, block: Block, body: &mut FunctionBody) { let mut i = 0; while i < body.blocks[block].insts.len() { - if !self.fuel.consume() { - return; - } - let inst = body.blocks[block].insts[i]; i += 1; if value_is_pure(inst, body) { @@ -115,6 +111,10 @@ impl<'a> GVNPass<'a> { } if let Some(value) = self.map.get(&value) { + if !self.fuel.consume() { + return; + } + body.set_alias(inst, *value); i -= 1; body.blocks[block].insts.remove(i);