diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 5704836..6af0930 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -238,7 +238,6 @@ impl<'a> WasmFuncBackend<'a> { ty: sig_index.index() as u32, table: table_index.index() as u32, }), - Operator::Return => Some(wasm_encoder::Instruction::Return), Operator::Select => Some(wasm_encoder::Instruction::Select), Operator::TypedSelect { ty } => Some(wasm_encoder::Instruction::TypedSelect( wasm_encoder::ValType::from(*ty), diff --git a/src/frontend.rs b/src/frontend.rs index d7d4221..023ea81 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -301,7 +301,7 @@ fn parse_body<'a>( let entry = Block::new(0); builder.body.entry = entry; builder.locals.seal_block_preds(entry, &mut builder.body); - builder.locals.start_block(entry); + builder.locals.start_block(entry, false); for (arg_idx, &arg_ty) in module.signature(my_sig).params.iter().enumerate() { let local_idx = Local::new(arg_idx); @@ -321,10 +321,14 @@ fn parse_body<'a>( let ops = body.get_operators_reader()?; for op in ops.into_iter() { let op = op?; - builder.handle_op(op)?; + if builder.reachable { + builder.handle_op(op)?; + } else { + builder.handle_op_unreachable(op)?; + } } - if builder.cur_block.is_some() { + if builder.reachable { builder.handle_op(wasmparser::Operator::Return)?; } @@ -346,7 +350,7 @@ struct LocalTracker { /// Types of locals, as declared. types: FxHashMap, /// The current block. - cur_block: Option, + cur_block: Block, /// Is the given block sealed? block_sealed: FxHashSet, /// The local-to-value mapping at the start of a block. @@ -363,19 +367,18 @@ impl LocalTracker { assert!(!was_present); } - pub fn start_block(&mut self, block: Block) { - self.finish_block(); + pub fn start_block(&mut self, block: Block, was_reachable: bool) { + self.finish_block(was_reachable); log::trace!("start_block: block {}", block); - self.cur_block = Some(block); + self.cur_block = block; } - pub fn finish_block(&mut self) { - log::trace!("finish_block: block {:?}", self.cur_block); - if let Some(block) = self.cur_block { + pub fn finish_block(&mut self, reachable: bool) { + log::trace!("finish_block: block {}", self.cur_block); + if reachable { let mapping = std::mem::take(&mut self.in_cur_block); - self.block_end.insert(block, mapping); + self.block_end.insert(self.cur_block, mapping); } - self.cur_block = None; } pub fn seal_block_preds(&mut self, block: Block, body: &mut FunctionBody) { @@ -404,7 +407,7 @@ impl LocalTracker { log::trace!("get_in_block: at_block {} local {}", at_block, local); let ty = body.locals[local]; - if self.cur_block == Some(at_block) { + if self.cur_block == at_block { if let Some(&value) = self.in_cur_block.get(&local) { log::trace!(" -> {:?}", value); return value; @@ -462,11 +465,7 @@ impl LocalTracker { } pub fn get(&mut self, body: &mut FunctionBody, local: Local) -> Value { - if let Some(block) = self.cur_block { - self.get_in_block(body, block, local) - } else { - Value::invalid() - } + self.get_in_block(body, self.cur_block, local) } fn create_default_value( @@ -592,7 +591,8 @@ struct FunctionBodyBuilder<'a, 'b> { my_sig: Signature, body: &'b mut FunctionBody, locals: LocalTracker, - cur_block: Option, + cur_block: Block, + reachable: bool, ctrl_stack: Vec, op_stack: Vec<(Type, Value)>, } @@ -604,6 +604,7 @@ enum Frame { out: Block, params: Vec, results: Vec, + out_reachable: bool, }, Loop { start_depth: usize, @@ -619,12 +620,14 @@ enum Frame { param_values: Vec<(Type, Value)>, params: Vec, results: Vec, + head_reachable: bool, }, Else { start_depth: usize, out: Block, params: Vec, results: Vec, + merge_reachable: bool, }, } @@ -681,6 +684,13 @@ impl Frame { | Frame::Else { results, .. } => &results[..], } } + + fn set_reachable(&mut self) { + match self { + Frame::Block { out_reachable, .. } => *out_reachable = true, + _ => {} + } + } } impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { @@ -692,7 +702,8 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { body, ctrl_stack: vec![], op_stack: vec![], - cur_block: Some(Block::new(0)), + cur_block: Block::new(0), + reachable: true, locals: LocalTracker::default(), }; @@ -705,6 +716,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { out, params: vec![], results, + out_reachable: false, }); ret } @@ -741,13 +753,16 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { trace!("op_stack = {:?}", self.op_stack); trace!("ctrl_stack = {:?}", self.ctrl_stack); trace!("locals = {:?}", self.locals); + + debug_assert!(self.reachable); + + if self.handle_ctrl_op(op.clone())? { + return Ok(()); + } + match &op { wasmparser::Operator::Unreachable => { - if let Some(block) = self.cur_block { - self.body.end_block(block, Terminator::Unreachable); - self.locals.finish_block(); - } - self.cur_block = None; + self.emit_unreachable(); } wasmparser::Operator::LocalGet { local_index } => { @@ -760,17 +775,13 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { wasmparser::Operator::LocalSet { local_index } => { let local_index = Local::from(*local_index); let (_, value) = self.op_stack.pop().unwrap(); - if self.cur_block.is_some() { - self.locals.set(local_index, value); - } + self.locals.set(local_index, value); } wasmparser::Operator::LocalTee { local_index } => { let local_index = Local::from(*local_index); let (_ty, value) = *self.op_stack.last().unwrap(); - if self.cur_block.is_some() { - self.locals.set(local_index, value); - } + self.locals.set(local_index, value); } wasmparser::Operator::Call { .. } @@ -957,11 +968,104 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { let _ = self.pop_1(); } + wasmparser::Operator::Br { relative_depth } + | wasmparser::Operator::BrIf { relative_depth } => { + let cond = match &op { + wasmparser::Operator::Br { .. } => None, + wasmparser::Operator::BrIf { .. } => Some(self.pop_1()), + _ => unreachable!(), + }; + // Get the frame we're branching to. + let frame = self.relative_frame(*relative_depth); + frame.set_reachable(); + let frame = frame.clone(); + log::trace!("Br/BrIf: dest frame {:?}", frame); + // Finally, generate the branch itself. + match cond { + None => { + // Get the args off the stack unconditionally. + let args = self.pop_n(frame.br_args().len()); + self.emit_branch(frame.br_target(), &args[..]); + self.reachable = false; + } + Some(cond) => { + let cont = self.body.add_block(); + // Get the args off the stack but leave for the fallthrough. + let args = self.op_stack[self.op_stack.len() - frame.br_args().len()..] + .iter() + .map(|(_ty, value)| *value) + .collect::>(); + self.emit_cond_branch(cond, frame.br_target(), &args[..], cont, &[]); + self.locals.seal_block_preds(cont, &mut self.body); + self.cur_block = cont; + self.locals.start_block(cont, self.reachable); + } + } + } + + wasmparser::Operator::BrTable { targets } => { + // Get the selector index. + let index = self.pop_1(); + // Get the signature of the default frame; this tells + // us the signature of all frames (since wasmparser + // validates the input for us). Pop that many args. + let default_frame = self.relative_frame(targets.default()); + let default_term_target = default_frame.br_target(); + let arg_len = default_frame.br_args().len(); + let args = self.pop_n(arg_len); + // Generate a branch terminator with the same args for + // every branch target. + let mut term_targets = vec![]; + for target in targets.targets() { + let target = target?; + let frame = self.relative_frame(target); + frame.set_reachable(); + assert_eq!(frame.br_args().len(), args.len()); + let block = frame.br_target(); + term_targets.push(block); + } + self.emit_br_table(index, default_term_target, &term_targets[..], &args[..]); + } + + wasmparser::Operator::Return => { + let retvals = self.pop_n(self.module.signature(self.my_sig).returns.len()); + self.emit_ret(&retvals[..]); + } + + _ => bail!(FrontendError::UnsupportedFeature(format!( + "Unsupported operator: {:?}", + op + ))), + } + + Ok(()) + } + + fn handle_op_unreachable(&mut self, op: wasmparser::Operator<'a>) -> Result<()> { + trace!("handle_op_unreachable: {:?}", op); + trace!("op_stack = {:?}", self.op_stack); + trace!("ctrl_stack = {:?}", self.ctrl_stack); + + debug_assert!(!self.reachable); + + self.handle_ctrl_op(op)?; + + Ok(()) + } + + fn handle_ctrl_op(&mut self, op: wasmparser::Operator<'a>) -> Result { + match &op { wasmparser::Operator::End => { let frame = self.ctrl_stack.pop(); match &frame { None => { - self.emit(Operator::Return)?; + if self.reachable { + let retvals = + self.pop_n(self.module.signature(self.my_sig).returns.len()); + self.emit_ret(&retvals[..]); + } else { + self.emit_unreachable(); + } } Some(Frame::Block { start_depth, @@ -977,9 +1081,9 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { }) => { // Generate a branch to the out-block with // blockparams for the results. - if let Some(cur_block) = self.cur_block { + if self.reachable { let result_values = - self.block_results(&results[..], *start_depth, cur_block); + self.block_results(&results[..], *start_depth, self.cur_block); self.emit_branch(*out, &result_values[..]); } self.op_stack.truncate(*start_depth); @@ -991,8 +1095,14 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { if let Some(Frame::Loop { header, .. }) = &frame { self.locals.seal_block_preds(*header, &mut self.body); } - self.cur_block = Some(*out); - self.locals.start_block(*out); + // Set `cur_block` only if currently set (otherwise, unreachable!) + self.cur_block = *out; + self.locals.start_block(*out, self.reachable); + self.reachable = self.reachable + || match &frame { + Some(Frame::Block { out_reachable, .. }) => *out_reachable, + _ => false, + }; self.push_block_params(results.len()); } Some(Frame::If { @@ -1001,51 +1111,59 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { el, ref param_values, ref results, + head_reachable, .. }) => { // Generate a branch to the out-block with // blockparams for the results. - if let Some(cur_block) = self.cur_block { + if self.reachable { let result_values = - self.block_results(&results[..], *start_depth, cur_block); + self.block_results(&results[..], *start_depth, self.cur_block); self.emit_branch(*out, &result_values[..]); } self.op_stack.truncate(*start_depth); - // No `else`, so we need to generate a trivial - // branch in the else-block. If the if-block-type - // has results, they must be exactly the params. - let else_result_values = param_values; - assert_eq!(else_result_values.len(), results.len()); - let else_result_values = else_result_values - .iter() - .map(|(_ty, value)| *value) - .collect::>(); - self.locals.start_block(*el); - self.cur_block = Some(*el); - self.emit_branch(*out, &else_result_values[..]); - assert_eq!(self.op_stack.len(), *start_depth); - self.cur_block = Some(*out); + if *head_reachable { + // No `else`, so we need to generate a trivial + // branch in the else-block. If the if-block-type + // has results, they must be exactly the params. + let else_result_values = param_values; + assert_eq!(else_result_values.len(), results.len()); + let else_result_values = else_result_values + .iter() + .map(|(_ty, value)| *value) + .collect::>(); + self.locals.start_block(*el, self.reachable); + self.cur_block = *el; + self.emit_branch(*out, &else_result_values[..]); + assert_eq!(self.op_stack.len(), *start_depth); + } + self.cur_block = *out; + let was_reachable = self.reachable; + self.reachable = *head_reachable || self.reachable; self.locals.seal_block_preds(*out, &mut self.body); - self.locals.start_block(*out); + self.locals.start_block(*out, was_reachable); self.push_block_params(results.len()); } Some(Frame::Else { out, ref results, start_depth, + merge_reachable, .. }) => { // Generate a branch to the out-block with // blockparams for the results. - if let Some(cur_block) = self.cur_block { + if self.reachable { let result_values = - self.block_results(&results[..], *start_depth, cur_block); + self.block_results(&results[..], *start_depth, self.cur_block); self.emit_branch(*out, &result_values[..]); } self.op_stack.truncate(*start_depth); - self.cur_block = Some(*out); + self.cur_block = *out; + let was_reachable = self.reachable; + self.reachable = *merge_reachable || self.reachable; self.locals.seal_block_preds(*out, &mut self.body); - self.locals.start_block(*out); + self.locals.start_block(*out, was_reachable); self.push_block_params(results.len()); } } @@ -1061,6 +1179,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { out, params, results, + out_reachable: false, }); } @@ -1071,8 +1190,8 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { let initial_args = self.pop_n(params.len()); let start_depth = self.op_stack.len(); self.emit_branch(header, &initial_args[..]); - self.cur_block = Some(header); - self.locals.start_block(header); + self.cur_block = header; + self.locals.start_block(header, self.reachable); self.push_block_params(params.len()); let out = self.body.add_block(); self.add_block_params(out, &results[..]); @@ -1101,12 +1220,13 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { param_values, params, results, + head_reachable: self.reachable, }); self.emit_cond_branch(cond, if_true, &[], if_false, &[]); self.locals.seal_block_preds(if_true, &mut self.body); self.locals.seal_block_preds(if_false, &mut self.body); - self.cur_block = Some(if_true); - self.locals.start_block(if_true); + self.cur_block = if_true; + self.locals.start_block(if_true, self.reachable); } wasmparser::Operator::Else => { @@ -1117,10 +1237,12 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { param_values, params, results, + head_reachable, } = self.ctrl_stack.pop().unwrap() { - if let Some(cur_block) = self.cur_block { - let if_results = self.block_results(&results[..], start_depth, cur_block); + if self.reachable { + let if_results = + self.block_results(&results[..], start_depth, self.cur_block); self.emit_branch(out, &if_results[..]); } self.op_stack.truncate(start_depth); @@ -1130,9 +1252,11 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { out, params, results, + merge_reachable: head_reachable, }); - self.cur_block = Some(el); - self.locals.start_block(el); + self.cur_block = el; + self.locals.start_block(el, self.reachable); + self.reachable = head_reachable; } else { bail!(FrontendError::Internal(format!( "Else without If on top of frame stack" @@ -1140,73 +1264,10 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { } } - wasmparser::Operator::Br { relative_depth } - | wasmparser::Operator::BrIf { relative_depth } => { - let cond = match &op { - wasmparser::Operator::Br { .. } => None, - wasmparser::Operator::BrIf { .. } => Some(self.pop_1()), - _ => unreachable!(), - }; - // Get the frame we're branching to. - let frame = self.relative_frame(*relative_depth).clone(); - log::trace!("Br/BrIf: dest frame {:?}", frame); - // Finally, generate the branch itself. - match cond { - None => { - // Get the args off the stack unconditionally. - let args = self.pop_n(frame.br_args().len()); - self.emit_branch(frame.br_target(), &args[..]); - } - Some(cond) => { - let cont = self.body.add_block(); - // Get the args off the stack but leave for the fallthrough. - let args = self.op_stack[self.op_stack.len() - frame.br_args().len()..] - .iter() - .map(|(_ty, value)| *value) - .collect::>(); - self.emit_cond_branch(cond, frame.br_target(), &args[..], cont, &[]); - self.locals.seal_block_preds(cont, &mut self.body); - self.cur_block = Some(cont); - self.locals.start_block(cont); - } - } - } - - wasmparser::Operator::BrTable { targets } => { - // Get the selector index. - let index = self.pop_1(); - // Get the signature of the default frame; this tells - // us the signature of all frames (since wasmparser - // validates the input for us). Pop that many args. - let default_frame = self.relative_frame(targets.default()); - let default_term_target = default_frame.br_target(); - let arg_len = default_frame.br_args().len(); - let args = self.pop_n(arg_len); - // Generate a branch terminator with the same args for - // every branch target. - let mut term_targets = vec![]; - for target in targets.targets() { - let target = target?; - let frame = self.relative_frame(target); - assert_eq!(frame.br_args().len(), args.len()); - let block = frame.br_target(); - term_targets.push(block); - } - self.emit_br_table(index, default_term_target, &term_targets[..], &args[..]); - } - - wasmparser::Operator::Return => { - let retvals = self.pop_n(self.module.signature(self.my_sig).returns.len()); - self.emit_ret(&retvals[..]); - } - - _ => bail!(FrontendError::UnsupportedFeature(format!( - "Unsupported operator: {:?}", - op - ))), + _ => return Ok(false), } - Ok(()) + Ok(true) } fn add_block_params(&mut self, block: Block, tys: &[Type]) { @@ -1230,8 +1291,9 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { } } - fn relative_frame(&self, relative_depth: u32) -> &Frame { - &self.ctrl_stack[self.ctrl_stack.len() - 1 - relative_depth as usize] + fn relative_frame(&mut self, relative_depth: u32) -> &mut Frame { + let index = self.ctrl_stack.len() - 1 - relative_depth as usize; + &mut self.ctrl_stack[index] } fn emit_branch(&mut self, target: Block, args: &[Value]) { @@ -1241,15 +1303,16 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { target, args ); - if let Some(block) = self.cur_block { + if self.reachable { let args = args.to_vec(); let target = BlockTarget { block: target, args, }; - self.body.end_block(block, Terminator::Br { target }); - self.cur_block = None; - self.locals.finish_block(); + self.body + .end_block(self.cur_block, Terminator::Br { target }); + self.reachable = false; + self.locals.finish_block(true); } } @@ -1269,11 +1332,11 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { if_false, if_false_args ); - if let Some(block) = self.cur_block { + if self.reachable { let if_true_args = if_true_args.to_vec(); let if_false_args = if_false_args.to_vec(); self.body.end_block( - block, + self.cur_block, Terminator::CondBr { cond, if_true: BlockTarget { @@ -1286,8 +1349,8 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { }, }, ); - self.cur_block = None; - self.locals.finish_block(); + self.reachable = false; + self.locals.finish_block(true); } } @@ -1306,7 +1369,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { indexed_targets, args, ); - if let Some(block) = self.cur_block { + if self.reachable { let args = args.to_vec(); let targets = indexed_targets .iter() @@ -1323,24 +1386,33 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { }; self.body.end_block( - block, + self.cur_block, Terminator::Select { value: index, targets, default, }, ); - self.cur_block = None; - self.locals.finish_block(); + self.locals.finish_block(true); + self.reachable = false; } } fn emit_ret(&mut self, values: &[Value]) { - if let Some(block) = self.cur_block { + if self.reachable { let values = values.to_vec(); - self.body.end_block(block, Terminator::Return { values }); - self.cur_block = None; - self.locals.finish_block(); + self.body + .end_block(self.cur_block, Terminator::Return { values }); + self.reachable = false; + self.locals.finish_block(true); + } + } + + fn emit_unreachable(&mut self) { + if self.reachable { + self.body.end_block(self.cur_block, Terminator::Unreachable); + self.reachable = false; + self.locals.finish_block(true); } } @@ -1350,16 +1422,15 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { self.cur_block, num_params ); - let block = self.cur_block.unwrap(); for i in 0..num_params { - let (ty, value) = self.body.blocks[block].params[i]; + let (ty, value) = self.body.blocks[self.cur_block].params[i]; log::trace!(" -> push {:?} ty {:?}", value, ty); self.op_stack.push((ty, value)); } } fn emit(&mut self, op: Operator) -> Result<()> { - let inputs = op_inputs(self.module, self.my_sig, &self.op_stack[..], &op)?; + let inputs = op_inputs(self.module, &self.op_stack[..], &op)?; let outputs = op_outputs(self.module, &self.op_stack[..], &op)?; log::trace!( @@ -1386,8 +1457,8 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { .add_value(ValueDef::Operator(op, input_operands, outputs.to_vec())); log::trace!(" -> value: {:?}", value); - if let Some(block) = self.cur_block { - self.body.append_to_block(block, value); + if self.reachable { + self.body.append_to_block(self.cur_block, value); } if n_outputs == 1 { @@ -1398,8 +1469,8 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { let pick = self .body .add_value(ValueDef::PickOutput(value, i, output_ty)); - if let Some(block) = self.cur_block { - self.body.append_to_block(block, pick); + if self.reachable { + self.body.append_to_block(self.cur_block, pick); } self.op_stack.push((output_ty, pick)); log::trace!(" -> pick {}: {:?} ty {:?}", i, pick, output_ty); diff --git a/src/op_traits.rs b/src/op_traits.rs index 7330644..12ba0ae 100644 --- a/src/op_traits.rs +++ b/src/op_traits.rs @@ -1,13 +1,12 @@ //! Metadata on operators. -use crate::ir::{Module, Signature, Type, Value}; +use crate::ir::{Module, Type, Value}; use crate::Operator; use anyhow::Result; use std::borrow::Cow; pub fn op_inputs( module: &Module, - my_sig: Signature, op_stack: &[(Type, Value)], op: &Operator, ) -> Result> { @@ -23,7 +22,6 @@ pub fn op_inputs( params.push(Type::I32); Ok(params.into()) } - &Operator::Return => Ok(Vec::from(module.signature(my_sig).returns.clone()).into()), &Operator::Select => { let val_ty = op_stack[op_stack.len() - 2].0; @@ -241,7 +239,6 @@ pub fn op_outputs( &Operator::CallIndirect { sig_index, .. } => { Ok(Vec::from(module.signature(sig_index).returns.clone()).into()) } - &Operator::Return => Ok(Cow::Borrowed(&[])), &Operator::Select => { let val_ty = op_stack[op_stack.len() - 2].0; @@ -442,7 +439,6 @@ pub enum SideEffect { WriteTable, ReadLocal, WriteLocal, - Return, All, } @@ -456,7 +452,6 @@ impl Operator { &Operator::Call { .. } => &[All], &Operator::CallIndirect { .. } => &[All], - &Operator::Return => &[Return], &Operator::Select => &[], &Operator::TypedSelect { .. } => &[], @@ -659,7 +654,6 @@ impl std::fmt::Display for Operator { sig_index, table_index, } => write!(f, "call_indirect<{}, {}>", sig_index, table_index)?, - &Operator::Return => write!(f, "return")?, &Operator::Select => write!(f, "select")?, &Operator::TypedSelect { ty } => write!(f, "typed_select<{}>", ty)?, diff --git a/src/ops.rs b/src/ops.rs index c63984c..295380d 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -33,7 +33,6 @@ pub enum Operator { sig_index: Signature, table_index: Table, }, - Return, Select, TypedSelect { ty: Type, @@ -316,7 +315,6 @@ impl<'a, 'b> std::convert::TryFrom<&'b wasmparser::Operator<'a>> for Operator { sig_index: Signature::from(type_index), table_index: Table::from(table_index), }), - &wasmparser::Operator::Return => Ok(Operator::Return), &wasmparser::Operator::LocalSet { .. } => Err(()), &wasmparser::Operator::LocalTee { .. } => Err(()), &wasmparser::Operator::LocalGet { .. } => Err(()),