From e6024c4ffd107644c52ea4a58fa9fa83f633a440 Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Sat, 13 Nov 2021 14:13:31 -0800 Subject: [PATCH] Implement br_table support. --- src/frontend.rs | 195 ++++++++++++++++++++++++++++++------------------ 1 file changed, 124 insertions(+), 71 deletions(-) diff --git a/src/frontend.rs b/src/frontend.rs index 94442e8..a2f1e43 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -102,12 +102,12 @@ struct FunctionBodyBuilder<'a, 'b> { module: &'b Module<'a>, my_sig: SignatureId, body: &'b mut FunctionBody<'a>, - cur_block: BlockId, + cur_block: Option, ctrl_stack: Vec, op_stack: Vec, } -#[derive(Debug)] +#[derive(Clone, Debug)] enum Frame { Block { start_depth: usize, @@ -175,7 +175,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { body, ctrl_stack: vec![], op_stack: vec![], - cur_block: 0, + cur_block: Some(0), } } @@ -212,7 +212,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { let result_values = self.op_stack.split_off(results.len()); self.emit_branch(out, &result_values[..]); assert_eq!(self.op_stack.len(), start_depth); - self.cur_block = out; + self.cur_block = Some(out); self.push_block_params(&results[..]); } Some(Frame::If { @@ -234,7 +234,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { assert_eq!(else_result_values.len(), results.len()); self.emit_branch(el, &else_result_values[..]); assert_eq!(self.op_stack.len(), start_depth); - self.cur_block = out; + self.cur_block = Some(out); self.push_block_params(&results[..]); } Some(Frame::Else { @@ -248,7 +248,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { let result_values = self.op_stack.split_off(results.len()); assert_eq!(self.op_stack.len(), start_depth); self.emit_branch(out, &result_values[..]); - self.cur_block = out; + self.cur_block = Some(out); self.push_block_params(&results[..]); } }, @@ -273,7 +273,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { let initial_args = self.op_stack.split_off(params.len()); let start_depth = self.op_stack.len(); self.emit_branch(header, &initial_args[..]); - self.cur_block = header; + self.cur_block = Some(header); self.push_block_params(¶ms[..]); let out = self.create_block(); self.ctrl_stack.push(Frame::Loop { @@ -302,7 +302,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { params, results, }); - self.cur_block = if_true; + self.cur_block = Some(if_true); self.emit_cond_branch(cond, if_true, &[], if_false, &[]); } @@ -325,7 +325,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { params, results, }); - self.cur_block = el; + self.cur_block = Some(el); } else { bail!("Else without If on top of frame stack"); } @@ -337,28 +337,46 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { Operator::BrIf { .. } => Some(self.op_stack.pop().unwrap()), _ => unreachable!(), }; - // Pop skipped-over frames. - let _ = self.ctrl_stack.split_off(*relative_depth as usize); // Get the frame we're branching to. - let frame = self.ctrl_stack.pop().unwrap(); + let frame = self.relative_frame(*relative_depth).clone(); // Get the args off the stack. let args = self.op_stack.split_off(frame.br_args().len()); - // Truncate the result stack down to the expected height. - self.op_stack.truncate(frame.start_depth()); // Finally, generate the branch itself. match cond { None => { self.emit_branch(frame.br_target(), &args[..]); + self.cur_block = None; } Some(cond) => { let cont = self.create_block(); self.emit_cond_branch(cond, frame.br_target(), &args[..], cont, &[]); - self.cur_block = cont; + self.cur_block = Some(cont); } } } - Operator::BrTable { .. } => {} + Operator::BrTable { table } => { + // Get the selector index. + let index = self.op_stack.pop().unwrap(); + // Get the signature of the default frame; this tells + // us the signature of all frames (since wasmparser + // validates the input for us). Pop that many args. + let default_frame = self.relative_frame(table.default()); + let default_term_target = default_frame.br_target(); + let arg_len = default_frame.br_args().len(); + let args = self.op_stack.split_off(arg_len); + // Generate a branch terminator with the same args for + // every branch target. + let mut term_targets = vec![]; + for target in table.targets() { + let target = target?; + let frame = self.relative_frame(target); + assert_eq!(frame.br_args().len(), args.len()); + let block = frame.br_target(); + term_targets.push(block); + } + self.emit_br_table(index, default_term_target, &term_targets[..], &args[..]); + } _ => bail!("Unsupported operator: {:?}", op), } @@ -385,14 +403,19 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { } } + fn relative_frame(&self, relative_depth: u32) -> &Frame { + &self.ctrl_stack[self.ctrl_stack.len() - 1 - relative_depth as usize] + } + fn emit_branch(&mut self, target: BlockId, args: &[ValueId]) { - let block = self.cur_block; - let args = args.iter().map(|&val| Operand::Value(val)).collect(); - let target = BlockTarget { - block: target, - args, - }; - self.body.blocks[block].terminator = Terminator::Br { target }; + if let Some(block) = self.cur_block { + let args = args.iter().map(|&val| Operand::Value(val)).collect(); + let target = BlockTarget { + block: target, + args, + }; + self.body.blocks[block].terminator = Terminator::Br { target }; + } } fn emit_cond_branch( @@ -403,34 +426,63 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { if_false: BlockId, if_false_args: &[ValueId], ) { - let block = self.cur_block; - let if_true_args = if_true_args - .iter() - .map(|&val| Operand::Value(val)) - .collect(); - let if_false_args = if_false_args - .iter() - .map(|&val| Operand::Value(val)) - .collect(); - self.body.blocks[block].terminator = Terminator::CondBr { - cond: Operand::Value(cond), - if_true: BlockTarget { - block: if_true, - args: if_true_args, - }, - if_false: BlockTarget { - block: if_false, - args: if_false_args, - }, - }; + if let Some(block) = self.cur_block { + let if_true_args = if_true_args + .iter() + .map(|&val| Operand::Value(val)) + .collect(); + let if_false_args = if_false_args + .iter() + .map(|&val| Operand::Value(val)) + .collect(); + self.body.blocks[block].terminator = Terminator::CondBr { + cond: Operand::Value(cond), + if_true: BlockTarget { + block: if_true, + args: if_true_args, + }, + if_false: BlockTarget { + block: if_false, + args: if_false_args, + }, + }; + } + } + + fn emit_br_table( + &mut self, + index: ValueId, + default_target: BlockId, + indexed_targets: &[BlockId], + args: &[ValueId], + ) { + if let Some(block) = self.cur_block { + let args: Vec> = args.iter().map(|&arg| Operand::Value(arg)).collect(); + let targets = indexed_targets + .iter() + .map(|&block| BlockTarget { + block, + args: args.clone(), + }) + .collect(); + let default = BlockTarget { + block: default_target, + args: args.clone(), + }; + self.body.blocks[block].terminator = Terminator::Select { + value: Operand::Value(index), + targets, + default, + }; + } } fn push_block_params(&mut self, tys: &[Type]) { - assert_eq!(tys, self.body.blocks[self.cur_block].params); + assert_eq!(tys, self.body.blocks[self.cur_block.unwrap()].params); for (i, &ty) in tys.iter().enumerate() { let value_id = self.body.values.len() as ValueId; self.body.values.push(ValueDef { - kind: ValueKind::BlockParam(self.cur_block, i), + kind: ValueKind::BlockParam(self.cur_block.unwrap(), i), ty, }); self.op_stack.push(value_id); @@ -438,37 +490,38 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { } fn emit(&mut self, op: Operator<'a>) -> Result<()> { - let block = self.cur_block; - let inst = self.body.blocks[block].insts.len() as InstId; + if let Some(block) = self.cur_block { + let inst = self.body.blocks[block].insts.len() as InstId; - let mut inputs = vec![]; - for input in op_inputs(self.module, self.my_sig, &self.body.locals[..], &op)? - .into_iter() - .rev() - { - let stack_top = self.op_stack.pop().unwrap(); - assert_eq!(self.body.values[stack_top].ty, input); - inputs.push(Operand::Value(stack_top)); - } - inputs.reverse(); + let mut inputs = vec![]; + for input in op_inputs(self.module, self.my_sig, &self.body.locals[..], &op)? + .into_iter() + .rev() + { + let stack_top = self.op_stack.pop().unwrap(); + assert_eq!(self.body.values[stack_top].ty, input); + inputs.push(Operand::Value(stack_top)); + } + inputs.reverse(); - let mut outputs = vec![]; - for output in op_outputs(self.module, &self.body.locals[..], &op)?.into_iter() { - let val = self.body.values.len() as ValueId; - outputs.push(val); - self.body.values.push(ValueDef { - kind: ValueKind::Inst(block, inst), - ty: output, + let mut outputs = vec![]; + for output in op_outputs(self.module, &self.body.locals[..], &op)?.into_iter() { + let val = self.body.values.len() as ValueId; + outputs.push(val); + self.body.values.push(ValueDef { + kind: ValueKind::Inst(block, inst), + ty: output, + }); + self.op_stack.push(val); + } + + self.body.blocks[block].insts.push(Inst { + operator: op, + outputs, + inputs, }); - self.op_stack.push(val); } - self.body.blocks[self.cur_block].insts.push(Inst { - operator: op, - outputs, - inputs, - }); - Ok(()) } }