diff --git a/Cargo.toml b/Cargo.toml index c7d06dd..e637744 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ license = "Apache-2.0 WITH LLVM-exception" edition = "2018" [dependencies] -wasmparser = "0.81" +wasmparser = { git = 'https://github.com/cfallin/wasm-tools', rev = '03a81b6a6ed4d5d9730fa71bf65636a3b1538ce7' } wasm-encoder = "0.3" anyhow = "1.0" structopt = "0.3" diff --git a/src/frontend.rs b/src/frontend.rs index bafe1bd..6ddcf47 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -5,9 +5,11 @@ use crate::ir::*; use crate::op_traits::{op_inputs, op_outputs}; use anyhow::{bail, Result}; +use fxhash::FxHashMap; use log::trace; use wasmparser::{ - ImportSectionEntryType, Operator, Parser, Payload, Type, TypeDef, TypeOrFuncType, + Ieee32, Ieee64, ImportSectionEntryType, Operator, Parser, Payload, Type, TypeDef, + TypeOrFuncType, }; pub fn wasm_to_ir(bytes: &[u8]) -> Result> { @@ -120,6 +122,18 @@ fn parse_body<'a, 'b>( ); let mut builder = FunctionBodyBuilder::new(module, my_sig, &mut ret); + + for (arg_idx, &arg_ty) in module.signatures[my_sig].params.iter().enumerate() { + let local_idx = arg_idx as LocalId; + let value = builder.body.values.len() as ValueId; + builder.body.values.push(ValueDef { + kind: ValueKind::Arg(arg_idx), + ty: arg_ty, + }); + trace!("defining local {} to value {}", local_idx, value); + builder.locals.insert(local_idx, (arg_ty, value)); + } + let ops = body.get_operators_reader()?; for op in ops.into_iter() { let op = op?; @@ -135,6 +149,8 @@ fn parse_body<'a, 'b>( Ok(ret) } +type LocalId = u32; + #[derive(Debug)] struct FunctionBodyBuilder<'a, 'b> { module: &'b Module<'a>, @@ -143,6 +159,8 @@ struct FunctionBodyBuilder<'a, 'b> { cur_block: Option, ctrl_stack: Vec, op_stack: Vec<(Type, ValueId)>, + locals: FxHashMap, + block_param_locals: FxHashMap>, } #[derive(Clone, Debug)] @@ -223,6 +241,8 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { ctrl_stack: vec![], op_stack: vec![], cur_block: Some(0), + locals: FxHashMap::default(), + block_param_locals: FxHashMap::default(), }; // Push initial implicit Block. @@ -260,6 +280,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { trace!("handle_op: {:?}", op); trace!("op_stack = {:?}", self.op_stack); trace!("ctrl_stack = {:?}", self.ctrl_stack); + trace!("locals = {:?}", self.locals); match &op { Operator::Unreachable => { if let Some(block) = self.cur_block { @@ -268,10 +289,52 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { self.cur_block = None; } - Operator::LocalGet { .. } - | Operator::LocalSet { .. } - | Operator::LocalTee { .. } - | Operator::Call { .. } + Operator::LocalGet { local_index } => { + let ty = self.body.locals[*local_index as usize]; + let value = self + .locals + .get(local_index) + .map(|(_ty, value)| *value) + .unwrap_or_else(|| { + if let Some(block) = self.cur_block { + let inst = self.body.blocks[block].insts.len() as InstId; + self.emit(match ty { + Type::I32 => Operator::I32Const { value: 0 }, + Type::I64 => Operator::I64Const { value: 0 }, + Type::F32 => Operator::F32Const { + value: Ieee32::from_bits(0), + }, + Type::F64 => Operator::F64Const { + value: Ieee64::from_bits(0), + }, + _ => panic!("Unknown type for default value for local: {:?}", ty), + }) + .unwrap(); + self.op_stack.pop(); + let value = self.body.values.len() as ValueId; + self.body.values.push(ValueDef { + ty, + kind: ValueKind::Inst(block, inst, 0), + }); + value + } else { + NO_VALUE + } + }); + self.op_stack.push((ty, value)); + } + + Operator::LocalSet { local_index } => { + let (ty, value) = self.op_stack.pop().unwrap(); + self.locals.insert(*local_index, (ty, value)); + } + + Operator::LocalTee { local_index } => { + let value = *self.op_stack.last().unwrap(); + self.locals.insert(*local_index, value); + } + + Operator::Call { .. } | Operator::CallIndirect { .. } | Operator::Select | Operator::TypedSelect { .. } @@ -707,9 +770,27 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { &self.ctrl_stack[self.ctrl_stack.len() - 1 - relative_depth as usize] } + fn fill_block_params_with_locals(&mut self, target: BlockId, args: &mut Vec>) { + if !self.block_param_locals.contains_key(&target) { + let mut keys: Vec = self.locals.keys().cloned().collect(); + keys.sort(); + for &local_id in &keys { + let ty = self.body.locals[local_id as usize]; + self.body.blocks[target].params.push(ty); + } + self.block_param_locals.insert(target, keys); + } + let block_param_locals = self.block_param_locals.get(&target).unwrap(); + for local in block_param_locals { + let local_value = self.locals.get(local).unwrap(); + args.push(Operand::value(local_value.1)); + } + } + fn emit_branch(&mut self, target: BlockId, args: &[ValueId]) { if let Some(block) = self.cur_block { - let args = args.iter().map(|&val| Operand::value(val)).collect(); + let mut args: Vec> = args.iter().map(|&val| Operand::value(val)).collect(); + self.fill_block_params_with_locals(target, &mut args); let target = BlockTarget { block: target, args, @@ -727,14 +808,16 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { if_false_args: &[ValueId], ) { if let Some(block) = self.cur_block { - let if_true_args = if_true_args + let mut if_true_args = if_true_args .iter() .map(|&val| Operand::value(val)) .collect(); - let if_false_args = if_false_args + let mut if_false_args = if_false_args .iter() .map(|&val| Operand::value(val)) .collect(); + self.fill_block_params_with_locals(if_true, &mut if_true_args); + self.fill_block_params_with_locals(if_false, &mut if_false_args); self.body.blocks[block].terminator = Terminator::CondBr { cond: Operand::value(cond), if_true: BlockTarget { @@ -760,14 +843,18 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { let args: Vec> = args.iter().map(|&arg| Operand::value(arg)).collect(); let targets = indexed_targets .iter() - .map(|&block| BlockTarget { - block, - args: args.clone(), + .map(|&block| { + let mut args = args.clone(); + self.fill_block_params_with_locals(block, &mut args); + BlockTarget { block, args } }) .collect(); + + let mut default_args = args; + self.fill_block_params_with_locals(default_target, &mut default_args); let default = BlockTarget { block: default_target, - args: args.clone(), + args: default_args, }; self.body.blocks[block].terminator = Terminator::Select { value: Operand::value(index), @@ -785,15 +872,39 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { } fn push_block_params(&mut self) { - let tys = &self.body.blocks[self.cur_block.unwrap()].params[..]; + let block = self.cur_block.unwrap(); + let tys = &self.body.blocks[block].params[..]; + let num_local_params = self + .block_param_locals + .get(&block) + .map(|l| l.len()) + .unwrap_or(0); + let wasm_stack_val_tys = &tys[0..(tys.len() - num_local_params)]; - for (i, &ty) in tys.iter().enumerate() { + let mut block_param_num = 0; + for &ty in wasm_stack_val_tys.iter() { let value_id = self.body.values.len() as ValueId; self.body.values.push(ValueDef { - kind: ValueKind::BlockParam(self.cur_block.unwrap(), i), + kind: ValueKind::BlockParam(block, block_param_num), ty, }); self.op_stack.push((ty, value_id)); + block_param_num += 1; + } + + if let Some(block_param_locals) = self.block_param_locals.get(&block) { + for (&ty, &local_id) in tys[tys.len() - num_local_params..] + .iter() + .zip(block_param_locals.iter()) + { + let value_id = self.body.values.len() as ValueId; + self.body.values.push(ValueDef { + kind: ValueKind::BlockParam(block, block_param_num), + ty, + }); + block_param_num += 1; + self.locals.insert(local_id, (ty, value_id)); + } } } @@ -819,11 +930,11 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { input_operands.reverse(); let mut output_operands = vec![]; - for output_ty in outputs.into_iter() { + for (i, output_ty) in outputs.into_iter().enumerate() { let val = self.body.values.len() as ValueId; output_operands.push(val); self.body.values.push(ValueDef { - kind: ValueKind::Inst(block, inst), + kind: ValueKind::Inst(block, inst, i), ty: output_ty, }); self.op_stack.push((output_ty, val)); diff --git a/src/ir.rs b/src/ir.rs index f6b0fee..9ed0d82 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1,6 +1,6 @@ //! Intermediate representation for Wasm. -use crate::{frontend, localssa::LocalSSATransform}; +use crate::frontend; use anyhow::Result; use wasmparser::{FuncType, Operator, Type}; @@ -50,8 +50,9 @@ pub struct ValueDef { #[derive(Clone, Debug)] pub enum ValueKind { + Arg(usize), BlockParam(BlockId, usize), - Inst(BlockId, InstId), + Inst(BlockId, InstId, usize), } #[derive(Clone, Debug, Default)] @@ -124,18 +125,7 @@ impl<'a> std::default::Default for Terminator<'a> { impl<'a> Module<'a> { pub fn from_wasm_bytes(bytes: &'a [u8]) -> Result { - let mut module = frontend::wasm_to_ir(bytes)?; - for func in &mut module.funcs { - match func { - &mut FuncDecl::Body(_, ref mut body) => { - let ssa_transform = LocalSSATransform::new(&body); - // TODO - } - _ => {} - } - } - - Ok(module) + frontend::wasm_to_ir(bytes) } } diff --git a/src/lib.rs b/src/lib.rs index 2509269..1752de9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,7 @@ //! WAFFLE Wasm analysis framework. +#![allow(dead_code)] + // Re-export wasmparser and wasmencoder for easier use of the right // version by our embedders. pub use wasm_encoder; @@ -8,7 +10,6 @@ pub use wasmparser; mod dataflow; mod frontend; mod ir; -mod localssa; mod op_traits; pub use ir::*; diff --git a/src/localssa.rs b/src/localssa.rs deleted file mode 100644 index afa2984..0000000 --- a/src/localssa.rs +++ /dev/null @@ -1,95 +0,0 @@ -//! Local-to-SSA conversion. - -use crate::{ - dataflow::{AnalysisFunction, AnalysisValue, ForwardDataflow, Lattice}, - FunctionBody, Operand, -}; -use crate::{BlockId, InstId, ValueId}; -use wasmparser::Operator; - -// We do a really simple thing for now: -// - Compute "is-there-more-than-one-reaching-definition" property for -// every var at each use site. We do this by tracking a lattice: no -// defs known (top), exactly one def known, multiple defs known -// (bottom). -// - For every var for which there is more than one -// reaching-definition at any use, insert a blockparam on every -// block where the var is live-in. -// - Annotate SSA values as belonging to a disjoint-set (only one live -// at a time) to assist lowering back into Wasm (so they can share a -// local). - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum ReachingDefsLattice { - Unknown, - OneDef(ValueId), - ManyDefs, -} - -impl std::default::Default for ReachingDefsLattice { - fn default() -> Self { - Self::Unknown - } -} - -impl Lattice for ReachingDefsLattice { - fn top() -> Self { - Self::Unknown - } - - fn bottom() -> Self { - Self::ManyDefs - } - - fn meet(a: &Self, b: &Self) -> Self { - match (a, b) { - (a, Self::Unknown) => *a, - (Self::Unknown, b) => *b, - - (Self::OneDef(a), Self::OneDef(b)) if a == b => Self::OneDef(*a), - (Self::OneDef(_), Self::OneDef(_)) => Self::ManyDefs, - - (Self::ManyDefs, _) | (_, Self::ManyDefs) => Self::ManyDefs, - } - } -} - -struct LocalReachingDefsAnalysis; -impl AnalysisFunction for LocalReachingDefsAnalysis { - type K = u32; // Local index - type L = ReachingDefsLattice; - - fn instruction( - &self, - input: &mut AnalysisValue, - func: &FunctionBody, - block: BlockId, - inst: InstId, - ) -> bool { - let inst = &func.blocks[block].insts[inst]; - match &inst.operator { - &Operator::LocalSet { local_index } | &Operator::LocalTee { local_index } => { - if let Operand::Value(value) = inst.inputs[0] { - let value = ReachingDefsLattice::OneDef(value); - let old = input.values.insert(local_index, value); - Some(value) != old - } else { - false - } - } - _ => false, - } - } -} - -pub struct LocalSSATransform { - local_analysis: ForwardDataflow, -} - -impl LocalSSATransform { - pub fn new(func: &FunctionBody) -> Self { - LocalSSATransform { - local_analysis: ForwardDataflow::new(&LocalReachingDefsAnalysis, func), - } - } -}