Add maximal-SSA mode.

This commit is contained in:
Chris Fallin 2022-12-02 11:58:04 -08:00
parent 5bdb4a1737
commit b6ce3abc1d
7 changed files with 212 additions and 13 deletions

View file

@ -134,7 +134,8 @@ fuzz_target!(|module: wasm_smith::ConfiguredModule<Config>| {
} }
}; };
let parsed_module = Module::from_wasm_bytes(&orig_bytes[..]).unwrap(); let mut parsed_module = Module::from_wasm_bytes(&orig_bytes[..]).unwrap();
parsed_module.optimize();
let roundtrip_bytes = parsed_module.to_wasm_bytes().unwrap(); let roundtrip_bytes = parsed_module.to_wasm_bytes().unwrap();
if let Ok(filename) = std::env::var("FUZZ_DUMP_WASM") { if let Ok(filename) = std::env::var("FUZZ_DUMP_WASM") {

View file

@ -7,7 +7,7 @@ fuzz_target!(|module: wasm_smith::Module| {
let _ = env_logger::try_init(); let _ = env_logger::try_init();
log::debug!("original module: {:?}", module); log::debug!("original module: {:?}", module);
let orig_bytes = module.to_bytes(); let orig_bytes = module.to_bytes();
let parsed_module = match Module::from_wasm_bytes(&orig_bytes[..]) { let mut parsed_module = match Module::from_wasm_bytes(&orig_bytes[..]) {
Ok(m) => m, Ok(m) => m,
Err(e) => { Err(e) => {
match e.downcast::<FrontendError>() { match e.downcast::<FrontendError>() {
@ -24,5 +24,6 @@ fuzz_target!(|module: wasm_smith::Module| {
} }
} }
}; };
parsed_module.optimize();
let _ = parsed_module.to_wasm_bytes(); let _ = parsed_module.to_wasm_bytes();
}); });

View file

@ -12,6 +12,15 @@ struct Options {
#[structopt(short, long)] #[structopt(short, long)]
debug: bool, debug: bool,
#[structopt(
help = "Do basic optimizations: GVN and const-prop",
long = "basic-opts"
)]
basic_opts: bool,
#[structopt(help = "Transform to maximal SSA", long = "max-ssa")]
max_ssa: bool,
#[structopt(subcommand)] #[structopt(subcommand)]
command: Command, command: Command,
} }
@ -45,13 +54,25 @@ fn main() -> Result<()> {
Command::PrintIR { wasm } => { Command::PrintIR { wasm } => {
let bytes = std::fs::read(wasm)?; let bytes = std::fs::read(wasm)?;
debug!("Loaded {} bytes of Wasm data", bytes.len()); debug!("Loaded {} bytes of Wasm data", bytes.len());
let module = Module::from_wasm_bytes(&bytes[..])?; let mut module = Module::from_wasm_bytes(&bytes[..])?;
if opts.basic_opts {
module.optimize();
}
if opts.max_ssa {
module.convert_to_max_ssa();
}
println!("{}", module.display()); println!("{}", module.display());
} }
Command::RoundTrip { input, output } => { Command::RoundTrip { input, output } => {
let bytes = std::fs::read(input)?; let bytes = std::fs::read(input)?;
debug!("Loaded {} bytes of Wasm data", bytes.len()); debug!("Loaded {} bytes of Wasm data", bytes.len());
let module = Module::from_wasm_bytes(&bytes[..])?; let mut module = Module::from_wasm_bytes(&bytes[..])?;
if opts.basic_opts {
module.optimize();
}
if opts.max_ssa {
module.convert_to_max_ssa();
}
let produced = module.to_wasm_bytes()?; let produced = module.to_wasm_bytes()?;
std::fs::write(output, &produced[..])?; std::fs::write(output, &produced[..])?;
} }

View file

@ -80,6 +80,9 @@ impl CFGInfo {
let mut def_block: PerEntity<Value, Block> = PerEntity::default(); let mut def_block: PerEntity<Value, Block> = PerEntity::default();
for (block, block_def) in f.blocks.entries() { for (block, block_def) in f.blocks.entries() {
for &(_, param) in &block_def.params {
def_block[param] = block;
}
for &value in &block_def.insts { for &value in &block_def.insts {
def_block[value] = block; def_block[value] = block;
} }

View file

@ -221,21 +221,36 @@ impl<'a> Module<'a> {
} }
pub fn from_wasm_bytes(bytes: &'a [u8]) -> Result<Self> { pub fn from_wasm_bytes(bytes: &'a [u8]) -> Result<Self> {
let mut module = frontend::wasm_to_ir(bytes)?; frontend::wasm_to_ir(bytes)
for func_decl in module.funcs.values_mut() {
if let Some(body) = func_decl.body_mut() {
let cfg = crate::cfg::CFGInfo::new(body);
crate::passes::basic_opt::gvn(body, &cfg);
crate::passes::resolve_aliases::run(body);
}
}
Ok(module)
} }
pub fn to_wasm_bytes(&self) -> Result<Vec<u8>> { pub fn to_wasm_bytes(&self) -> Result<Vec<u8>> {
backend::compile(self) backend::compile(self)
} }
pub fn per_func_body<F: Fn(&mut FunctionBody)>(&mut self, f: F) {
for func_decl in self.funcs.values_mut() {
if let Some(body) = func_decl.body_mut() {
f(body);
}
}
}
pub fn optimize(&mut self) {
self.per_func_body(|body| {
let cfg = crate::cfg::CFGInfo::new(body);
crate::passes::basic_opt::gvn(body, &cfg);
crate::passes::resolve_aliases::run(body);
});
}
pub fn convert_to_max_ssa(&mut self) {
self.per_func_body(|body| {
let cfg = crate::cfg::CFGInfo::new(body);
crate::passes::maxssa::run(body, &cfg);
});
}
pub fn display<'b>(&'b self) -> ModuleDisplay<'b> pub fn display<'b>(&'b self) -> ModuleDisplay<'b>
where where
'b: 'a, 'b: 'a,

View file

@ -3,3 +3,4 @@
pub mod basic_opt; pub mod basic_opt;
pub mod dom_pass; pub mod dom_pass;
pub mod resolve_aliases; pub mod resolve_aliases;
pub mod maxssa;

157
src/passes/maxssa.rs Normal file
View file

@ -0,0 +1,157 @@
//! Conversion pass that creates "maximal SSA": only local uses (no
//! uses of defs in other blocks), with all values explicitly passed
//! through blockparams. This makes some other transforms easier
//! because it removes the need to worry about adding blockparams when
//! mutating the CFG (all possible blockparams are already there!).
use crate::cfg::CFGInfo;
use crate::entity::PerEntity;
use crate::ir::{Block, FunctionBody, Value, ValueDef};
use std::collections::{BTreeSet, HashMap};
pub fn run(body: &mut FunctionBody, cfg: &CFGInfo) {
MaxSSAPass::new().run(body, cfg);
}
struct MaxSSAPass {
/// Additional block args that must be passed to each block, in
/// order. Value numbers are *original* values.
new_args: PerEntity<Block, Vec<Value>>,
/// For each block, a value map: from original value to local copy
/// of value.
value_map: HashMap<(Block, Value), Value>,
}
impl MaxSSAPass {
fn new() -> Self {
Self {
new_args: PerEntity::default(),
value_map: HashMap::new(),
}
}
fn run(mut self, body: &mut FunctionBody, cfg: &CFGInfo) {
for block in body.blocks.iter() {
self.visit(body, cfg, block);
}
self.update(body);
}
fn visit(&mut self, body: &mut FunctionBody, cfg: &CFGInfo, block: Block) {
// For each use in the block, process the use. Collect all
// uses first to deduplicate and allow more efficient
// processing (and to appease the borrow checker).
let mut uses = BTreeSet::default();
for &inst in &body.blocks[block].insts {
match &body.values[inst] {
&ValueDef::Operator(_, ref args, _) => {
for &arg in args {
let arg = body.resolve_alias(arg);
uses.insert(arg);
}
}
&ValueDef::PickOutput(value, ..) => {
let value = body.resolve_alias(value);
uses.insert(value);
}
_ => {}
}
}
body.blocks[block].terminator.visit_uses(|u| {
let u = body.resolve_alias(u);
uses.insert(u);
});
for u in uses {
self.visit_use(body, cfg, block, u);
}
}
fn visit_use(&mut self, body: &mut FunctionBody, cfg: &CFGInfo, block: Block, value: Value) {
if self.value_map.contains_key(&(block, value)) {
return;
}
if cfg.def_block[value] == block {
return;
}
self.new_args[block].push(value);
// Create a blockparam.
let ty = body.values[value].ty().unwrap();
let blockparam = body.add_blockparam(block, ty);
self.value_map.insert((block, value), blockparam);
// Recursively visit preds and use the value there, to ensure
// they have the value available as well.
for i in 0..body.blocks[block].preds.len() {
// Don't borrow for whole loop while iterating (`body` is
// taken as mut by recursion, but we don't add preds).
let pred = body.blocks[block].preds[i];
self.visit_use(body, cfg, pred, value);
}
}
fn update_preds(&mut self, body: &mut FunctionBody, block: Block) {
for i in 0..body.blocks[block].preds.len() {
let pred = body.blocks[block].preds[i];
let pred_succ_idx = body.blocks[block].pos_in_pred_succ[i];
body.blocks[pred]
.terminator
.update_target(pred_succ_idx, |target| {
for &new_arg in &self.new_args[block] {
let actual_value = self
.value_map
.get(&(pred, new_arg))
.copied()
.unwrap_or(new_arg);
target.args.push(actual_value);
}
});
}
}
fn update_uses(&mut self, body: &mut FunctionBody, block: Block) {
let resolve = |body: &FunctionBody, value: Value| {
let value = body.resolve_alias(value);
self.value_map
.get(&(block, value))
.copied()
.unwrap_or(value)
};
for i in 0..body.blocks[block].insts.len() {
let inst = body.blocks[block].insts[i];
let mut def = std::mem::take(&mut body.values[inst]);
match &mut def {
ValueDef::Operator(_, args, _) => {
for arg in args {
*arg = resolve(body, *arg);
}
}
ValueDef::PickOutput(value, ..) => {
*value = resolve(body, *value);
}
ValueDef::Alias(_) => {
// Nullify the alias: should no longer be needed.
def = ValueDef::None;
}
_ => {}
}
body.values[inst] = def;
}
let mut term = std::mem::take(&mut body.blocks[block].terminator);
term.update_uses(|u| {
*u = resolve(body, *u);
});
body.blocks[block].terminator = term;
}
fn update(&mut self, body: &mut FunctionBody) {
for block in body.blocks.iter() {
if self.new_args[block].len() > 0 {
self.update_preds(body, block);
}
self.update_uses(body, block);
}
}
}