diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 77aafcc..8b43c77 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -2,11 +2,14 @@ use crate::cfg::CFGInfo; use crate::entity::EntityRef; -use crate::ir::{FunctionBody, Value, ValueDef}; +use crate::ir::{ + ExportKind, Func, FuncDecl, FunctionBody, ImportKind, Module, Type, Value, ValueDef, +}; use crate::passes::rpo::RPO; use crate::Operator; use anyhow::Result; use std::borrow::Cow; +use std::collections::HashMap; pub mod stackify; use stackify::{Context as StackifyContext, WasmBlock}; @@ -15,7 +18,7 @@ use treeify::Trees; pub mod localify; use localify::Localifier; -pub struct WasmBackend<'a> { +pub struct WasmFuncBackend<'a> { body: &'a FunctionBody, rpo: RPO, trees: Trees, @@ -29,8 +32,8 @@ macro_rules! op { }; } -impl<'a> WasmBackend<'a> { - pub fn new(body: &'a FunctionBody) -> Result> { +impl<'a> WasmFuncBackend<'a> { + pub fn new(body: &'a FunctionBody) -> Result> { log::debug!("Backend compiling:\n{}\n", body.display_verbose("| ")); let cfg = CFGInfo::new(body); let rpo = RPO::compute(body); @@ -41,7 +44,7 @@ impl<'a> WasmBackend<'a> { log::debug!("Ctrl:\n{:?}\n", ctrl); let locals = Localifier::compute(body, &cfg, &trees); log::debug!("Locals:\n{:?}\n", locals); - Ok(WasmBackend { + Ok(WasmFuncBackend { body, rpo, trees, @@ -55,6 +58,7 @@ impl<'a> WasmBackend<'a> { self.locals .locals .values() + .skip(self.body.blocks[self.body.entry].params.len()) .map(|&ty| (1, wasm_encoder::ValType::from(ty))) .collect::>(), ); @@ -73,6 +77,7 @@ impl<'a> WasmBackend<'a> { } _ => {} } + func.instruction(&wasm_encoder::Instruction::End); log::debug!("Compiled to:\n{:?}\n", func); @@ -503,3 +508,188 @@ impl<'a> WasmBackend<'a> { } } } + +pub fn compile(module: &Module<'_>) -> anyhow::Result> { + let mut into_mod = wasm_encoder::Module::new(); + + let mut types = wasm_encoder::TypeSection::new(); + for (_sig, sig_data) in module.signatures() { + let params = sig_data + .params + .iter() + .map(|&ty| wasm_encoder::ValType::from(ty)); + let returns = sig_data + .returns + .iter() + .map(|&ty| wasm_encoder::ValType::from(ty)); + types.function(params, returns); + } + into_mod.section(&types); + + let mut imports = wasm_encoder::ImportSection::new(); + let import_map = module + .imports() + .filter_map(|import| match &import.kind { + &ImportKind::Func(func) => Some((func, (&import.module[..], &import.name[..]))), + _ => None, + }) + .collect::>(); + let mut num_imports = 0; + for (i, (func, func_decl)) in module.funcs().enumerate() { + match func_decl { + FuncDecl::Import(sig) => { + let (module_name, func_name) = import_map.get(&func).unwrap_or(&("", "")); + imports.import( + module_name, + func_name, + wasm_encoder::EntityType::Function(sig.index() as u32), + ); + } + FuncDecl::Body(..) => { + num_imports = i; + break; + } + } + } + into_mod.section(&imports); + + let mut funcs = wasm_encoder::FunctionSection::new(); + for (func, func_decl) in module.funcs().skip(num_imports) { + match func_decl { + FuncDecl::Import(_) => anyhow::bail!("Import comes after func with body: {}", func), + FuncDecl::Body(sig, _) => { + funcs.function(sig.index() as u32); + } + } + } + into_mod.section(&funcs); + + let mut tables = wasm_encoder::TableSection::new(); + for (_table, table_data) in module.tables() { + tables.table(wasm_encoder::TableType { + element_type: wasm_encoder::ValType::from(table_data.ty), + minimum: table_data + .func_elements + .as_ref() + .map(|elt| elt.len()) + .unwrap_or(0) as u32, + maximum: table_data.max, + }); + } + into_mod.section(&tables); + + let mut memories = wasm_encoder::MemorySection::new(); + for (_mem, mem_data) in module.memories() { + memories.memory(wasm_encoder::MemoryType { + minimum: mem_data.initial_pages as u64, + maximum: mem_data.maximum_pages.map(|val| val as u64), + memory64: false, + shared: false, + }); + } + into_mod.section(&memories); + + let mut globals = wasm_encoder::GlobalSection::new(); + for (_global, global_data) in module.globals() { + globals.global( + wasm_encoder::GlobalType { + val_type: wasm_encoder::ValType::from(global_data.ty), + mutable: global_data.mutable, + }, + &const_init(global_data.ty, global_data.value), + ); + } + into_mod.section(&globals); + + let mut exports = wasm_encoder::ExportSection::new(); + for export in module.exports() { + match &export.kind { + &ExportKind::Table(table) => { + exports.export( + &export.name[..], + wasm_encoder::ExportKind::Table, + table.index() as u32, + ); + } + &ExportKind::Func(func) => { + exports.export( + &export.name[..], + wasm_encoder::ExportKind::Func, + func.index() as u32, + ); + } + &ExportKind::Memory(mem) => { + exports.export( + &export.name[..], + wasm_encoder::ExportKind::Memory, + mem.index() as u32, + ); + } + &ExportKind::Global(global) => { + exports.export( + &export.name[..], + wasm_encoder::ExportKind::Global, + global.index() as u32, + ); + } + } + } + into_mod.section(&exports); + + if let Some(start) = module.start_func { + let start = wasm_encoder::StartSection { + function_index: start.index() as u32, + }; + into_mod.section(&start); + } + + let mut elem = wasm_encoder::ElementSection::new(); + for (table, table_data) in module.tables() { + if let Some(elts) = &table_data.func_elements { + for (i, &elt) in elts.iter().enumerate() { + if elt.is_valid() { + elem.active( + Some(table.index() as u32), + &wasm_encoder::ConstExpr::i32_const(i as i32), + wasm_encoder::ValType::FuncRef, + wasm_encoder::Elements::Functions(&[elt.index() as u32]), + ); + } + } + } + } + into_mod.section(&elem); + + let mut code = wasm_encoder::CodeSection::new(); + for (_func, func_decl) in module.funcs().skip(num_imports) { + let body = func_decl.body().unwrap(); + let body = WasmFuncBackend::new(body)?.compile()?; + code.function(&body); + } + into_mod.section(&code); + + let mut data = wasm_encoder::DataSection::new(); + for (mem, mem_data) in module.memories() { + for segment in &mem_data.segments { + data.active( + mem.index() as u32, + &wasm_encoder::ConstExpr::i32_const(segment.offset as i32), + segment.data.iter().copied(), + ); + } + } + into_mod.section(&data); + + Ok(into_mod.finish()) +} + +fn const_init(ty: Type, value: Option) -> wasm_encoder::ConstExpr { + let bits = value.unwrap_or(0); + match ty { + Type::I32 => wasm_encoder::ConstExpr::i32_const(bits as u32 as i32), + Type::I64 => wasm_encoder::ConstExpr::i64_const(bits as i64), + Type::F32 => wasm_encoder::ConstExpr::f32_const(f32::from_bits(bits as u32)), + Type::F64 => wasm_encoder::ConstExpr::f64_const(f64::from_bits(bits as u64)), + _ => unimplemented!(), + } +} diff --git a/src/frontend.rs b/src/frontend.rs index 3a298a4..6f33159 100644 --- a/src/frontend.rs +++ b/src/frontend.rs @@ -221,6 +221,9 @@ fn handle_payload<'a>( } } Payload::End(_) => {} + Payload::StartSection { func, .. } => { + module.start_func = Some(Func::from(func)); + } payload => { log::warn!("Skipping section: {:?}", payload); } diff --git a/src/ir/display.rs b/src/ir/display.rs index 39f6ff7..c3fa77a 100644 --- a/src/ir/display.rs +++ b/src/ir/display.rs @@ -147,6 +147,9 @@ pub struct ModuleDisplay<'a>(pub(crate) &'a Module<'a>); impl<'a> Display for ModuleDisplay<'a> { fn fmt(&self, f: &mut Formatter) -> FmtResult { writeln!(f, "module {{")?; + if let Some(func) = self.0.start_func { + writeln!(f, " start = {}", func)?; + } let mut sig_strs = HashMap::new(); for (sig, sig_data) in self.0.signatures() { let arg_tys = sig_data diff --git a/src/ir/module.rs b/src/ir/module.rs index 6a0a78f..9251d5c 100644 --- a/src/ir/module.rs +++ b/src/ir/module.rs @@ -14,6 +14,7 @@ pub struct Module<'a> { imports: Vec, exports: Vec, memories: EntityVec, + pub start_func: Option, } #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -133,6 +134,7 @@ impl<'a> Module<'a> { imports: vec![], exports: vec![], memories: EntityVec::default(), + start_func: None, } } } @@ -226,15 +228,7 @@ impl<'a> Module<'a> { } pub fn to_wasm_bytes(&self) -> Result> { - for (func, func_decl) in self.funcs.entries() { - log::debug!("Compiling: {}", func); - if let Some(body) = func_decl.body() { - let comp = backend::WasmBackend::new(body)?; - let _ = comp.compile()?; - } - } - - Ok(vec![]) + backend::compile(self) } pub fn display<'b>(&'b self) -> ModuleDisplay<'b>