From 86be4c06e1eee3a0e212d6683e27a9e48800ec68 Mon Sep 17 00:00:00 2001 From: Chris Fallin Date: Thu, 28 Mar 2024 16:36:07 -0700 Subject: [PATCH] Handle typed funcrefs. --- src/backend/mod.rs | 34 ++++++++++++--- src/frontend.rs | 82 ++++++++++++++++++++++++------------ src/ir.rs | 1 - src/ir/module.rs | 10 ++--- src/op_traits.rs | 22 ++++++++++ src/ops.rs | 14 ++++++ wasm_tests/typed-funcref.wat | 21 +++++++++ 7 files changed, 143 insertions(+), 41 deletions(-) create mode 100644 wasm_tests/typed-funcref.wat diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 8535d76..296046b 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -909,6 +909,13 @@ impl<'a> WasmFuncBackend<'a> { } Operator::F32x4DemoteF64x2Zero => Some(wasm_encoder::Instruction::F32x4DemoteF64x2Zero), Operator::F64x2PromoteLowF32x4 => Some(wasm_encoder::Instruction::F64x2PromoteLowF32x4), + + Operator::CallRef { sig_index } => { + Some(wasm_encoder::Instruction::CallRef(sig_index.index() as u32)) + } + Operator::RefFunc { func_index } => { + Some(wasm_encoder::Instruction::RefFunc(func_index.index() as u32)) + } }; if let Some(inst) = inst { @@ -955,7 +962,7 @@ pub fn compile(module: &Module<'_>) -> anyhow::Result> { .func_elements .as_ref() .map(|elts| elts.len() as u32) - .unwrap_or(0), + .unwrap_or(table.initial), maximum: table.max, }) } @@ -1081,11 +1088,26 @@ pub fn compile(module: &Module<'_>) -> anyhow::Result> { if let Some(elts) = &table_data.func_elements { for (i, &elt) in elts.iter().enumerate() { if elt.is_valid() { - elem.active( - Some(table.index() as u32), - &wasm_encoder::ConstExpr::i32_const(i as i32), - wasm_encoder::Elements::Functions(&[elt.index() as u32]), - ); + match table_data.ty { + Type::FuncRef => { + elem.active( + Some(table.index() as u32), + &wasm_encoder::ConstExpr::i32_const(i as i32), + wasm_encoder::Elements::Functions(&[elt.index() as u32]), + ); + } + Type::TypedFuncRef(..) => { + elem.active( + Some(table.index() as u32), + &wasm_encoder::ConstExpr::i32_const(i as i32), + wasm_encoder::Elements::Expressions( + table_data.ty.into(), + &[wasm_encoder::ConstExpr::ref_func(elt.index() as u32)], + ), + ); + } + _ => unreachable!(), + } } } } diff --git a/src/frontend.rs b/src/frontend.rs index 5bae767..e28af19 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -133,7 +133,11 @@ fn handle_payload<'a>( ImportKind::Global(global) } TypeRef::Table(ty) => { - let table = module.frontend_add_table(ty.element_type.into(), None); + let table = module.frontend_add_table( + ty.element_type.into(), + ty.initial, + ty.maximum, + ); ImportKind::Table(table) } TypeRef::Memory(mem) => { @@ -174,7 +178,11 @@ fn handle_payload<'a>( Payload::TableSection(reader) => { for table in reader { let table = table?; - module.frontend_add_table(table.ty.element_type.into(), table.ty.maximum); + module.frontend_add_table( + table.ty.element_type.into(), + table.ty.initial, + table.ty.maximum, + ); } } Payload::FunctionSection(reader) => { @@ -318,7 +326,7 @@ fn handle_payload<'a>( } => { let table = Table::from(table_index.unwrap_or(0)); let offset = parse_init_expr(&offset_expr)?.unwrap_or(0) as usize; - match element.items { + let funcs = match element.items { wasmparser::ElementItems::Functions(items) => { let mut funcs = vec![]; for item in items { @@ -326,35 +334,51 @@ fn handle_payload<'a>( let func = Func::from(item); funcs.push(func); } - - let table_items = - module.tables[table].func_elements.as_mut().unwrap(); - let new_size = - offset.checked_add(funcs.len()).ok_or_else(|| { - FrontendError::TooLarge(format!( - "Overflowing element offset + length: {} + {}", - offset, - funcs.len() - )) - })?; - if new_size > table_items.len() { - static MAX_TABLE: usize = 100_000; - if new_size > MAX_TABLE { - bail!(FrontendError::TooLarge(format!( - "Too many table elements: {:?}", - new_size - ))); + funcs + } + wasmparser::ElementItems::Expressions(_, const_exprs) => { + let mut funcs = vec![]; + for const_expr in const_exprs { + let const_expr = const_expr?; + let mut func = None; + for op in const_expr.get_operators_reader() { + let op = op?; + match op { + wasmparser::Operator::End => {} + wasmparser::Operator::RefFunc { function_index } => { + func = Some(Func::from(function_index)); + } + wasmparser::Operator::RefNull { .. } => { + func = Some(Func::invalid()); + } + _ => panic!("Unsupported table-init op: {:?}", op), + } } - table_items.resize(new_size, Func::invalid()); + funcs.push(func.unwrap_or(Func::invalid())); } - table_items[offset..new_size].copy_from_slice(&funcs[..]); + funcs } - wasmparser::ElementItems::Expressions(..) => { - bail!(FrontendError::UnsupportedFeature( - "Expression element items".into() - )) + }; + + let table_items = module.tables[table].func_elements.as_mut().unwrap(); + let new_size = offset.checked_add(funcs.len()).ok_or_else(|| { + FrontendError::TooLarge(format!( + "Overflowing element offset + length: {} + {}", + offset, + funcs.len() + )) + })?; + if new_size > table_items.len() { + static MAX_TABLE: usize = 100_000; + if new_size > MAX_TABLE { + bail!(FrontendError::TooLarge(format!( + "Too many table elements: {:?}", + new_size + ))); } + table_items.resize(new_size, Func::invalid()); } + table_items[offset..new_size].copy_from_slice(&funcs[..]); } } } @@ -1394,7 +1418,9 @@ impl<'a, 'b> FunctionBodyBuilder<'a, 'b> { | wasmparser::Operator::F64x2ConvertLowI32x4S | wasmparser::Operator::F64x2ConvertLowI32x4U | wasmparser::Operator::F32x4DemoteF64x2Zero - | wasmparser::Operator::F64x2PromoteLowF32x4 => { + | wasmparser::Operator::F64x2PromoteLowF32x4 + | wasmparser::Operator::CallRef { .. } + | wasmparser::Operator::RefFunc { .. } => { self.emit(Operator::try_from(&op).unwrap(), loc)? } diff --git a/src/ir.rs b/src/ir.rs index 71b14d6..2a40cd0 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -26,7 +26,6 @@ impl From for Type { } impl From for Type { fn from(ty: wasmparser::RefType) -> Self { - assert!(ty.is_func_ref(), "only funcrefs are supported right now"); match ty.type_index() { Some(idx) => { let nullable = ty.is_nullable(); diff --git a/src/ir/module.rs b/src/ir/module.rs index 05795db..f4a438e 100644 --- a/src/ir/module.rs +++ b/src/ir/module.rs @@ -43,6 +43,7 @@ pub struct MemorySegment { #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct TableData { pub ty: Type, + pub initial: u32, pub max: Option, pub func_elements: Option>, } @@ -170,15 +171,12 @@ impl<'a> Module<'a> { } impl<'a> Module<'a> { - pub(crate) fn frontend_add_table(&mut self, ty: Type, max: Option) -> Table { - let func_elements = if ty == Type::FuncRef { - Some(vec![]) - } else { - None - }; + pub(crate) fn frontend_add_table(&mut self, ty: Type, initial: u32, max: Option) -> Table { + let func_elements = Some(vec![]); self.tables.push(TableData { ty, func_elements, + initial, max, }) } diff --git a/src/op_traits.rs b/src/op_traits.rs index 36fb535..8d2fadc 100644 --- a/src/op_traits.rs +++ b/src/op_traits.rs @@ -1,5 +1,6 @@ //! Metadata on operators. +use crate::entity::EntityRef; use crate::ir::{Module, Type, Value}; use crate::Operator; use anyhow::Result; @@ -475,6 +476,13 @@ pub fn op_inputs( Operator::F64x2ConvertLowI32x4U => Ok(Cow::Borrowed(&[Type::V128])), Operator::F32x4DemoteF64x2Zero => Ok(Cow::Borrowed(&[Type::V128])), Operator::F64x2PromoteLowF32x4 => Ok(Cow::Borrowed(&[Type::V128])), + + Operator::CallRef { sig_index } => { + let mut params = module.signatures[*sig_index].params.to_vec(); + params.push(Type::TypedFuncRef(true, sig_index.index() as u32)); + Ok(params.into()) + } + Operator::RefFunc { .. } => Ok(Cow::Borrowed(&[])), } } @@ -933,6 +941,14 @@ pub fn op_outputs( Operator::F64x2ConvertLowI32x4U => Ok(Cow::Borrowed(&[Type::V128])), Operator::F32x4DemoteF64x2Zero => Ok(Cow::Borrowed(&[Type::V128])), Operator::F64x2PromoteLowF32x4 => Ok(Cow::Borrowed(&[Type::V128])), + + Operator::CallRef { sig_index } => { + Ok(Vec::from(module.signatures[*sig_index].returns.clone()).into()) + } + Operator::RefFunc { func_index } => { + let ty = module.funcs[*func_index].sig(); + Ok(vec![Type::TypedFuncRef(true, ty.index() as u32)].into()) + } } } @@ -1397,6 +1413,9 @@ impl Operator { Operator::F64x2ConvertLowI32x4U => &[], Operator::F32x4DemoteF64x2Zero => &[], Operator::F64x2PromoteLowF32x4 => &[], + + Operator::CallRef { .. } => &[All], + Operator::RefFunc { .. } => &[], } } @@ -1885,6 +1904,9 @@ impl std::fmt::Display for Operator { Operator::F64x2ConvertLowI32x4U => write!(f, "f64x2convertlowi32x4u")?, Operator::F32x4DemoteF64x2Zero => write!(f, "f32x4demotef64x2zero")?, Operator::F64x2PromoteLowF32x4 => write!(f, "f64x2promotelowf32x4")?, + + Operator::CallRef { sig_index } => write!(f, "call_ref<{}>", sig_index)?, + Operator::RefFunc { func_index } => write!(f, "ref_func<{}>", func_index)?, } Ok(()) diff --git a/src/ops.rs b/src/ops.rs index 3072a65..b5634c1 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -630,6 +630,13 @@ pub enum Operator { F64x2ConvertLowI32x4U, F32x4DemoteF64x2Zero, F64x2PromoteLowF32x4, + + CallRef { + sig_index: Signature, + }, + RefFunc { + func_index: Func, + }, } #[test] @@ -1259,6 +1266,13 @@ impl<'a, 'b> std::convert::TryFrom<&'b wasmparser::Operator<'a>> for Operator { &wasmparser::Operator::F32x4DemoteF64x2Zero => Ok(Operator::F32x4DemoteF64x2Zero), &wasmparser::Operator::F64x2PromoteLowF32x4 => Ok(Operator::F64x2PromoteLowF32x4), + &wasmparser::Operator::CallRef { type_index } => Ok(Operator::CallRef { + sig_index: Signature::from(type_index), + }), + &wasmparser::Operator::RefFunc { function_index } => Ok(Operator::RefFunc { + func_index: Func::from(function_index), + }), + _ => Err(()), } } diff --git a/wasm_tests/typed-funcref.wat b/wasm_tests/typed-funcref.wat new file mode 100644 index 0000000..a532f58 --- /dev/null +++ b/wasm_tests/typed-funcref.wat @@ -0,0 +1,21 @@ +(module + (type $t (func (param i32 i32) (result i32))) + + (table $tab 10 10 (ref null $t)) + (table $tab2 10 10 (ref null $t)) + + (elem (table $tab2) (i32.const 0) (ref null $t) (ref.func $f)) + + (func $callit (param i32 i32 i32) (result i32) + (call_ref $t (local.get 1) + (local.get 2) + (table.get $tab (local.get 0)))) + + (func $setit (param i32 (ref null $t)) + (table.set $tab (local.get 0) (local.get 1))) + + (func $getf (result (ref null $t)) + (ref.func $f)) + + (func $f (param i32 i32) (result i32) + local.get 0))