diff --git a/hblang/src/codegen.rs b/hblang/src/codegen.rs index ab7fc44..435f880 100644 --- a/hblang/src/codegen.rs +++ b/hblang/src/codegen.rs @@ -234,9 +234,12 @@ impl RegAlloc { } } +#[derive(Clone)] struct FnLabel { offset: u32, name: Ident, + args: Rc<[Type]>, + ret: Type, } struct Variable { @@ -350,17 +353,94 @@ impl<'a> Codegen<'a> { self.path = path; self.input = input; + for expr in exprs { + match expr { + E::BinOp { + left: E::Ident { id, .. }, + op: T::Decl, + right: E::Closure { args, ret, .. }, + } => { + let args = args.iter().map(|arg| self.ty(&arg.ty)).collect::>(); + let ret = self.ty(ret); + self.declare_fn_label(*id, args, ret); + } + E::BinOp { + left: E::Ident { id, name, .. }, + op: T::Decl, + right: E::Struct { .. }, + } => { + self.records.push(Struct { + id: *id, + name: (*name).into(), + fields: Rc::from([]), + }); + } + _ => self.report(expr.pos(), "expected declaration"), + } + } + for expr in exprs { let E::BinOp { - left: E::Ident { .. }, + left: E::Ident { id, name, .. }, op: T::Decl, - .. + right, } = expr else { self.report(expr.pos(), format_args!("expected declaration")); }; - self.expr(expr, None); + match right { + E::Struct { fields, .. } => { + let fields = fields + .iter() + .map(|&(name, ty)| (name.into(), self.ty(&ty))) + .collect(); + self.records + .iter_mut() + .find(|r| r.id == *id) + .unwrap() + .fields = fields; + } + E::Closure { body, args, .. } => { + log::dbg!("fn: {}", name); + let frame = self.define_fn_label(*id); + if *name == "main" { + self.main = Some(frame.label); + } + + let fn_label = self.labels[frame.label as usize].clone(); + + log::dbg!("fn-args"); + let mut parama = 2..12; + for (arg, &ty) in args.iter().zip(fn_label.args.iter()) { + let loc = self.load_arg(ty, &mut parama); + self.vars.push(Variable { + id: arg.id, + value: Value { ty, loc }, + }); + } + + self.gpa.init_callee(); + self.ret = fn_label.ret; + + log::dbg!("fn-body"); + if self.expr(body, None).is_some() { + self.report(body.pos(), "expected all paths in the fucntion to return"); + } + self.vars.clear(); + + log::dbg!("fn-prelude, stack: {:x}", self.stack_size); + + log::dbg!("fn-relocs"); + self.write_fn_prelude(frame); + + log::dbg!("fn-ret"); + self.reloc_rets(); + self.ret(); + self.stack_size = 0; + } + _ => unreachable!(), + } } Ok(()) } @@ -538,16 +618,7 @@ impl<'a> Codegen<'a> { .position(|(n, _)| *n == name.as_ref()) .unwrap(); let (_, value) = field_values.remove(index); - if value.ty != ty { - self.report( - expr.pos(), - format_args!( - "expected {}, got {}", - self.display_ty(ty), - self.display_ty(value.ty) - ), - ); - } + self.assert_ty(expr.pos(), ty, value.ty); log::dbg!("ctor: {} {} {:?}", stack, offset, value.loc); self.assign( Value { @@ -588,22 +659,6 @@ impl<'a> Codegen<'a> { }; Some(Value { ty, loc }) } - E::BinOp { - left: E::Ident { id, name, .. }, - op: T::Decl, - right: E::Struct { fields, .. }, - } => { - let fields = fields - .iter() - .map(|&(name, ty)| (name.into(), self.ty(&ty))) - .collect(); - self.records.push(Struct { - id: *id, - name: (*name).into(), - fields, - }); - Some(Value::VOID) - } E::UnOp { op: T::Amp, val, @@ -643,50 +698,6 @@ impl<'a> Codegen<'a> { ), } } - E::BinOp { - left: E::Ident { name, id, .. }, - op: T::Decl, - right: E::Closure { - ret, body, args, .. - }, - } => { - log::dbg!("fn: {}", name); - let frame = self.add_label(*id); - if *name == "main" { - self.main = Some(frame.label); - } - - log::dbg!("fn-args"); - let mut parama = 2..12; - for arg in args.iter() { - let ty = self.ty(&arg.ty); - let loc = self.load_arg(ty, &mut parama); - self.vars.push(Variable { - id: arg.id, - value: Value { ty, loc }, - }); - } - - self.gpa.init_callee(); - self.ret = self.ty(ret); - - log::dbg!("fn-body"); - if self.expr(body, None).is_some() { - self.report(body.pos(), "expected all paths in the fucntion to return"); - } - self.vars.clear(); - - log::dbg!("fn-prelude, stack: {:x}", self.stack_size); - - log::dbg!("fn-relocs"); - self.write_fn_prelude(frame); - - log::dbg!("fn-ret"); - self.reloc_rets(); - self.ret(); - self.stack_size = 0; - Some(Value::VOID) - } E::BinOp { left: E::Ident { id, .. }, op: T::Decl, @@ -705,12 +716,14 @@ impl<'a> Codegen<'a> { func: E::Ident { id, .. }, args, } => { + let func = self.get_label(*id); + let fn_label = self.labels[func as usize].clone(); let mut parama = 2..12; - for arg in args.iter() { - let arg = self.expr(arg, None)?; + for (earg, &ty) in args.iter().zip(fn_label.args.iter()) { + let arg = self.expr(earg, Some(ty))?; + self.assert_ty(earg.pos(), ty, arg.ty); self.pass_arg(arg, &mut parama); } - let func = self.get_or_reserve_label(*id); self.code.call(func); let reg = self.gpa.allocate(); self.code.encode(instrs::cp(reg.0, 1)); @@ -731,16 +744,7 @@ impl<'a> Codegen<'a> { E::Return { val, pos } => { if let Some(val) = val { let val = self.expr(val, Some(self.ret))?; - if val.ty != self.ret { - self.report( - pos, - format_args!( - "expected {}, got {}", - self.display_ty(self.ret), - self.display_ty(val.ty) - ), - ); - } + self.assert_ty(pos, self.ret, val.ty); self.assign( Value { ty: self.ret, @@ -969,28 +973,20 @@ impl<'a> Codegen<'a> { } } - fn get_or_reserve_label(&mut self, name: Ident) -> LabelId { - if let Some(label) = self.labels.iter().position(|l| l.name == name) { - label as u32 - } else { - self.labels.push(FnLabel { offset: 0, name }); - self.labels.len() as u32 - 1 - } + fn declare_fn_label(&mut self, name: Ident, args: Rc<[Type]>, ret: Type) -> LabelId { + self.labels.push(FnLabel { + offset: 0, + name, + args, + ret, + }); + self.labels.len() as u32 - 1 } - fn add_label(&mut self, name: Ident) -> Frame { + fn define_fn_label(&mut self, name: Ident) -> Frame { let offset = self.code.code.len() as u32; - let label = if let Some(label) = self.labels.iter().position(|l| l.name == name) { - self.labels[label].offset = offset; - label as u32 - } else { - self.labels.push(FnLabel { - offset, - name: name.into(), - }); - self.labels.len() as u32 - 1 - }; - + let label = self.get_label(name); + self.labels[label as usize].offset = offset; Frame { label, prev_relocs: self.code.relocs.len(), @@ -1162,6 +1158,14 @@ impl<'a> Codegen<'a> { } } + fn assert_ty(&self, pos: parser::Pos, ty: Type, expected: Type) { + if ty != expected { + let ty = self.display_ty(ty); + let expected = self.display_ty(expected); + self.report(pos, format_args!("expected {ty}, got {expected}")); + } + } + fn report(&self, pos: parser::Pos, msg: impl std::fmt::Display) -> ! { let (line, col) = lexer::line_col(self.input, pos); println!("{}:{}:{}: {}", self.path.display(), line, col, msg);