diff --git a/Cargo.toml b/Cargo.toml index b29e6c8..c7d06dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,4 +12,5 @@ wasm-encoder = "0.3" anyhow = "1.0" structopt = "0.3" log = "0.4" -env_logger = "0.9" \ No newline at end of file +env_logger = "0.9" +fxhash = "0.2" diff --git a/src/dataflow.rs b/src/dataflow.rs new file mode 100644 index 0000000..35f8c2a --- /dev/null +++ b/src/dataflow.rs @@ -0,0 +1,153 @@ +//! Dataflow analysis. + +use fxhash::{FxHashMap, FxHashSet}; +use std::collections::VecDeque; +use std::{fmt::Debug, hash::Hash}; + +use crate::{BlockId, FunctionBody, InstId}; + +pub trait Lattice: Clone + Debug + PartialEq + Eq { + fn top() -> Self; + fn bottom() -> Self; + fn meet(a: &Self, b: &Self) -> Self; +} + +pub trait AnalysisKey: Clone + Debug + PartialEq + Eq + Hash {} + +impl AnalysisKey for u32 {} + +#[derive(Clone, Debug)] +pub struct AnalysisValue { + pub values: FxHashMap, +} + +impl std::default::Default for AnalysisValue { + fn default() -> Self { + Self { + values: FxHashMap::default(), + } + } +} + +impl AnalysisValue { + fn meet_with(&mut self, other: &Self, meet_mode: MapMeetMode) -> bool { + let mut changed = false; + let mut to_remove = vec![]; + for (key, value) in &mut self.values { + if let Some(other_value) = other.values.get(key) { + let met = L::meet(value, other_value); + if met != *value { + changed = true; + *value = met; + } + } else { + if meet_mode == MapMeetMode::Intersection { + to_remove.push(key.clone()); + changed = true; + } + } + } + for k in to_remove { + self.values.remove(&k); + } + if meet_mode == MapMeetMode::Union { + for (key, value) in &other.values { + if !self.values.contains_key(key) { + self.values.insert(key.clone(), value.clone()); + changed = true; + } + } + } + + changed + } +} + +pub trait AnalysisFunction { + type K: AnalysisKey; + type L: Lattice; + + fn instruction( + &self, + _input: &mut AnalysisValue, + _func: &FunctionBody, + _block: BlockId, + _inst: InstId, + ) -> bool { + false + } + + fn terminator( + &self, + _input: &mut AnalysisValue, + _func: &FunctionBody, + _block: BlockId, + _index: usize, + _next: BlockId, + ) -> bool { + false + } + + fn meet_mode(&self) -> MapMeetMode { + MapMeetMode::Union + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum MapMeetMode { + Union, + Intersection, +} + +#[derive(Clone, Debug, Default)] +pub struct ForwardDataflow { + block_in: Vec>, + workqueue: VecDeque, + workqueue_set: FxHashSet, +} + +impl ForwardDataflow { + pub fn new(analysis: &F, func: &FunctionBody) -> Self { + let mut ret = ForwardDataflow { + block_in: vec![AnalysisValue::default(); func.blocks.len()], + workqueue: vec![0].into(), + workqueue_set: vec![0].into_iter().collect(), + }; + ret.compute(analysis, func); + ret + } + + fn compute(&mut self, analysis: &F, func: &FunctionBody) { + while let Some(block) = self.workqueue.pop_front() { + self.workqueue_set.remove(&block); + self.update_block(analysis, func, block); + } + } + + fn update_block(&mut self, analysis: &F, func: &FunctionBody, block: BlockId) { + let mut value = self.block_in[block].clone(); + let mut changed = false; + for i in 0..func.blocks[block].insts.len() { + changed |= analysis.instruction(&mut value, func, block, i); + } + + for (i, succ) in func.blocks[block] + .terminator + .successors() + .into_iter() + .enumerate() + { + let mut term_changed = changed; + let mut value = value.clone(); + term_changed |= analysis.terminator(&mut value, func, block, i, succ); + if term_changed { + if self.block_in[succ].meet_with(&value, analysis.meet_mode()) { + if !self.workqueue_set.contains(&succ) { + self.workqueue.push_back(succ); + self.workqueue_set.insert(succ); + } + } + } + } + } +} diff --git a/src/ir.rs b/src/ir.rs index 24bf3db..f6b0fee 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1,6 +1,6 @@ //! Intermediate representation for Wasm. -use crate::frontend; +use crate::{frontend, localssa::LocalSSATransform}; use anyhow::Result; use wasmparser::{FuncType, Operator, Type}; @@ -124,6 +124,42 @@ impl<'a> std::default::Default for Terminator<'a> { impl<'a> Module<'a> { pub fn from_wasm_bytes(bytes: &'a [u8]) -> Result { - frontend::wasm_to_ir(bytes) + let mut module = frontend::wasm_to_ir(bytes)?; + for func in &mut module.funcs { + match func { + &mut FuncDecl::Body(_, ref mut body) => { + let ssa_transform = LocalSSATransform::new(&body); + // TODO + } + _ => {} + } + } + + Ok(module) + } +} + +impl<'a> Terminator<'a> { + pub fn successors(&self) -> Vec { + match self { + Terminator::Return { .. } => vec![], + Terminator::Br { target, .. } => vec![target.block], + Terminator::CondBr { + if_true, if_false, .. + } => vec![if_true.block, if_false.block], + Terminator::Select { + ref targets, + default, + .. + } => { + let mut ret = targets + .iter() + .map(|target| target.block) + .collect::>(); + ret.push(default.block); + ret + } + Terminator::None => vec![], + } } } diff --git a/src/lib.rs b/src/lib.rs index 3541116..2509269 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub use wasm_encoder; pub use wasmparser; +mod dataflow; mod frontend; mod ir; mod localssa; diff --git a/src/localssa.rs b/src/localssa.rs index 86a3d87..afa2984 100644 --- a/src/localssa.rs +++ b/src/localssa.rs @@ -1,2 +1,95 @@ //! Local-to-SSA conversion. +use crate::{ + dataflow::{AnalysisFunction, AnalysisValue, ForwardDataflow, Lattice}, + FunctionBody, Operand, +}; +use crate::{BlockId, InstId, ValueId}; +use wasmparser::Operator; + +// We do a really simple thing for now: +// - Compute "is-there-more-than-one-reaching-definition" property for +// every var at each use site. We do this by tracking a lattice: no +// defs known (top), exactly one def known, multiple defs known +// (bottom). +// - For every var for which there is more than one +// reaching-definition at any use, insert a blockparam on every +// block where the var is live-in. +// - Annotate SSA values as belonging to a disjoint-set (only one live +// at a time) to assist lowering back into Wasm (so they can share a +// local). + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum ReachingDefsLattice { + Unknown, + OneDef(ValueId), + ManyDefs, +} + +impl std::default::Default for ReachingDefsLattice { + fn default() -> Self { + Self::Unknown + } +} + +impl Lattice for ReachingDefsLattice { + fn top() -> Self { + Self::Unknown + } + + fn bottom() -> Self { + Self::ManyDefs + } + + fn meet(a: &Self, b: &Self) -> Self { + match (a, b) { + (a, Self::Unknown) => *a, + (Self::Unknown, b) => *b, + + (Self::OneDef(a), Self::OneDef(b)) if a == b => Self::OneDef(*a), + (Self::OneDef(_), Self::OneDef(_)) => Self::ManyDefs, + + (Self::ManyDefs, _) | (_, Self::ManyDefs) => Self::ManyDefs, + } + } +} + +struct LocalReachingDefsAnalysis; +impl AnalysisFunction for LocalReachingDefsAnalysis { + type K = u32; // Local index + type L = ReachingDefsLattice; + + fn instruction( + &self, + input: &mut AnalysisValue, + func: &FunctionBody, + block: BlockId, + inst: InstId, + ) -> bool { + let inst = &func.blocks[block].insts[inst]; + match &inst.operator { + &Operator::LocalSet { local_index } | &Operator::LocalTee { local_index } => { + if let Operand::Value(value) = inst.inputs[0] { + let value = ReachingDefsLattice::OneDef(value); + let old = input.values.insert(local_index, value); + Some(value) != old + } else { + false + } + } + _ => false, + } + } +} + +pub struct LocalSSATransform { + local_analysis: ForwardDataflow, +} + +impl LocalSSATransform { + pub fn new(func: &FunctionBody) -> Self { + LocalSSATransform { + local_analysis: ForwardDataflow::new(&LocalReachingDefsAnalysis, func), + } + } +}