Fix block type inferring

tc
azur 2023-04-14 13:16:15 +07:00
parent 50ae682203
commit 9eb4bf27fb
4 changed files with 135 additions and 76 deletions

View File

@ -5,13 +5,22 @@ use typing::infer::infer_exprs;
fn main() { fn main() {
let src = " let src = "
{ let r = {
let foo = let x =
let a = true in if 0 == 1
let b = false in then {
a + b; let x = true;
foo * 2 if x then 1 else 2
} }
else 34 + {
let foo = 30 in
foo + 5
};
let y = { 1 } * 2;
if 1 + 1 == 2
then x
else y
};
".to_string(); ".to_string();
let filename = "?".to_string(); let filename = "?".to_string();

View File

@ -89,6 +89,15 @@ pub enum Lit<'src> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum UnaryOp { Neg, Not } pub enum UnaryOp { Neg, Not }
impl Display for UnaryOp {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
UnaryOp::Neg => write!(f, "-"),
UnaryOp::Not => write!(f, "!"),
}
}
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum BinaryOp { pub enum BinaryOp {
Add, Sub, Mul, Div, Rem, Add, Sub, Mul, Div, Rem,
@ -96,6 +105,26 @@ pub enum BinaryOp {
Eq, Ne, Lt, Le, Gt, Ge, Eq, Ne, Lt, Le, Gt, Ge,
} }
impl Display for BinaryOp {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
BinaryOp::Add => write!(f, "+"),
BinaryOp::Sub => write!(f, "-"),
BinaryOp::Mul => write!(f, "*"),
BinaryOp::Div => write!(f, "/"),
BinaryOp::Rem => write!(f, "%"),
BinaryOp::And => write!(f, "&&"),
BinaryOp::Or => write!(f, "||"),
BinaryOp::Eq => write!(f, "=="),
BinaryOp::Ne => write!(f, "!="),
BinaryOp::Lt => write!(f, "<"),
BinaryOp::Le => write!(f, "<="),
BinaryOp::Gt => write!(f, ">"),
BinaryOp::Ge => write!(f, ">="),
}
}
}
pub type Spanned<T> = (T, Span); pub type Spanned<T> = (T, Span);
// Clone is needed for type checking since the type checking // Clone is needed for type checking since the type checking

View File

@ -10,11 +10,17 @@ use syntax::{
use super::typed::TExpr; use super::typed::TExpr;
macro_rules! unbox {
($e:expr) => {
(*$e.0, $e.1)
};
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
struct Infer<'src> { struct Infer<'src> {
env: HashMap<&'src str, Type>, env: HashMap<&'src str, Type>,
subst: Vec<Type>, subst: Vec<Type>,
constraints: Vec<(Type, Type)>, constraints: Vec<(Type, Type, SimpleSpan)>,
} }
impl<'src> Infer<'src> { impl<'src> Infer<'src> {
@ -38,6 +44,11 @@ impl<'src> Infer<'src> {
self.subst.get(i).cloned() self.subst.get(i).cloned()
} }
/// Add new constraint
fn add_constraint(&mut self, t1: Type, t2: Type, span: SimpleSpan) {
self.constraints.push((t1, t2, span));
}
/// Check if a type variable occurs in a type /// Check if a type variable occurs in a type
fn occurs(&self, i: usize, t: Type) -> bool { fn occurs(&self, i: usize, t: Type) -> bool {
use Type::*; use Type::*;
@ -137,7 +148,7 @@ impl<'src> Infer<'src> {
/// Solve the constraints by unifying them /// Solve the constraints by unifying them
fn solve(&mut self) -> Result<(), String> { fn solve(&mut self) -> Result<(), String> {
for (t1, t2) in self.constraints.clone().into_iter() { for (t1, t2, _span) in self.constraints.clone().into_iter() {
self.unify(t1, t2)?; self.unify(t1, t2)?;
} }
Ok(()) Ok(())
@ -248,14 +259,15 @@ impl<'src> Infer<'src> {
Block { Block {
exprs: exprst, exprs: exprst,
void, void,
ret_ty, ret_ty: self.substitute(ret_ty),
} }
}, },
} }
} }
/// Infer the type of an expression /// Infer the type of an expression
fn infer(&mut self, e: Expr<'src>, expected: Type) -> Result<TExpr<'src>, String> { fn infer(&mut self, e: (Expr<'src>, SimpleSpan), expected: Type) -> Result<TExpr<'src>, String> {
let (e, span) = e;
match e { match e {
// Literal values // Literal values
// Push the constraint (expected type to be the literal type) and // Push the constraint (expected type to be the literal type) and
@ -267,7 +279,7 @@ impl<'src> Infer<'src> {
Lit::Num(_) => Type::Num, Lit::Num(_) => Type::Num,
Lit::Str(_) => Type::Str, Lit::Str(_) => Type::Str,
}; };
self.constraints.push((expected, t)); self.add_constraint(expected, t, span);
Ok(TExpr::Lit(l)) Ok(TExpr::Lit(l))
}, },
@ -276,36 +288,36 @@ impl<'src> Infer<'src> {
Expr::Ident(ref x) => { Expr::Ident(ref x) => {
let t = self.env.get(x) let t = self.env.get(x)
.ok_or(format!("Unbound variable: {}", x))?; .ok_or(format!("Unbound variable: {}", x))?;
self.constraints.push((expected, t.clone())); self.add_constraint(expected, t.clone(), span);
Ok(TExpr::Ident(x.clone())) Ok(TExpr::Ident(x.clone()))
} }
// Unary & binary operators // Unary & binary operators
// The type of the left and right hand side are inferred and // The type of the left and right hand side are inferred and
// the expected type is determined by the operator // the expected type is determined by the operator
Expr::Unary(op, (expr, espan)) => match op { Expr::Unary(op, e) => match op {
// Numeric operators (Num -> Num) // Numeric operators (Num -> Num)
UnaryOp::Neg => { UnaryOp::Neg => {
let et = self.infer(*expr, Type::Num)?; let et = self.infer(unbox!(e), Type::Num)?;
self.constraints.push((expected, Type::Num)); self.add_constraint(expected, Type::Num, span);
Ok(TExpr::Unary { Ok(TExpr::Unary {
op, op,
expr: (Box::new(et), espan), expr: (Box::new(et), e.1),
ret_ty: Type::Num, ret_ty: Type::Num,
}) })
}, },
// Boolean operators (Bool -> Bool) // Boolean operators (Bool -> Bool)
UnaryOp::Not => { UnaryOp::Not => {
let et = self.infer(*expr, Type::Bool)?; let et = self.infer(unbox!(e), Type::Bool)?;
self.constraints.push((expected, Type::Bool)); self.add_constraint(expected, Type::Bool, span);
Ok(TExpr::Unary { Ok(TExpr::Unary {
op, op,
expr: (Box::new(et), espan), expr: (Box::new(et), e.1),
ret_ty: Type::Bool, ret_ty: Type::Bool,
}) })
}, },
} }
Expr::Binary(op, (lhs, lspan), (rhs, rspan)) => match op { Expr::Binary(op, lhs, rhs) => match op {
// Numeric operators (Num -> Num -> Num) // Numeric operators (Num -> Num -> Num)
BinaryOp::Add BinaryOp::Add
| BinaryOp::Sub | BinaryOp::Sub
@ -313,13 +325,13 @@ impl<'src> Infer<'src> {
| BinaryOp::Div | BinaryOp::Div
| BinaryOp::Rem | BinaryOp::Rem
=> { => {
let lt = self.infer(*lhs, Type::Num)?; let lt = self.infer(unbox!(lhs), Type::Num)?;
let rt = self.infer(*rhs, Type::Num)?; let rt = self.infer(unbox!(rhs), Type::Num)?;
self.constraints.push((expected, Type::Num)); self.add_constraint(expected, Type::Num, span);
Ok(TExpr::Binary { Ok(TExpr::Binary {
op, op,
lhs: (Box::new(lt), lspan), lhs: (Box::new(lt), lhs.1),
rhs: (Box::new(rt), rspan), rhs: (Box::new(rt), rhs.1),
ret_ty: Type::Num, ret_ty: Type::Num,
}) })
}, },
@ -327,13 +339,13 @@ impl<'src> Infer<'src> {
BinaryOp::And BinaryOp::And
| BinaryOp::Or | BinaryOp::Or
=> { => {
let lt = self.infer(*lhs, Type::Bool)?; let lt = self.infer(unbox!(lhs), Type::Bool)?;
let rt = self.infer(*rhs, Type::Bool)?; let rt = self.infer(unbox!(rhs), Type::Bool)?;
self.constraints.push((expected, Type::Bool)); self.add_constraint(expected, Type::Bool, span);
Ok(TExpr::Binary { Ok(TExpr::Binary {
op, op,
lhs: (Box::new(lt), lspan), lhs: (Box::new(lt), lhs.1),
rhs: (Box::new(rt), rspan), rhs: (Box::new(rt), rhs.1),
ret_ty: Type::Bool, ret_ty: Type::Bool,
}) })
}, },
@ -349,20 +361,20 @@ impl<'src> Infer<'src> {
// expected type for both the left and right hand side // expected type for both the left and right hand side
// so the type on both side have to be the same // so the type on both side have to be the same
let t = self.fresh(); let t = self.fresh();
let lt = self.infer(*lhs, t.clone())?; let lt = self.infer(unbox!(lhs), t.clone())?;
let rt = self.infer(*rhs, t)?; let rt = self.infer(unbox!(rhs), t)?;
self.constraints.push((expected, Type::Bool)); self.add_constraint(expected, Type::Bool, span);
Ok(TExpr::Binary { Ok(TExpr::Binary {
op, op,
lhs: (Box::new(lt), lspan), lhs: (Box::new(lt), lhs.1),
rhs: (Box::new(rt), rspan), rhs: (Box::new(rt), rhs.1),
ret_ty: Type::Bool, ret_ty: Type::Bool,
}) })
}, },
} }
// Lambda // Lambda
Expr::Lambda(args, ret, (b, bspan)) => { Expr::Lambda(args, ret, b) => {
// Get the return type or create a fresh type variable // Get the return type or create a fresh type variable
let rt = ret.unwrap_or(self.fresh()); let rt = ret.unwrap_or(self.fresh());
// Fill in the type of the arguments with a fresh type // Fill in the type of the arguments with a fresh type
@ -376,7 +388,7 @@ impl<'src> Infer<'src> {
xs.clone().into_iter().for_each(|(x, t)| { env.insert(x, t); }); xs.clone().into_iter().for_each(|(x, t)| { env.insert(x, t); });
let mut inf = self.clone(); let mut inf = self.clone();
inf.env = env; inf.env = env;
let bt = inf.infer(*b, rt.clone())?; let bt = inf.infer(unbox!(b), rt.clone())?;
// Add the substitutions & constraints from the body // Add the substitutions & constraints from the body
// if it doesn't already exist // if it doesn't already exist
@ -392,22 +404,22 @@ impl<'src> Infer<'src> {
} }
// Push the constraints // Push the constraints
self.constraints.push((expected, Type::Func( self.add_constraint(expected, Type::Func(
xs.clone().into_iter() xs.clone().into_iter()
.map(|x| x.1) .map(|x| x.1)
.collect(), .collect(),
Box::new(rt.clone()), Box::new(rt.clone()),
))); ), span);
Ok(TExpr::Lambda { Ok(TExpr::Lambda {
params: xs, params: xs,
body: (Box::new(bt), bspan), body: (Box::new(bt), b.1),
ret_ty: rt, ret_ty: rt,
}) })
}, },
// Call // Call
Expr::Call((f, fspan), args) => { Expr::Call(f, args) => {
// Generate fresh types for the arguments // Generate fresh types for the arguments
let freshes = args.clone().into_iter() let freshes = args.clone().into_iter()
.map(|_| self.fresh()) .map(|_| self.fresh())
@ -418,44 +430,41 @@ impl<'src> Infer<'src> {
Box::new(expected), Box::new(expected),
); );
// Expect the function to have the function type // Expect the function to have the function type
let ft = self.infer(*f, fsig)?; let ft = self.infer(unbox!(f), fsig)?;
// Infer the arguments // Infer the arguments
let xs = args.into_iter() let xs = args.into_iter()
.zip(freshes.into_iter()) .zip(freshes.into_iter())
.map(|((x, xspan), t)| { .map(|(x, t)| Ok((self.infer(x, t)?, span)))
let xt = self.infer(x, t)?;
Ok((xt, xspan))
})
.collect::<Result<Vec<_>, String>>()?; .collect::<Result<Vec<_>, String>>()?;
Ok(TExpr::Call { Ok(TExpr::Call {
func: (Box::new(ft), fspan), func: (Box::new(ft), f.1),
args: xs, args: xs,
}) })
}, },
// If // If
Expr::If { cond: (c, cspan), t: (t, tspan), f: (f, fspan) } => { Expr::If { cond, t, f } => {
// Condition has to be a boolean // Condition has to be a boolean
let ct = self.infer(*c, Type::Bool)?; let ct = self.infer(unbox!(cond), Type::Bool)?;
// The type of the if expression is the same as the // The type of the if expression is the same as the
// expected type // expected type
let tt = self.infer(*t, expected.clone())?; let tt = self.infer(unbox!(t), expected.clone())?;
let et = self.infer(*f, expected.clone())?; let et = self.infer(unbox!(f), expected.clone())?;
Ok(TExpr::If { Ok(TExpr::If {
cond: (Box::new(ct), cspan), cond: (Box::new(ct), cond.1),
t: (Box::new(tt), tspan), t: (Box::new(tt), t.1),
f: (Box::new(et), fspan), f: (Box::new(et), f.1),
br_ty: expected, br_ty: expected,
}) })
}, },
// Let & define // Let & define
Expr::Let { name, ty, value: (v, vspan), body: (b, bspan) } => { Expr::Let { name, ty, value, body } => {
// Infer the type of the value // Infer the type of the value
let ty = ty.unwrap_or(self.fresh()); let ty = ty.unwrap_or(self.fresh());
let vt = self.infer(*v, ty.clone())?; let vt = self.infer(unbox!(value), ty.clone())?;
// Create a new environment and add the binding to it // Create a new environment and add the binding to it
// and then use the new environment to infer the body // and then use the new environment to infer the body
@ -463,47 +472,58 @@ impl<'src> Infer<'src> {
env.insert(name.clone(), ty.clone()); env.insert(name.clone(), ty.clone());
let mut inf = Infer::new(); let mut inf = Infer::new();
inf.env = env; inf.env = env;
let bt = inf.infer(*b, expected)?; let bt = inf.infer(unbox!(body), expected.clone())?;
Ok(TExpr::Let { Ok(TExpr::Let {
name, ty, name, ty,
value: (Box::new(vt), vspan), value: (Box::new(vt), value.1),
body: (Box::new(bt), bspan), body: (Box::new(bt), body.1),
}) })
}, },
Expr::Define { name, ty, value: (v, vspan) } => { Expr::Define { name, ty, value } => {
let ty = ty.unwrap_or(self.fresh()); let ty = ty.unwrap_or(self.fresh());
let vt = self.infer(*v, ty.clone())?; let vt = self.infer(unbox!(value), ty.clone())?;
self.env.insert(name.clone(), ty.clone()); self.env.insert(name.clone(), ty.clone());
// Define always returns unit
self.constraints.push((expected, Type::Unit));
Ok(TExpr::Define { Ok(TExpr::Define {
name, ty, name, ty,
value: (Box::new(vt), vspan), value: (Box::new(vt), value.1),
}) })
}, },
// Block // Block
Expr::Block { exprs, void } => { Expr::Block { exprs, void } => {
// Infer the type of each expression // Infer the type of each expression
let mut last = None;
let len = exprs.len();
let xs = exprs.into_iter() let xs = exprs.into_iter()
.map(|(x, xspan)| { .enumerate()
let xt = self.infer(*x, expected.clone())?; .map(|(i, x)| {
Ok((xt, xspan)) let t = self.fresh();
let xt = self.infer(unbox!(x), t.clone())?;
// Save the type of the last expression
if i == len - 1 {
last = Some(t);
}
Ok((xt, x.1))
}) })
.collect::<Result<Vec<_>, String>>()?; .collect::<Result<Vec<_>, String>>()?;
let ret_ty = if void { let rt = if void || last.is_none() {
// If the block is void or there is no expression,
// the return type is unit
self.add_constraint(expected, Type::Unit, span);
Type::Unit Type::Unit
} else { } else {
// Otherwise, the return type is the same as the expected type
self.add_constraint(expected.clone(), last.unwrap(), span);
expected expected
}; };
Ok(TExpr::Block { Ok(TExpr::Block {
exprs: xs, exprs: xs,
void, ret_ty, void,
ret_ty: rt,
}) })
}, },
} }
@ -520,11 +540,11 @@ pub fn infer_exprs(es: Vec<(Expr, SimpleSpan)>) -> (Vec<(TExpr, SimpleSpan)>, St
// Errors // Errors
let mut errs = vec![]; let mut errs = vec![];
for e in es { for (e, s) in es {
let f = inf.fresh(); let f = inf.fresh();
let t = inf.infer(e.0, f).unwrap(); let t = inf.infer((e, s), f).unwrap();
tes.push(Some((t.clone(), e.1))); tes.push(Some((t.clone(), s)));
tes_nosub.push((t, e.1)); tes_nosub.push((t, s));
match inf.solve() { match inf.solve() {
Ok(_) => { Ok(_) => {

View File

@ -1,3 +1,4 @@
use chumsky::span::SimpleSpan;
use syntax::{ use syntax::{
expr::{ expr::{
BinaryOp, BinaryOp,
@ -57,4 +58,4 @@ pub enum TExpr<'src> {
void: bool, void: bool,
ret_ty: Type, ret_ty: Type,
}, },
} }