Implement br_table support.

This commit is contained in:
Chris Fallin 2021-11-13 14:13:31 -08:00
parent 7f1652fb2e
commit e6024c4ffd

View file

@ -102,12 +102,12 @@ struct FunctionBodyBuilder<'a, 'b> {
module: &'b Module<'a>, module: &'b Module<'a>,
my_sig: SignatureId, my_sig: SignatureId,
body: &'b mut FunctionBody<'a>, body: &'b mut FunctionBody<'a>,
cur_block: BlockId, cur_block: Option<BlockId>,
ctrl_stack: Vec<Frame>, ctrl_stack: Vec<Frame>,
op_stack: Vec<ValueId>, op_stack: Vec<ValueId>,
} }
#[derive(Debug)] #[derive(Clone, Debug)]
enum Frame { enum Frame {
Block { Block {
start_depth: usize, start_depth: usize,
@ -175,7 +175,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
body, body,
ctrl_stack: vec![], ctrl_stack: vec![],
op_stack: vec![], op_stack: vec![],
cur_block: 0, cur_block: Some(0),
} }
} }
@ -212,7 +212,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
let result_values = self.op_stack.split_off(results.len()); let result_values = self.op_stack.split_off(results.len());
self.emit_branch(out, &result_values[..]); self.emit_branch(out, &result_values[..]);
assert_eq!(self.op_stack.len(), start_depth); assert_eq!(self.op_stack.len(), start_depth);
self.cur_block = out; self.cur_block = Some(out);
self.push_block_params(&results[..]); self.push_block_params(&results[..]);
} }
Some(Frame::If { Some(Frame::If {
@ -234,7 +234,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
assert_eq!(else_result_values.len(), results.len()); assert_eq!(else_result_values.len(), results.len());
self.emit_branch(el, &else_result_values[..]); self.emit_branch(el, &else_result_values[..]);
assert_eq!(self.op_stack.len(), start_depth); assert_eq!(self.op_stack.len(), start_depth);
self.cur_block = out; self.cur_block = Some(out);
self.push_block_params(&results[..]); self.push_block_params(&results[..]);
} }
Some(Frame::Else { Some(Frame::Else {
@ -248,7 +248,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
let result_values = self.op_stack.split_off(results.len()); let result_values = self.op_stack.split_off(results.len());
assert_eq!(self.op_stack.len(), start_depth); assert_eq!(self.op_stack.len(), start_depth);
self.emit_branch(out, &result_values[..]); self.emit_branch(out, &result_values[..]);
self.cur_block = out; self.cur_block = Some(out);
self.push_block_params(&results[..]); self.push_block_params(&results[..]);
} }
}, },
@ -273,7 +273,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
let initial_args = self.op_stack.split_off(params.len()); let initial_args = self.op_stack.split_off(params.len());
let start_depth = self.op_stack.len(); let start_depth = self.op_stack.len();
self.emit_branch(header, &initial_args[..]); self.emit_branch(header, &initial_args[..]);
self.cur_block = header; self.cur_block = Some(header);
self.push_block_params(&params[..]); self.push_block_params(&params[..]);
let out = self.create_block(); let out = self.create_block();
self.ctrl_stack.push(Frame::Loop { self.ctrl_stack.push(Frame::Loop {
@ -302,7 +302,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
params, params,
results, results,
}); });
self.cur_block = if_true; self.cur_block = Some(if_true);
self.emit_cond_branch(cond, if_true, &[], if_false, &[]); self.emit_cond_branch(cond, if_true, &[], if_false, &[]);
} }
@ -325,7 +325,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
params, params,
results, results,
}); });
self.cur_block = el; self.cur_block = Some(el);
} else { } else {
bail!("Else without If on top of frame stack"); bail!("Else without If on top of frame stack");
} }
@ -337,28 +337,46 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
Operator::BrIf { .. } => Some(self.op_stack.pop().unwrap()), Operator::BrIf { .. } => Some(self.op_stack.pop().unwrap()),
_ => unreachable!(), _ => unreachable!(),
}; };
// Pop skipped-over frames.
let _ = self.ctrl_stack.split_off(*relative_depth as usize);
// Get the frame we're branching to. // Get the frame we're branching to.
let frame = self.ctrl_stack.pop().unwrap(); let frame = self.relative_frame(*relative_depth).clone();
// Get the args off the stack. // Get the args off the stack.
let args = self.op_stack.split_off(frame.br_args().len()); let args = self.op_stack.split_off(frame.br_args().len());
// Truncate the result stack down to the expected height.
self.op_stack.truncate(frame.start_depth());
// Finally, generate the branch itself. // Finally, generate the branch itself.
match cond { match cond {
None => { None => {
self.emit_branch(frame.br_target(), &args[..]); self.emit_branch(frame.br_target(), &args[..]);
self.cur_block = None;
} }
Some(cond) => { Some(cond) => {
let cont = self.create_block(); let cont = self.create_block();
self.emit_cond_branch(cond, frame.br_target(), &args[..], cont, &[]); self.emit_cond_branch(cond, frame.br_target(), &args[..], cont, &[]);
self.cur_block = cont; self.cur_block = Some(cont);
} }
} }
} }
Operator::BrTable { .. } => {} Operator::BrTable { table } => {
// Get the selector index.
let index = self.op_stack.pop().unwrap();
// Get the signature of the default frame; this tells
// us the signature of all frames (since wasmparser
// validates the input for us). Pop that many args.
let default_frame = self.relative_frame(table.default());
let default_term_target = default_frame.br_target();
let arg_len = default_frame.br_args().len();
let args = self.op_stack.split_off(arg_len);
// Generate a branch terminator with the same args for
// every branch target.
let mut term_targets = vec![];
for target in table.targets() {
let target = target?;
let frame = self.relative_frame(target);
assert_eq!(frame.br_args().len(), args.len());
let block = frame.br_target();
term_targets.push(block);
}
self.emit_br_table(index, default_term_target, &term_targets[..], &args[..]);
}
_ => bail!("Unsupported operator: {:?}", op), _ => bail!("Unsupported operator: {:?}", op),
} }
@ -385,8 +403,12 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
} }
} }
fn relative_frame(&self, relative_depth: u32) -> &Frame {
&self.ctrl_stack[self.ctrl_stack.len() - 1 - relative_depth as usize]
}
fn emit_branch(&mut self, target: BlockId, args: &[ValueId]) { fn emit_branch(&mut self, target: BlockId, args: &[ValueId]) {
let block = self.cur_block; if let Some(block) = self.cur_block {
let args = args.iter().map(|&val| Operand::Value(val)).collect(); let args = args.iter().map(|&val| Operand::Value(val)).collect();
let target = BlockTarget { let target = BlockTarget {
block: target, block: target,
@ -394,6 +416,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
}; };
self.body.blocks[block].terminator = Terminator::Br { target }; self.body.blocks[block].terminator = Terminator::Br { target };
} }
}
fn emit_cond_branch( fn emit_cond_branch(
&mut self, &mut self,
@ -403,7 +426,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
if_false: BlockId, if_false: BlockId,
if_false_args: &[ValueId], if_false_args: &[ValueId],
) { ) {
let block = self.cur_block; if let Some(block) = self.cur_block {
let if_true_args = if_true_args let if_true_args = if_true_args
.iter() .iter()
.map(|&val| Operand::Value(val)) .map(|&val| Operand::Value(val))
@ -424,13 +447,42 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
}, },
}; };
} }
}
fn emit_br_table(
&mut self,
index: ValueId,
default_target: BlockId,
indexed_targets: &[BlockId],
args: &[ValueId],
) {
if let Some(block) = self.cur_block {
let args: Vec<Operand<'a>> = args.iter().map(|&arg| Operand::Value(arg)).collect();
let targets = indexed_targets
.iter()
.map(|&block| BlockTarget {
block,
args: args.clone(),
})
.collect();
let default = BlockTarget {
block: default_target,
args: args.clone(),
};
self.body.blocks[block].terminator = Terminator::Select {
value: Operand::Value(index),
targets,
default,
};
}
}
fn push_block_params(&mut self, tys: &[Type]) { fn push_block_params(&mut self, tys: &[Type]) {
assert_eq!(tys, self.body.blocks[self.cur_block].params); assert_eq!(tys, self.body.blocks[self.cur_block.unwrap()].params);
for (i, &ty) in tys.iter().enumerate() { for (i, &ty) in tys.iter().enumerate() {
let value_id = self.body.values.len() as ValueId; let value_id = self.body.values.len() as ValueId;
self.body.values.push(ValueDef { self.body.values.push(ValueDef {
kind: ValueKind::BlockParam(self.cur_block, i), kind: ValueKind::BlockParam(self.cur_block.unwrap(), i),
ty, ty,
}); });
self.op_stack.push(value_id); self.op_stack.push(value_id);
@ -438,7 +490,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
} }
fn emit(&mut self, op: Operator<'a>) -> Result<()> { fn emit(&mut self, op: Operator<'a>) -> Result<()> {
let block = self.cur_block; if let Some(block) = self.cur_block {
let inst = self.body.blocks[block].insts.len() as InstId; let inst = self.body.blocks[block].insts.len() as InstId;
let mut inputs = vec![]; let mut inputs = vec![];
@ -463,11 +515,12 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
self.op_stack.push(val); self.op_stack.push(val);
} }
self.body.blocks[self.cur_block].insts.push(Inst { self.body.blocks[block].insts.push(Inst {
operator: op, operator: op,
outputs, outputs,
inputs, inputs,
}); });
}
Ok(()) Ok(())
} }