diff --git a/src/entity.rs b/src/entity.rs index aceadac..6189a39 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -73,6 +73,12 @@ impl std::default::Default for EntityVec From> for EntityVec { + fn from(vec: Vec) -> Self { + Self(vec, PhantomData) + } +} + impl EntityVec { pub fn push(&mut self, t: T) -> Idx { let idx = Idx::new(self.0.len()); @@ -117,6 +123,10 @@ impl EntityVec { pub fn get_mut(&mut self, idx: Idx) -> Option<&mut T> { self.0.get_mut(idx.index()) } + + pub fn into_vec(self) -> Vec { + self.0 + } } impl Index for EntityVec { diff --git a/src/frontend.rs b/src/frontend.rs index 145e66a..dd1763d 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -551,7 +551,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { match &op { wasmparser::Operator::Unreachable => { if let Some(block) = self.cur_block { - self.body.end_block(block, Terminator::None); + self.body.end_block(block, Terminator::Unreachable); self.locals.finish_block(); } self.cur_block = None; diff --git a/src/ir.rs b/src/ir.rs index eded1d6..052d0b7 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -9,6 +9,7 @@ pub enum Type { F32, F64, V128, + FuncRef, } impl From for Type { fn from(ty: wasmparser::Type) -> Self { @@ -18,6 +19,7 @@ impl From for Type { wasmparser::Type::F32 => Type::F32, wasmparser::Type::F64 => Type::F64, wasmparser::Type::V128 => Type::V128, + wasmparser::Type::FuncRef => Type::FuncRef, _ => panic!("Unsupported type: {:?}", ty), } } @@ -31,6 +33,7 @@ impl std::fmt::Display for Type { Type::F32 => "f32", Type::F64 => "f64", Type::V128 => "v128", + Type::FuncRef => "funcref", }; write!(f, "{}", s) } @@ -43,7 +46,7 @@ entity!(Local, "local"); entity!(Global, "global"); entity!(Table, "table"); entity!(Memory, "memory"); -entity!(Value, "value"); +entity!(Value, "v"); mod module; pub use module::*; diff --git a/src/ir/display.rs b/src/ir/display.rs index 7a9ff6b..fe7edc5 100644 --- a/src/ir/display.rs +++ b/src/ir/display.rs @@ -1,7 +1,7 @@ //! Displaying IR. use super::{FuncDecl, FunctionBody, Module, ValueDef}; - +use std::collections::HashMap; use std::fmt::{Display, Formatter, Result as FmtResult}; pub struct FunctionBodyDisplay<'a>(pub(crate) &'a FunctionBody, pub(crate) &'a str); @@ -36,23 +36,32 @@ impl<'a> Display for FunctionBodyDisplay<'a> { .map(|(ty, val)| format!("{}: {}", val, ty)) .collect::>(); writeln!(f, "{} {}({}):", self.1, block_id, block_params.join(", "))?; - for &pred in &block.preds { - writeln!(f, "{} # pred: {}", self.1, pred)?; - } - for &succ in &block.succs { - writeln!(f, "{} # succ: {}", self.1, succ)?; - } + writeln!( + f, + "{} # preds: {}", + self.1, + block + .preds + .iter() + .map(|pred| format!("{}", pred)) + .collect::>() + .join(", ") + )?; + writeln!( + f, + "{} # succs: {}", + self.1, + block + .succs + .iter() + .map(|succ| format!("{}", succ)) + .collect::>() + .join(", ") + )?; for &inst in &block.insts { - let inst = self.0.resolve_alias(inst); match &self.0.values[inst] { ValueDef::Operator(op, args, tys) => { - let args = args - .iter() - .map(|&v| { - let v = self.0.resolve_alias(v); - format!("{}", v) - }) - .collect::>(); + let args = args.iter().map(|&v| format!("{}", v)).collect::>(); let tys = tys.iter().map(|&ty| format!("{}", ty)).collect::>(); writeln!( f, @@ -67,12 +76,16 @@ impl<'a> Display for FunctionBodyDisplay<'a> { ValueDef::PickOutput(val, idx, ty) => { writeln!(f, "{} {} = {}.{} # {}", self.1, inst, val, idx, ty)?; } + ValueDef::Alias(v) => { + writeln!(f, "{} {} <- {}", self.1, inst, v)?; + } _ => unreachable!(), } } + writeln!(f, "{} {}", self.1, block.terminator)?; } - writeln!(f, "}}")?; + writeln!(f, "{}}}", self.1)?; Ok(()) } @@ -83,14 +96,36 @@ pub struct ModuleDisplay<'a>(pub(crate) &'a Module<'a>); impl<'a> Display for ModuleDisplay<'a> { fn fmt(&self, f: &mut Formatter) -> FmtResult { writeln!(f, "module {{")?; + let mut sig_strs = HashMap::new(); + for (sig, sig_data) in self.0.signatures() { + let arg_tys = sig_data + .params + .iter() + .map(|&ty| format!("{}", ty)) + .collect::>(); + let ret_tys = sig_data + .returns + .iter() + .map(|&ty| format!("{}", ty)) + .collect::>(); + let sig_str = format!("{} -> {}", arg_tys.join(", "), ret_tys.join(", ")); + sig_strs.insert(sig, sig_str.clone()); + writeln!(f, " {}: {}", sig, sig_str)?; + } + for (global, global_ty) in self.0.globals() { + writeln!(f, " {}: {}", global, global_ty)?; + } + for (table, table_ty) in self.0.tables() { + writeln!(f, " {}: {}", table, table_ty)?; + } for (func, func_decl) in self.0.funcs() { match func_decl { FuncDecl::Body(sig, body) => { - writeln!(f, " {}: {} =", func, sig)?; + writeln!(f, " {}: {} = # {}", func, sig, sig_strs.get(&sig).unwrap())?; writeln!(f, "{}", body.display(" "))?; } FuncDecl::Import(sig) => { - writeln!(f, " {}: {}", func, sig)?; + writeln!(f, " {}: {} # {}", func, sig, sig_strs.get(&sig).unwrap())?; } } } diff --git a/src/ir/func.rs b/src/ir/func.rs index 36f7427..1708e42 100644 --- a/src/ir/func.rs +++ b/src/ir/func.rs @@ -1,4 +1,4 @@ -use super::{Block, FunctionBodyDisplay, Local, Signature, Value, ValueDef, Type}; +use super::{Block, FunctionBodyDisplay, Local, Signature, Type, Value, ValueDef}; use crate::entity::EntityVec; #[derive(Clone, Debug)] @@ -168,6 +168,17 @@ pub struct BlockTarget { pub args: Vec, } +impl std::fmt::Display for BlockTarget { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let args = self + .args + .iter() + .map(|arg| format!("{}", arg)) + .collect::>(); + write!(f, "{}({})", self.block, args.join(", ")) + } +} + #[derive(Clone, Debug)] pub enum Terminator { Br { @@ -186,6 +197,7 @@ pub enum Terminator { Return { values: Vec, }, + Unreachable, None, } @@ -195,6 +207,46 @@ impl std::default::Default for Terminator { } } +impl std::fmt::Display for Terminator { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Terminator::None => write!(f, "no_terminator")?, + Terminator::Br { target } => write!(f, "br {}", target)?, + Terminator::CondBr { + cond, + if_true, + if_false, + } => write!(f, "if {}, {}, {}", cond, if_true, if_false)?, + Terminator::Select { + value, + targets, + default, + } => write!( + f, + "select {}, [{}], {}", + value, + targets + .iter() + .map(|target| format!("{}", target)) + .collect::>() + .join(", "), + default + )?, + Terminator::Return { values } => write!( + f, + "return {}", + values + .iter() + .map(|val| format!("{}", val)) + .collect::>() + .join(", ") + )?, + Terminator::Unreachable => write!(f, "unreachable")?, + } + Ok(()) + } +} + impl Terminator { pub fn visit_targets(&self, mut f: F) { match self { @@ -219,6 +271,7 @@ impl Terminator { } } Terminator::None => {} + Terminator::Unreachable => {} } } @@ -245,6 +298,7 @@ impl Terminator { } } Terminator::None => {} + Terminator::Unreachable => {} } } diff --git a/src/ir/module.rs b/src/ir/module.rs index d52ec24..40846fa 100644 --- a/src/ir/module.rs +++ b/src/ir/module.rs @@ -65,12 +65,21 @@ impl<'a> Module<'a> { pub fn signature<'b>(&'b self, id: Signature) -> &'b SignatureData { &self.signatures[id] } + pub fn signatures<'b>(&'b self) -> impl Iterator { + self.signatures.entries() + } pub fn global_ty(&self, id: Global) -> Type { self.globals[id] } + pub fn globals<'b>(&'b self) -> impl Iterator + 'b { + self.globals.entries().map(|(id, ty)| (id, *ty)) + } pub fn table_ty(&self, id: Table) -> Type { self.tables[id] } + pub fn tables<'b>(&'b self) -> impl Iterator + 'b { + self.tables.entries().map(|(id, ty)| (id, *ty)) + } pub(crate) fn frontend_add_signature(&mut self, ty: SignatureData) { self.signatures.push(ty); @@ -86,7 +95,13 @@ impl<'a> Module<'a> { } pub fn from_wasm_bytes(bytes: &'a [u8]) -> Result { - frontend::wasm_to_ir(bytes) + let mut module = frontend::wasm_to_ir(bytes)?; + for func_decl in module.funcs.values_mut() { + if let Some(body) = func_decl.body_mut() { + crate::passes::rpo::reorder_into_rpo(body); + } + } + Ok(module) } pub fn to_wasm_bytes(&self) -> Result> { diff --git a/src/lib.rs b/src/lib.rs index 4eb29c3..43f7c2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,8 @@ mod frontend; mod ir; mod op_traits; mod ops; +mod passes; +pub use passes::rpo::reorder_into_rpo; pub use ir::*; pub use ops::Operator; diff --git a/src/passes.rs b/src/passes.rs new file mode 100644 index 0000000..20ce41e --- /dev/null +++ b/src/passes.rs @@ -0,0 +1,3 @@ +//! Passes. + +pub mod rpo; diff --git a/src/passes/rpo.rs b/src/passes/rpo.rs new file mode 100644 index 0000000..9533174 --- /dev/null +++ b/src/passes/rpo.rs @@ -0,0 +1,162 @@ +//! Reorder-into-RPO pass. +//! +//! The RPO sort order we choose is quite special: we want loop bodies +//! to be placed contiguously, without blocks that do not belong to +//! the loop in the middle. +//! +//! Consider the following CFG: +//! +//! ```plain +//! 1 +//! | +//! 2 <-. +//! / | | +//! | 3 --' +//! | | +//! `> 4 +//! | +//! 5 +//! ``` +//! +//! A normal RPO sort may produce 1, 2, 4, 5, 3 or 1, 2, 3, 4, 5 +//! depending on which child order it chooses from block 2. (If it +//! visits 3 first, it will emit it first in postorder hence it comes +//! last.) +//! +//! One way of ensuring we get the right order would be to compute the +//! loop nest and make note of loops when choosing children to visit, +//! but we really would rather not do that, since we may not otherwise +//! need it. +//! +//! Instead, we keep a "pending" list: as we have nodes on the stack +//! during postorder traversal, we keep a list of other children that +//! we will visit once we get back to a given level. If another node +//! is pending, and is a successor we are considering, we visit it +//! *first* in postorder, so it is last in RPO. This is a way to +//! ensure that (e.g.) block 4 above is visited first when considering +//! successors of block 2. + +use crate::entity; +use crate::entity::{EntityRef, EntityVec, PerEntity}; +use crate::ir::{Block, FunctionBody}; +use std::collections::{HashMap, HashSet}; + +entity!(RPOIndex, "rpo"); + +impl RPOIndex { + fn prev(self) -> RPOIndex { + RPOIndex::from(self.0.checked_sub(1).unwrap()) + } +} + +#[derive(Clone, Debug, Default)] +struct RPO { + order: EntityVec, + rev: PerEntity>, +} + +impl RPO { + fn compute(body: &FunctionBody) -> RPO { + let mut postorder = vec![]; + let mut visited = HashSet::new(); + let mut pending = vec![]; + let mut pending_idx = HashMap::new(); + visited.insert(body.entry); + Self::visit( + body, + body.entry, + &mut visited, + &mut pending, + &mut pending_idx, + &mut postorder, + ); + postorder.reverse(); + let order = EntityVec::from(postorder); + + let mut rev = PerEntity::default(); + for (rpo_index, &block) in order.entries() { + rev[block] = Some(rpo_index); + } + + RPO { order, rev } + } + + fn visit( + body: &FunctionBody, + block: Block, + visited: &mut HashSet, + pending: &mut Vec, + pending_idx: &mut HashMap, + postorder: &mut Vec, + ) { + // `pending` is a Vec, not a Set; we prioritize based on + // position (first in pending go first in postorder -> last in + // RPO). A case with nested loops to show why this matters: + // + // TODO example + + let pending_top = pending.len(); + pending.extend(body.blocks[block].succs.iter().copied()); + + // Sort new entries in `pending` by index at which they appear + // earlier. Those that don't appear in `pending` at all should + // be visited last (to appear in RPO first), so we want `None` + // values to sort first here (hence the "unwrap or MAX" + // idiom). Then those that appear earlier in `pending` should + // be visited earlier here to appear later in RPO, so they + // sort later. + pending[pending_top..] + .sort_by_key(|entry| pending_idx.get(entry).copied().unwrap_or(usize::MAX)); + + // Above we placed items in order they are to be visited; + // below we pop off the end, so we reverse here. + pending[pending_top..].reverse(); + + // Now update indices in `pending_idx`: insert entries for + // those seqs not yet present. + for i in pending_top..pending.len() { + pending_idx.entry(pending[i]).or_insert(i); + } + + for _ in 0..(pending.len() - pending_top) { + let succ = pending.pop().unwrap(); + if pending_idx.get(&succ) == Some(&pending.len()) { + pending_idx.remove(&succ); + } + + if visited.insert(succ) { + Self::visit(body, succ, visited, pending, pending_idx, postorder); + } + } + postorder.push(block); + } + + fn map_block(&self, block: Block) -> Block { + Block::new(self.rev[block].unwrap().index()) + } +} + +pub fn reorder_into_rpo(body: &mut FunctionBody) { + let rpo = RPO::compute(body); + // Remap entry block. + body.entry = rpo.map_block(body.entry); + // Reorder blocks. + let mut block_data = std::mem::take(&mut body.blocks).into_vec(); + let mut new_block_data = vec![]; + for block in rpo.order.values().copied() { + new_block_data.push(std::mem::take(&mut block_data[block.index()])); + } + body.blocks = EntityVec::from(new_block_data); + // Rewrite references in each terminator, pred and succ list. + for block in body.blocks.values_mut() { + block.terminator.update_targets(|target| { + target.block = rpo.map_block(target.block); + }); + for pred in &mut block.preds { + *pred = rpo.map_block(*pred); + } + for succ in &mut block.succs { + *succ = rpo.map_block(*succ); + } + } +}