From 9eb4bf27fbbe7e397c485c6b387f3ca83738cd1a Mon Sep 17 00:00:00 2001 From: azur Date: Fri, 14 Apr 2023 13:16:15 +0700 Subject: [PATCH] Fix block type inferring --- bin/src/main.rs | 23 +++++-- syntax/src/expr.rs | 29 ++++++++ typing/src/infer.rs | 156 +++++++++++++++++++++++++------------------- typing/src/typed.rs | 3 +- 4 files changed, 135 insertions(+), 76 deletions(-) diff --git a/bin/src/main.rs b/bin/src/main.rs index 72845e5..d1e89d0 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -5,13 +5,22 @@ use typing::infer::infer_exprs; fn main() { let src = " - { - let foo = - let a = true in - let b = false in - a + b; - foo * 2 - } + let r = { + let x = + if 0 == 1 + then { + let x = true; + if x then 1 else 2 + } + else 34 + { + let foo = 30 in + foo + 5 + }; + let y = { 1 } * 2; + if 1 + 1 == 2 + then x + else y + }; ".to_string(); let filename = "?".to_string(); diff --git a/syntax/src/expr.rs b/syntax/src/expr.rs index e60cf3f..34f6a8d 100644 --- a/syntax/src/expr.rs +++ b/syntax/src/expr.rs @@ -89,6 +89,15 @@ pub enum Lit<'src> { #[derive(Clone, Debug)] pub enum UnaryOp { Neg, Not } +impl Display for UnaryOp { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + UnaryOp::Neg => write!(f, "-"), + UnaryOp::Not => write!(f, "!"), + } + } +} + #[derive(Clone, Debug)] pub enum BinaryOp { Add, Sub, Mul, Div, Rem, @@ -96,6 +105,26 @@ pub enum BinaryOp { Eq, Ne, Lt, Le, Gt, Ge, } +impl Display for BinaryOp { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + BinaryOp::Add => write!(f, "+"), + BinaryOp::Sub => write!(f, "-"), + BinaryOp::Mul => write!(f, "*"), + BinaryOp::Div => write!(f, "/"), + BinaryOp::Rem => write!(f, "%"), + BinaryOp::And => write!(f, "&&"), + BinaryOp::Or => write!(f, "||"), + BinaryOp::Eq => write!(f, "=="), + BinaryOp::Ne => write!(f, "!="), + BinaryOp::Lt => write!(f, "<"), + BinaryOp::Le => write!(f, "<="), + BinaryOp::Gt => write!(f, ">"), + BinaryOp::Ge => write!(f, ">="), + } + } +} + pub type Spanned = (T, Span); // Clone is needed for type checking since the type checking diff --git a/typing/src/infer.rs b/typing/src/infer.rs index 8ef7df9..8d43b62 100644 --- a/typing/src/infer.rs +++ b/typing/src/infer.rs @@ -10,11 +10,17 @@ use syntax::{ use super::typed::TExpr; +macro_rules! unbox { + ($e:expr) => { + (*$e.0, $e.1) + }; +} + #[derive(Clone, Debug)] struct Infer<'src> { env: HashMap<&'src str, Type>, subst: Vec, - constraints: Vec<(Type, Type)>, + constraints: Vec<(Type, Type, SimpleSpan)>, } impl<'src> Infer<'src> { @@ -38,6 +44,11 @@ impl<'src> Infer<'src> { self.subst.get(i).cloned() } + /// Add new constraint + fn add_constraint(&mut self, t1: Type, t2: Type, span: SimpleSpan) { + self.constraints.push((t1, t2, span)); + } + /// Check if a type variable occurs in a type fn occurs(&self, i: usize, t: Type) -> bool { use Type::*; @@ -137,7 +148,7 @@ impl<'src> Infer<'src> { /// Solve the constraints by unifying them fn solve(&mut self) -> Result<(), String> { - for (t1, t2) in self.constraints.clone().into_iter() { + for (t1, t2, _span) in self.constraints.clone().into_iter() { self.unify(t1, t2)?; } Ok(()) @@ -248,14 +259,15 @@ impl<'src> Infer<'src> { Block { exprs: exprst, void, - ret_ty, + ret_ty: self.substitute(ret_ty), } }, } } /// Infer the type of an expression - fn infer(&mut self, e: Expr<'src>, expected: Type) -> Result, String> { + fn infer(&mut self, e: (Expr<'src>, SimpleSpan), expected: Type) -> Result, String> { + let (e, span) = e; match e { // Literal values // Push the constraint (expected type to be the literal type) and @@ -267,7 +279,7 @@ impl<'src> Infer<'src> { Lit::Num(_) => Type::Num, Lit::Str(_) => Type::Str, }; - self.constraints.push((expected, t)); + self.add_constraint(expected, t, span); Ok(TExpr::Lit(l)) }, @@ -276,36 +288,36 @@ impl<'src> Infer<'src> { Expr::Ident(ref x) => { let t = self.env.get(x) .ok_or(format!("Unbound variable: {}", x))?; - self.constraints.push((expected, t.clone())); + self.add_constraint(expected, t.clone(), span); Ok(TExpr::Ident(x.clone())) } // Unary & binary operators // The type of the left and right hand side are inferred and // the expected type is determined by the operator - Expr::Unary(op, (expr, espan)) => match op { + Expr::Unary(op, e) => match op { // Numeric operators (Num -> Num) UnaryOp::Neg => { - let et = self.infer(*expr, Type::Num)?; - self.constraints.push((expected, Type::Num)); + let et = self.infer(unbox!(e), Type::Num)?; + self.add_constraint(expected, Type::Num, span); Ok(TExpr::Unary { op, - expr: (Box::new(et), espan), + expr: (Box::new(et), e.1), ret_ty: Type::Num, }) }, // Boolean operators (Bool -> Bool) UnaryOp::Not => { - let et = self.infer(*expr, Type::Bool)?; - self.constraints.push((expected, Type::Bool)); + let et = self.infer(unbox!(e), Type::Bool)?; + self.add_constraint(expected, Type::Bool, span); Ok(TExpr::Unary { op, - expr: (Box::new(et), espan), + expr: (Box::new(et), e.1), ret_ty: Type::Bool, }) }, } - Expr::Binary(op, (lhs, lspan), (rhs, rspan)) => match op { + Expr::Binary(op, lhs, rhs) => match op { // Numeric operators (Num -> Num -> Num) BinaryOp::Add | BinaryOp::Sub @@ -313,13 +325,13 @@ impl<'src> Infer<'src> { | BinaryOp::Div | BinaryOp::Rem => { - let lt = self.infer(*lhs, Type::Num)?; - let rt = self.infer(*rhs, Type::Num)?; - self.constraints.push((expected, Type::Num)); + let lt = self.infer(unbox!(lhs), Type::Num)?; + let rt = self.infer(unbox!(rhs), Type::Num)?; + self.add_constraint(expected, Type::Num, span); Ok(TExpr::Binary { op, - lhs: (Box::new(lt), lspan), - rhs: (Box::new(rt), rspan), + lhs: (Box::new(lt), lhs.1), + rhs: (Box::new(rt), rhs.1), ret_ty: Type::Num, }) }, @@ -327,13 +339,13 @@ impl<'src> Infer<'src> { BinaryOp::And | BinaryOp::Or => { - let lt = self.infer(*lhs, Type::Bool)?; - let rt = self.infer(*rhs, Type::Bool)?; - self.constraints.push((expected, Type::Bool)); + let lt = self.infer(unbox!(lhs), Type::Bool)?; + let rt = self.infer(unbox!(rhs), Type::Bool)?; + self.add_constraint(expected, Type::Bool, span); Ok(TExpr::Binary { op, - lhs: (Box::new(lt), lspan), - rhs: (Box::new(rt), rspan), + lhs: (Box::new(lt), lhs.1), + rhs: (Box::new(rt), rhs.1), ret_ty: Type::Bool, }) }, @@ -349,20 +361,20 @@ impl<'src> Infer<'src> { // expected type for both the left and right hand side // so the type on both side have to be the same let t = self.fresh(); - let lt = self.infer(*lhs, t.clone())?; - let rt = self.infer(*rhs, t)?; - self.constraints.push((expected, Type::Bool)); + let lt = self.infer(unbox!(lhs), t.clone())?; + let rt = self.infer(unbox!(rhs), t)?; + self.add_constraint(expected, Type::Bool, span); Ok(TExpr::Binary { op, - lhs: (Box::new(lt), lspan), - rhs: (Box::new(rt), rspan), + lhs: (Box::new(lt), lhs.1), + rhs: (Box::new(rt), rhs.1), ret_ty: Type::Bool, }) }, } // Lambda - Expr::Lambda(args, ret, (b, bspan)) => { + Expr::Lambda(args, ret, b) => { // Get the return type or create a fresh type variable let rt = ret.unwrap_or(self.fresh()); // Fill in the type of the arguments with a fresh type @@ -376,7 +388,7 @@ impl<'src> Infer<'src> { xs.clone().into_iter().for_each(|(x, t)| { env.insert(x, t); }); let mut inf = self.clone(); inf.env = env; - let bt = inf.infer(*b, rt.clone())?; + let bt = inf.infer(unbox!(b), rt.clone())?; // Add the substitutions & constraints from the body // if it doesn't already exist @@ -392,22 +404,22 @@ impl<'src> Infer<'src> { } // Push the constraints - self.constraints.push((expected, Type::Func( + self.add_constraint(expected, Type::Func( xs.clone().into_iter() .map(|x| x.1) .collect(), Box::new(rt.clone()), - ))); + ), span); Ok(TExpr::Lambda { params: xs, - body: (Box::new(bt), bspan), + body: (Box::new(bt), b.1), ret_ty: rt, }) }, // Call - Expr::Call((f, fspan), args) => { + Expr::Call(f, args) => { // Generate fresh types for the arguments let freshes = args.clone().into_iter() .map(|_| self.fresh()) @@ -418,44 +430,41 @@ impl<'src> Infer<'src> { Box::new(expected), ); // Expect the function to have the function type - let ft = self.infer(*f, fsig)?; + let ft = self.infer(unbox!(f), fsig)?; // Infer the arguments let xs = args.into_iter() .zip(freshes.into_iter()) - .map(|((x, xspan), t)| { - let xt = self.infer(x, t)?; - Ok((xt, xspan)) - }) + .map(|(x, t)| Ok((self.infer(x, t)?, span))) .collect::, String>>()?; Ok(TExpr::Call { - func: (Box::new(ft), fspan), + func: (Box::new(ft), f.1), args: xs, }) }, // If - Expr::If { cond: (c, cspan), t: (t, tspan), f: (f, fspan) } => { + Expr::If { cond, t, f } => { // Condition has to be a boolean - let ct = self.infer(*c, Type::Bool)?; + let ct = self.infer(unbox!(cond), Type::Bool)?; // The type of the if expression is the same as the // expected type - let tt = self.infer(*t, expected.clone())?; - let et = self.infer(*f, expected.clone())?; + let tt = self.infer(unbox!(t), expected.clone())?; + let et = self.infer(unbox!(f), expected.clone())?; Ok(TExpr::If { - cond: (Box::new(ct), cspan), - t: (Box::new(tt), tspan), - f: (Box::new(et), fspan), + cond: (Box::new(ct), cond.1), + t: (Box::new(tt), t.1), + f: (Box::new(et), f.1), br_ty: expected, }) }, // Let & define - Expr::Let { name, ty, value: (v, vspan), body: (b, bspan) } => { + Expr::Let { name, ty, value, body } => { // Infer the type of the value let ty = ty.unwrap_or(self.fresh()); - let vt = self.infer(*v, ty.clone())?; + let vt = self.infer(unbox!(value), ty.clone())?; // Create a new environment and add the binding to it // and then use the new environment to infer the body @@ -463,47 +472,58 @@ impl<'src> Infer<'src> { env.insert(name.clone(), ty.clone()); let mut inf = Infer::new(); inf.env = env; - let bt = inf.infer(*b, expected)?; + let bt = inf.infer(unbox!(body), expected.clone())?; Ok(TExpr::Let { name, ty, - value: (Box::new(vt), vspan), - body: (Box::new(bt), bspan), + value: (Box::new(vt), value.1), + body: (Box::new(bt), body.1), }) }, - Expr::Define { name, ty, value: (v, vspan) } => { + Expr::Define { name, ty, value } => { let ty = ty.unwrap_or(self.fresh()); - let vt = self.infer(*v, ty.clone())?; + let vt = self.infer(unbox!(value), ty.clone())?; self.env.insert(name.clone(), ty.clone()); - // Define always returns unit - self.constraints.push((expected, Type::Unit)); - Ok(TExpr::Define { name, ty, - value: (Box::new(vt), vspan), + value: (Box::new(vt), value.1), }) }, // Block Expr::Block { exprs, void } => { // Infer the type of each expression + let mut last = None; + let len = exprs.len(); let xs = exprs.into_iter() - .map(|(x, xspan)| { - let xt = self.infer(*x, expected.clone())?; - Ok((xt, xspan)) + .enumerate() + .map(|(i, x)| { + let t = self.fresh(); + let xt = self.infer(unbox!(x), t.clone())?; + // Save the type of the last expression + if i == len - 1 { + last = Some(t); + } + Ok((xt, x.1)) }) .collect::, String>>()?; - let ret_ty = if void { + let rt = if void || last.is_none() { + // If the block is void or there is no expression, + // the return type is unit + self.add_constraint(expected, Type::Unit, span); Type::Unit } else { + // Otherwise, the return type is the same as the expected type + self.add_constraint(expected.clone(), last.unwrap(), span); expected }; Ok(TExpr::Block { exprs: xs, - void, ret_ty, + void, + ret_ty: rt, }) }, } @@ -520,11 +540,11 @@ pub fn infer_exprs(es: Vec<(Expr, SimpleSpan)>) -> (Vec<(TExpr, SimpleSpan)>, St // Errors let mut errs = vec![]; - for e in es { + for (e, s) in es { let f = inf.fresh(); - let t = inf.infer(e.0, f).unwrap(); - tes.push(Some((t.clone(), e.1))); - tes_nosub.push((t, e.1)); + let t = inf.infer((e, s), f).unwrap(); + tes.push(Some((t.clone(), s))); + tes_nosub.push((t, s)); match inf.solve() { Ok(_) => { diff --git a/typing/src/typed.rs b/typing/src/typed.rs index 93906d7..0d3d70c 100644 --- a/typing/src/typed.rs +++ b/typing/src/typed.rs @@ -1,3 +1,4 @@ +use chumsky::span::SimpleSpan; use syntax::{ expr::{ BinaryOp, @@ -57,4 +58,4 @@ pub enum TExpr<'src> { void: bool, ret_ty: Type, }, -} +} \ No newline at end of file