Basic support for SSA-ifying locals.

This commit is contained in:
Chris Fallin 2021-11-13 22:25:27 -08:00
parent c1d4e0c6b9
commit 8a7a65f94c
5 changed files with 135 additions and 128 deletions

View file

@ -7,7 +7,7 @@ license = "Apache-2.0 WITH LLVM-exception"
edition = "2018"
[dependencies]
wasmparser = "0.81"
wasmparser = { git = 'https://github.com/cfallin/wasm-tools', rev = '03a81b6a6ed4d5d9730fa71bf65636a3b1538ce7' }
wasm-encoder = "0.3"
anyhow = "1.0"
structopt = "0.3"

View file

@ -5,9 +5,11 @@
use crate::ir::*;
use crate::op_traits::{op_inputs, op_outputs};
use anyhow::{bail, Result};
use fxhash::FxHashMap;
use log::trace;
use wasmparser::{
ImportSectionEntryType, Operator, Parser, Payload, Type, TypeDef, TypeOrFuncType,
Ieee32, Ieee64, ImportSectionEntryType, Operator, Parser, Payload, Type, TypeDef,
TypeOrFuncType,
};
pub fn wasm_to_ir(bytes: &[u8]) -> Result<Module<'_>> {
@ -120,6 +122,18 @@ fn parse_body<'a, 'b>(
);
let mut builder = FunctionBodyBuilder::new(module, my_sig, &mut ret);
for (arg_idx, &arg_ty) in module.signatures[my_sig].params.iter().enumerate() {
let local_idx = arg_idx as LocalId;
let value = builder.body.values.len() as ValueId;
builder.body.values.push(ValueDef {
kind: ValueKind::Arg(arg_idx),
ty: arg_ty,
});
trace!("defining local {} to value {}", local_idx, value);
builder.locals.insert(local_idx, (arg_ty, value));
}
let ops = body.get_operators_reader()?;
for op in ops.into_iter() {
let op = op?;
@ -135,6 +149,8 @@ fn parse_body<'a, 'b>(
Ok(ret)
}
type LocalId = u32;
#[derive(Debug)]
struct FunctionBodyBuilder<'a, 'b> {
module: &'b Module<'a>,
@ -143,6 +159,8 @@ struct FunctionBodyBuilder<'a, 'b> {
cur_block: Option<BlockId>,
ctrl_stack: Vec<Frame>,
op_stack: Vec<(Type, ValueId)>,
locals: FxHashMap<LocalId, (Type, ValueId)>,
block_param_locals: FxHashMap<BlockId, Vec<LocalId>>,
}
#[derive(Clone, Debug)]
@ -223,6 +241,8 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
ctrl_stack: vec![],
op_stack: vec![],
cur_block: Some(0),
locals: FxHashMap::default(),
block_param_locals: FxHashMap::default(),
};
// Push initial implicit Block.
@ -260,6 +280,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
trace!("handle_op: {:?}", op);
trace!("op_stack = {:?}", self.op_stack);
trace!("ctrl_stack = {:?}", self.ctrl_stack);
trace!("locals = {:?}", self.locals);
match &op {
Operator::Unreachable => {
if let Some(block) = self.cur_block {
@ -268,10 +289,52 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
self.cur_block = None;
}
Operator::LocalGet { .. }
| Operator::LocalSet { .. }
| Operator::LocalTee { .. }
| Operator::Call { .. }
Operator::LocalGet { local_index } => {
let ty = self.body.locals[*local_index as usize];
let value = self
.locals
.get(local_index)
.map(|(_ty, value)| *value)
.unwrap_or_else(|| {
if let Some(block) = self.cur_block {
let inst = self.body.blocks[block].insts.len() as InstId;
self.emit(match ty {
Type::I32 => Operator::I32Const { value: 0 },
Type::I64 => Operator::I64Const { value: 0 },
Type::F32 => Operator::F32Const {
value: Ieee32::from_bits(0),
},
Type::F64 => Operator::F64Const {
value: Ieee64::from_bits(0),
},
_ => panic!("Unknown type for default value for local: {:?}", ty),
})
.unwrap();
self.op_stack.pop();
let value = self.body.values.len() as ValueId;
self.body.values.push(ValueDef {
ty,
kind: ValueKind::Inst(block, inst, 0),
});
value
} else {
NO_VALUE
}
});
self.op_stack.push((ty, value));
}
Operator::LocalSet { local_index } => {
let (ty, value) = self.op_stack.pop().unwrap();
self.locals.insert(*local_index, (ty, value));
}
Operator::LocalTee { local_index } => {
let value = *self.op_stack.last().unwrap();
self.locals.insert(*local_index, value);
}
Operator::Call { .. }
| Operator::CallIndirect { .. }
| Operator::Select
| Operator::TypedSelect { .. }
@ -707,9 +770,27 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
&self.ctrl_stack[self.ctrl_stack.len() - 1 - relative_depth as usize]
}
fn fill_block_params_with_locals(&mut self, target: BlockId, args: &mut Vec<Operand<'a>>) {
if !self.block_param_locals.contains_key(&target) {
let mut keys: Vec<LocalId> = self.locals.keys().cloned().collect();
keys.sort();
for &local_id in &keys {
let ty = self.body.locals[local_id as usize];
self.body.blocks[target].params.push(ty);
}
self.block_param_locals.insert(target, keys);
}
let block_param_locals = self.block_param_locals.get(&target).unwrap();
for local in block_param_locals {
let local_value = self.locals.get(local).unwrap();
args.push(Operand::value(local_value.1));
}
}
fn emit_branch(&mut self, target: BlockId, args: &[ValueId]) {
if let Some(block) = self.cur_block {
let args = args.iter().map(|&val| Operand::value(val)).collect();
let mut args: Vec<Operand<'a>> = args.iter().map(|&val| Operand::value(val)).collect();
self.fill_block_params_with_locals(target, &mut args);
let target = BlockTarget {
block: target,
args,
@ -727,14 +808,16 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
if_false_args: &[ValueId],
) {
if let Some(block) = self.cur_block {
let if_true_args = if_true_args
let mut if_true_args = if_true_args
.iter()
.map(|&val| Operand::value(val))
.collect();
let if_false_args = if_false_args
let mut if_false_args = if_false_args
.iter()
.map(|&val| Operand::value(val))
.collect();
self.fill_block_params_with_locals(if_true, &mut if_true_args);
self.fill_block_params_with_locals(if_false, &mut if_false_args);
self.body.blocks[block].terminator = Terminator::CondBr {
cond: Operand::value(cond),
if_true: BlockTarget {
@ -760,14 +843,18 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
let args: Vec<Operand<'a>> = args.iter().map(|&arg| Operand::value(arg)).collect();
let targets = indexed_targets
.iter()
.map(|&block| BlockTarget {
block,
args: args.clone(),
.map(|&block| {
let mut args = args.clone();
self.fill_block_params_with_locals(block, &mut args);
BlockTarget { block, args }
})
.collect();
let mut default_args = args;
self.fill_block_params_with_locals(default_target, &mut default_args);
let default = BlockTarget {
block: default_target,
args: args.clone(),
args: default_args,
};
self.body.blocks[block].terminator = Terminator::Select {
value: Operand::value(index),
@ -785,15 +872,39 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
}
fn push_block_params(&mut self) {
let tys = &self.body.blocks[self.cur_block.unwrap()].params[..];
let block = self.cur_block.unwrap();
let tys = &self.body.blocks[block].params[..];
let num_local_params = self
.block_param_locals
.get(&block)
.map(|l| l.len())
.unwrap_or(0);
let wasm_stack_val_tys = &tys[0..(tys.len() - num_local_params)];
for (i, &ty) in tys.iter().enumerate() {
let mut block_param_num = 0;
for &ty in wasm_stack_val_tys.iter() {
let value_id = self.body.values.len() as ValueId;
self.body.values.push(ValueDef {
kind: ValueKind::BlockParam(self.cur_block.unwrap(), i),
kind: ValueKind::BlockParam(block, block_param_num),
ty,
});
self.op_stack.push((ty, value_id));
block_param_num += 1;
}
if let Some(block_param_locals) = self.block_param_locals.get(&block) {
for (&ty, &local_id) in tys[tys.len() - num_local_params..]
.iter()
.zip(block_param_locals.iter())
{
let value_id = self.body.values.len() as ValueId;
self.body.values.push(ValueDef {
kind: ValueKind::BlockParam(block, block_param_num),
ty,
});
block_param_num += 1;
self.locals.insert(local_id, (ty, value_id));
}
}
}
@ -819,11 +930,11 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
input_operands.reverse();
let mut output_operands = vec![];
for output_ty in outputs.into_iter() {
for (i, output_ty) in outputs.into_iter().enumerate() {
let val = self.body.values.len() as ValueId;
output_operands.push(val);
self.body.values.push(ValueDef {
kind: ValueKind::Inst(block, inst),
kind: ValueKind::Inst(block, inst, i),
ty: output_ty,
});
self.op_stack.push((output_ty, val));

View file

@ -1,6 +1,6 @@
//! Intermediate representation for Wasm.
use crate::{frontend, localssa::LocalSSATransform};
use crate::frontend;
use anyhow::Result;
use wasmparser::{FuncType, Operator, Type};
@ -50,8 +50,9 @@ pub struct ValueDef {
#[derive(Clone, Debug)]
pub enum ValueKind {
Arg(usize),
BlockParam(BlockId, usize),
Inst(BlockId, InstId),
Inst(BlockId, InstId, usize),
}
#[derive(Clone, Debug, Default)]
@ -124,18 +125,7 @@ impl<'a> std::default::Default for Terminator<'a> {
impl<'a> Module<'a> {
pub fn from_wasm_bytes(bytes: &'a [u8]) -> Result<Self> {
let mut module = frontend::wasm_to_ir(bytes)?;
for func in &mut module.funcs {
match func {
&mut FuncDecl::Body(_, ref mut body) => {
let ssa_transform = LocalSSATransform::new(&body);
// TODO
}
_ => {}
}
}
Ok(module)
frontend::wasm_to_ir(bytes)
}
}

View file

@ -1,5 +1,7 @@
//! WAFFLE Wasm analysis framework.
#![allow(dead_code)]
// Re-export wasmparser and wasmencoder for easier use of the right
// version by our embedders.
pub use wasm_encoder;
@ -8,7 +10,6 @@ pub use wasmparser;
mod dataflow;
mod frontend;
mod ir;
mod localssa;
mod op_traits;
pub use ir::*;

View file

@ -1,95 +0,0 @@
//! Local-to-SSA conversion.
use crate::{
dataflow::{AnalysisFunction, AnalysisValue, ForwardDataflow, Lattice},
FunctionBody, Operand,
};
use crate::{BlockId, InstId, ValueId};
use wasmparser::Operator;
// We do a really simple thing for now:
// - Compute "is-there-more-than-one-reaching-definition" property for
// every var at each use site. We do this by tracking a lattice: no
// defs known (top), exactly one def known, multiple defs known
// (bottom).
// - For every var for which there is more than one
// reaching-definition at any use, insert a blockparam on every
// block where the var is live-in.
// - Annotate SSA values as belonging to a disjoint-set (only one live
// at a time) to assist lowering back into Wasm (so they can share a
// local).
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ReachingDefsLattice {
Unknown,
OneDef(ValueId),
ManyDefs,
}
impl std::default::Default for ReachingDefsLattice {
fn default() -> Self {
Self::Unknown
}
}
impl Lattice for ReachingDefsLattice {
fn top() -> Self {
Self::Unknown
}
fn bottom() -> Self {
Self::ManyDefs
}
fn meet(a: &Self, b: &Self) -> Self {
match (a, b) {
(a, Self::Unknown) => *a,
(Self::Unknown, b) => *b,
(Self::OneDef(a), Self::OneDef(b)) if a == b => Self::OneDef(*a),
(Self::OneDef(_), Self::OneDef(_)) => Self::ManyDefs,
(Self::ManyDefs, _) | (_, Self::ManyDefs) => Self::ManyDefs,
}
}
}
struct LocalReachingDefsAnalysis;
impl AnalysisFunction for LocalReachingDefsAnalysis {
type K = u32; // Local index
type L = ReachingDefsLattice;
fn instruction(
&self,
input: &mut AnalysisValue<Self::K, Self::L>,
func: &FunctionBody,
block: BlockId,
inst: InstId,
) -> bool {
let inst = &func.blocks[block].insts[inst];
match &inst.operator {
&Operator::LocalSet { local_index } | &Operator::LocalTee { local_index } => {
if let Operand::Value(value) = inst.inputs[0] {
let value = ReachingDefsLattice::OneDef(value);
let old = input.values.insert(local_index, value);
Some(value) != old
} else {
false
}
}
_ => false,
}
}
}
pub struct LocalSSATransform {
local_analysis: ForwardDataflow<LocalReachingDefsAnalysis>,
}
impl LocalSSATransform {
pub fn new(func: &FunctionBody) -> Self {
LocalSSATransform {
local_analysis: ForwardDataflow::new(&LocalReachingDefsAnalysis, func),
}
}
}