Implement br_table support.

This commit is contained in:
Chris Fallin 2021-11-13 14:13:31 -08:00
parent 7f1652fb2e
commit e6024c4ffd

View file

@ -102,12 +102,12 @@ struct FunctionBodyBuilder<'a, 'b> {
module: &'b Module<'a>,
my_sig: SignatureId,
body: &'b mut FunctionBody<'a>,
cur_block: BlockId,
cur_block: Option<BlockId>,
ctrl_stack: Vec<Frame>,
op_stack: Vec<ValueId>,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
enum Frame {
Block {
start_depth: usize,
@ -175,7 +175,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
body,
ctrl_stack: vec![],
op_stack: vec![],
cur_block: 0,
cur_block: Some(0),
}
}
@ -212,7 +212,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
let result_values = self.op_stack.split_off(results.len());
self.emit_branch(out, &result_values[..]);
assert_eq!(self.op_stack.len(), start_depth);
self.cur_block = out;
self.cur_block = Some(out);
self.push_block_params(&results[..]);
}
Some(Frame::If {
@ -234,7 +234,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
assert_eq!(else_result_values.len(), results.len());
self.emit_branch(el, &else_result_values[..]);
assert_eq!(self.op_stack.len(), start_depth);
self.cur_block = out;
self.cur_block = Some(out);
self.push_block_params(&results[..]);
}
Some(Frame::Else {
@ -248,7 +248,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
let result_values = self.op_stack.split_off(results.len());
assert_eq!(self.op_stack.len(), start_depth);
self.emit_branch(out, &result_values[..]);
self.cur_block = out;
self.cur_block = Some(out);
self.push_block_params(&results[..]);
}
},
@ -273,7 +273,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
let initial_args = self.op_stack.split_off(params.len());
let start_depth = self.op_stack.len();
self.emit_branch(header, &initial_args[..]);
self.cur_block = header;
self.cur_block = Some(header);
self.push_block_params(&params[..]);
let out = self.create_block();
self.ctrl_stack.push(Frame::Loop {
@ -302,7 +302,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
params,
results,
});
self.cur_block = if_true;
self.cur_block = Some(if_true);
self.emit_cond_branch(cond, if_true, &[], if_false, &[]);
}
@ -325,7 +325,7 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
params,
results,
});
self.cur_block = el;
self.cur_block = Some(el);
} else {
bail!("Else without If on top of frame stack");
}
@ -337,28 +337,46 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
Operator::BrIf { .. } => Some(self.op_stack.pop().unwrap()),
_ => unreachable!(),
};
// Pop skipped-over frames.
let _ = self.ctrl_stack.split_off(*relative_depth as usize);
// Get the frame we're branching to.
let frame = self.ctrl_stack.pop().unwrap();
let frame = self.relative_frame(*relative_depth).clone();
// Get the args off the stack.
let args = self.op_stack.split_off(frame.br_args().len());
// Truncate the result stack down to the expected height.
self.op_stack.truncate(frame.start_depth());
// Finally, generate the branch itself.
match cond {
None => {
self.emit_branch(frame.br_target(), &args[..]);
self.cur_block = None;
}
Some(cond) => {
let cont = self.create_block();
self.emit_cond_branch(cond, frame.br_target(), &args[..], cont, &[]);
self.cur_block = cont;
self.cur_block = Some(cont);
}
}
}
Operator::BrTable { .. } => {}
Operator::BrTable { table } => {
// Get the selector index.
let index = self.op_stack.pop().unwrap();
// Get the signature of the default frame; this tells
// us the signature of all frames (since wasmparser
// validates the input for us). Pop that many args.
let default_frame = self.relative_frame(table.default());
let default_term_target = default_frame.br_target();
let arg_len = default_frame.br_args().len();
let args = self.op_stack.split_off(arg_len);
// Generate a branch terminator with the same args for
// every branch target.
let mut term_targets = vec![];
for target in table.targets() {
let target = target?;
let frame = self.relative_frame(target);
assert_eq!(frame.br_args().len(), args.len());
let block = frame.br_target();
term_targets.push(block);
}
self.emit_br_table(index, default_term_target, &term_targets[..], &args[..]);
}
_ => bail!("Unsupported operator: {:?}", op),
}
@ -385,14 +403,19 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
}
}
fn relative_frame(&self, relative_depth: u32) -> &Frame {
&self.ctrl_stack[self.ctrl_stack.len() - 1 - relative_depth as usize]
}
fn emit_branch(&mut self, target: BlockId, args: &[ValueId]) {
let block = self.cur_block;
let args = args.iter().map(|&val| Operand::Value(val)).collect();
let target = BlockTarget {
block: target,
args,
};
self.body.blocks[block].terminator = Terminator::Br { target };
if let Some(block) = self.cur_block {
let args = args.iter().map(|&val| Operand::Value(val)).collect();
let target = BlockTarget {
block: target,
args,
};
self.body.blocks[block].terminator = Terminator::Br { target };
}
}
fn emit_cond_branch(
@ -403,34 +426,63 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
if_false: BlockId,
if_false_args: &[ValueId],
) {
let block = self.cur_block;
let if_true_args = if_true_args
.iter()
.map(|&val| Operand::Value(val))
.collect();
let if_false_args = if_false_args
.iter()
.map(|&val| Operand::Value(val))
.collect();
self.body.blocks[block].terminator = Terminator::CondBr {
cond: Operand::Value(cond),
if_true: BlockTarget {
block: if_true,
args: if_true_args,
},
if_false: BlockTarget {
block: if_false,
args: if_false_args,
},
};
if let Some(block) = self.cur_block {
let if_true_args = if_true_args
.iter()
.map(|&val| Operand::Value(val))
.collect();
let if_false_args = if_false_args
.iter()
.map(|&val| Operand::Value(val))
.collect();
self.body.blocks[block].terminator = Terminator::CondBr {
cond: Operand::Value(cond),
if_true: BlockTarget {
block: if_true,
args: if_true_args,
},
if_false: BlockTarget {
block: if_false,
args: if_false_args,
},
};
}
}
fn emit_br_table(
&mut self,
index: ValueId,
default_target: BlockId,
indexed_targets: &[BlockId],
args: &[ValueId],
) {
if let Some(block) = self.cur_block {
let args: Vec<Operand<'a>> = args.iter().map(|&arg| Operand::Value(arg)).collect();
let targets = indexed_targets
.iter()
.map(|&block| BlockTarget {
block,
args: args.clone(),
})
.collect();
let default = BlockTarget {
block: default_target,
args: args.clone(),
};
self.body.blocks[block].terminator = Terminator::Select {
value: Operand::Value(index),
targets,
default,
};
}
}
fn push_block_params(&mut self, tys: &[Type]) {
assert_eq!(tys, self.body.blocks[self.cur_block].params);
assert_eq!(tys, self.body.blocks[self.cur_block.unwrap()].params);
for (i, &ty) in tys.iter().enumerate() {
let value_id = self.body.values.len() as ValueId;
self.body.values.push(ValueDef {
kind: ValueKind::BlockParam(self.cur_block, i),
kind: ValueKind::BlockParam(self.cur_block.unwrap(), i),
ty,
});
self.op_stack.push(value_id);
@ -438,37 +490,38 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> {
}
fn emit(&mut self, op: Operator<'a>) -> Result<()> {
let block = self.cur_block;
let inst = self.body.blocks[block].insts.len() as InstId;
if let Some(block) = self.cur_block {
let inst = self.body.blocks[block].insts.len() as InstId;
let mut inputs = vec![];
for input in op_inputs(self.module, self.my_sig, &self.body.locals[..], &op)?
.into_iter()
.rev()
{
let stack_top = self.op_stack.pop().unwrap();
assert_eq!(self.body.values[stack_top].ty, input);
inputs.push(Operand::Value(stack_top));
}
inputs.reverse();
let mut inputs = vec![];
for input in op_inputs(self.module, self.my_sig, &self.body.locals[..], &op)?
.into_iter()
.rev()
{
let stack_top = self.op_stack.pop().unwrap();
assert_eq!(self.body.values[stack_top].ty, input);
inputs.push(Operand::Value(stack_top));
}
inputs.reverse();
let mut outputs = vec![];
for output in op_outputs(self.module, &self.body.locals[..], &op)?.into_iter() {
let val = self.body.values.len() as ValueId;
outputs.push(val);
self.body.values.push(ValueDef {
kind: ValueKind::Inst(block, inst),
ty: output,
let mut outputs = vec![];
for output in op_outputs(self.module, &self.body.locals[..], &op)?.into_iter() {
let val = self.body.values.len() as ValueId;
outputs.push(val);
self.body.values.push(ValueDef {
kind: ValueKind::Inst(block, inst),
ty: output,
});
self.op_stack.push(val);
}
self.body.blocks[block].insts.push(Inst {
operator: op,
outputs,
inputs,
});
self.op_stack.push(val);
}
self.body.blocks[self.cur_block].insts.push(Inst {
operator: op,
outputs,
inputs,
});
Ok(())
}
}