diff --git a/Cargo.lock b/Cargo.lock index aa71341..0e3d12e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,16 @@ dependencies = [ "yansi", ] +[[package]] +name = "bin" +version = "0.1.0" +dependencies = [ + "ariadne", + "chumsky", + "syntax", + "typing", +] + [[package]] name = "cc" version = "1.0.79" @@ -54,14 +64,6 @@ dependencies = [ "ahash", ] -[[package]] -name = "holymer" -version = "0.1.0" -dependencies = [ - "ariadne", - "chumsky", -] - [[package]] name = "libc" version = "0.2.140" @@ -96,6 +98,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "syntax" +version = "0.1.0" +dependencies = [ + "chumsky", +] + +[[package]] +name = "typing" +version = "0.1.0" +dependencies = [ + "chumsky", + "syntax", +] + [[package]] name = "unicode-width" version = "0.1.10" diff --git a/Cargo.toml b/Cargo.toml index e3c045d..4a07dc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,7 @@ -[package] -name = "holymer" -version = "0.1.0" -edition = "2021" +[workspace] -[dependencies] -ariadne = "0.2.0" -chumsky = "1.0.0-alpha.3" +members = [ + "bin", + "syntax", + "typing", +] \ No newline at end of file diff --git a/bin/Cargo.toml b/bin/Cargo.toml new file mode 100644 index 0000000..5a8dc75 --- /dev/null +++ b/bin/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "bin" +version = "0.1.0" +edition = "2021" + +[dependencies] +ariadne = "0.2.0" +chumsky = "1.0.0-alpha.3" +syntax = { path = "../syntax" } +typing = { path = "../typing" } diff --git a/src/main.rs b/bin/src/main.rs similarity index 86% rename from src/main.rs rename to bin/src/main.rs index ce55261..72845e5 100644 --- a/src/main.rs +++ b/bin/src/main.rs @@ -1,9 +1,7 @@ use ariadne::{sources, Color, Label, Report, ReportKind}; use chumsky::{Parser, prelude::Input}; -use self::{parse::parser::{lexer, exprs_parser}}; - -pub mod parse; -pub mod typing; +use syntax::parser::{lexer, exprs_parser}; +use typing::infer::infer_exprs; fn main() { let src = " @@ -26,7 +24,13 @@ fn main() { .into_output_errors(); if let Some(ast) = ast.filter(|_| errs.len() + parse_errs.len() == 0) { - println!("{:?}", ast); + let (ast, e) = infer_exprs(ast.0); + if !e.is_empty() { + println!("{:?}", e); + } + if !ast.is_empty() { + println!("{:?}", ast); + } } parse_errs diff --git a/sketch.hlm b/sketch.hlm index db009b4..9c6cf0c 100644 --- a/sketch.hlm +++ b/sketch.hlm @@ -1,25 +1 @@ -mut ret_ty: Option - -if type_check(expr) = Return { - if ret_ty.is_some() && ret_ty != this_ty { - error - } else { - ret_ty = this_ty - } -} - -===== - -{ - if true { - return 1; - } - - if false { - return "Hello"; - } - - do_something(); - - return 4; -} \ No newline at end of file +let add = \x : num, y : num -> num = x + y; \ No newline at end of file diff --git a/src/typing/mod.rs b/src/typing/mod.rs deleted file mode 100644 index 53b10eb..0000000 --- a/src/typing/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod ty; -pub mod typed; \ No newline at end of file diff --git a/syntax/Cargo.toml b/syntax/Cargo.toml new file mode 100644 index 0000000..65cdf95 --- /dev/null +++ b/syntax/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "syntax" +version = "0.1.0" +edition = "2021" + +[dependencies] +chumsky = "1.0.0-alpha.3" diff --git a/syntax/src/expr.rs b/syntax/src/expr.rs new file mode 100644 index 0000000..e60cf3f --- /dev/null +++ b/syntax/src/expr.rs @@ -0,0 +1,134 @@ +use std::fmt::{ Display, Formatter, self }; +use chumsky::span::SimpleSpan; + +use super::ty::Type; + +#[derive(Clone, Debug, PartialEq)] +pub enum Delim { Paren, Brack, Brace } + +// The tokens of the language. +// 'src is the lifetime of the source code string. +#[derive(Clone, Debug, PartialEq)] +pub enum Token<'src> { + Unit, Bool(bool), Num(f64), Str(&'src str), + Ident(&'src str), + + Add, Sub, Mul, Div, Rem, + Eq, Ne, Lt, Gt, Le, Ge, + And, Or, Not, + + Assign, Comma, Colon, Semicolon, + Open(Delim), Close(Delim), + Lambda, Arrow, + + Let, In, Func, Return, If, Then, Else, +} + +impl<'src> Display for Token<'src> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + Token::Unit => write!(f, "()"), + Token::Bool(b) => write!(f, "{}", b), + Token::Num(n) => write!(f, "{}", n), + Token::Str(s) => write!(f, "\"{}\"", s), + Token::Ident(s) => write!(f, "{}", s), + + Token::Add => write!(f, "+"), + Token::Sub => write!(f, "-"), + Token::Mul => write!(f, "*"), + Token::Div => write!(f, "/"), + Token::Rem => write!(f, "%"), + Token::Eq => write!(f, "=="), + Token::Ne => write!(f, "!="), + Token::Lt => write!(f, "<"), + Token::Gt => write!(f, ">"), + Token::Le => write!(f, "<="), + Token::Ge => write!(f, ">="), + Token::And => write!(f, "&&"), + Token::Or => write!(f, "||"), + Token::Not => write!(f, "!"), + + Token::Assign => write!(f, "="), + Token::Comma => write!(f, ","), + Token::Colon => write!(f, ":"), + Token::Semicolon => write!(f, ";"), + Token::Open(d) => write!(f, "{}", match d { + Delim::Paren => "(", + Delim::Brack => "[", + Delim::Brace => "{", + }), + Token::Close(d) => write!(f, "{}", match d { + Delim::Paren => ")", + Delim::Brack => "]", + Delim::Brace => "}", + }), + Token::Lambda => write!(f, "\\"), + Token::Arrow => write!(f, "->"), + + Token::Let => write!(f, "let"), + Token::In => write!(f, "in"), + Token::Func => write!(f, "func"), + Token::Return => write!(f, "return"), + Token::If => write!(f, "if"), + Token::Then => write!(f, "then"), + Token::Else => write!(f, "else"), + } + } +} + +pub type Span = SimpleSpan; + +#[derive(Clone, Debug, PartialEq)] +pub enum Lit<'src> { + Unit, + Bool(bool), + Num(f64), + Str(&'src str), +} + +#[derive(Clone, Debug)] +pub enum UnaryOp { Neg, Not } + +#[derive(Clone, Debug)] +pub enum BinaryOp { + Add, Sub, Mul, Div, Rem, + And, Or, + Eq, Ne, Lt, Le, Gt, Ge, +} + +pub type Spanned = (T, Span); + +// Clone is needed for type checking since the type checking +// algorithm is recursive and sometimes consume the AST. +#[derive(Clone, Debug)] +pub enum Expr<'src> { + Lit(Lit<'src>), + Ident(&'src str), + + Unary(UnaryOp, Spanned>), + Binary(BinaryOp, Spanned>, Spanned>), + + Lambda(Vec<(&'src str, Option)>, Option, Spanned>), + Call(Spanned>, Vec>), + + If { + cond: Spanned>, + t: Spanned>, + f: Spanned>, + }, + Let { + name: &'src str, + ty: Option, + value: Spanned>, + body: Spanned>, + }, + Define { + name: &'src str, + ty: Option, + value: Spanned>, + }, + Block { + exprs: Vec>>, + void: bool, // True if last expression is discarded (ends with semicolon). + }, +} \ No newline at end of file diff --git a/src/parse/mod.rs b/syntax/src/lib.rs similarity index 94% rename from src/parse/mod.rs rename to syntax/src/lib.rs index 0acac3e..907b709 100644 --- a/src/parse/mod.rs +++ b/syntax/src/lib.rs @@ -1,9 +1,11 @@ +pub mod expr; pub mod parser; +pub mod ty; #[cfg(test)] mod tests { use chumsky::prelude::*; - use super::parser::*; + use super::{ expr::*, parser::* }; #[test] fn simple() { diff --git a/src/parse/parser.rs b/syntax/src/parser.rs similarity index 73% rename from src/parse/parser.rs rename to syntax/src/parser.rs index 9dca651..2bd49f1 100644 --- a/src/parse/parser.rs +++ b/syntax/src/parser.rs @@ -1,85 +1,6 @@ -use std::fmt::{ - Display, - Formatter, - self, -}; use chumsky::prelude::*; -use crate::typing::ty::Type; -#[derive(Clone, Debug, PartialEq)] -pub enum Delim { Paren, Brack, Brace } - -// The tokens of the language. -// 'src is the lifetime of the source code string. -#[derive(Clone, Debug, PartialEq)] -pub enum Token<'src> { - Unit, Bool(bool), Num(f64), Str(&'src str), - Ident(&'src str), - - Add, Sub, Mul, Div, Rem, - Eq, Ne, Lt, Gt, Le, Ge, - And, Or, Not, - - Assign, Comma, Colon, Semicolon, - Open(Delim), Close(Delim), - Lambda, Arrow, - - Let, In, Func, Return, If, Then, Else, -} - -impl<'src> Display for Token<'src> { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self { - Token::Unit => write!(f, "()"), - Token::Bool(b) => write!(f, "{}", b), - Token::Num(n) => write!(f, "{}", n), - Token::Str(s) => write!(f, "\"{}\"", s), - Token::Ident(s) => write!(f, "{}", s), - - Token::Add => write!(f, "+"), - Token::Sub => write!(f, "-"), - Token::Mul => write!(f, "*"), - Token::Div => write!(f, "/"), - Token::Rem => write!(f, "%"), - Token::Eq => write!(f, "=="), - Token::Ne => write!(f, "!="), - Token::Lt => write!(f, "<"), - Token::Gt => write!(f, ">"), - Token::Le => write!(f, "<="), - Token::Ge => write!(f, ">="), - Token::And => write!(f, "&&"), - Token::Or => write!(f, "||"), - Token::Not => write!(f, "!"), - - Token::Assign => write!(f, "="), - Token::Comma => write!(f, ","), - Token::Colon => write!(f, ":"), - Token::Semicolon => write!(f, ";"), - Token::Open(d) => write!(f, "{}", match d { - Delim::Paren => "(", - Delim::Brack => "[", - Delim::Brace => "{", - }), - Token::Close(d) => write!(f, "{}", match d { - Delim::Paren => ")", - Delim::Brack => "]", - Delim::Brace => "}", - }), - Token::Lambda => write!(f, "\\"), - Token::Arrow => write!(f, "->"), - - Token::Let => write!(f, "let"), - Token::In => write!(f, "in"), - Token::Func => write!(f, "func"), - Token::Return => write!(f, "return"), - Token::If => write!(f, "if"), - Token::Then => write!(f, "then"), - Token::Else => write!(f, "else"), - } - } -} - -pub type Span = SimpleSpan; +use super::{ expr::*, ty::Type }; pub fn lexer<'src>() -> impl Parser<'src, &'src str, Vec<(Token<'src>, Span)>, extra::Err>> { let num = text::int(10) @@ -159,61 +80,6 @@ pub fn lexer<'src>() -> impl Parser<'src, &'src str, Vec<(Token<'src>, Span)>, e .collect() } -#[derive(Clone, Debug, PartialEq)] -pub enum Lit<'src> { - Unit, - Bool(bool), - Num(f64), - Str(&'src str), -} - -#[derive(Clone, Debug)] -pub enum UnaryOp { Neg, Not } - -#[derive(Clone, Debug)] -pub enum BinaryOp { - Add, Sub, Mul, Div, Rem, - And, Or, - Eq, Ne, Lt, Le, Gt, Ge, -} - -pub type Spanned = (T, Span); - -// Clone is needed for type checking since the type checking -// algorithm is recursive and sometimes consume the AST. -#[derive(Clone, Debug)] -pub enum Expr<'src> { - Lit(Lit<'src>), - Ident(&'src str), - - Unary(UnaryOp, Spanned>), - Binary(BinaryOp, Spanned>, Spanned>), - - Lambda(Vec<(&'src str, Option)>, Spanned>), - Call(Spanned>, Vec>), - - If { - cond: Spanned>, - t: Spanned>, - f: Spanned>, - }, - Let { - name: &'src str, - ty: Option, - value: Spanned>, - body: Spanned>, - }, - Define { - name: &'src str, - ty: Option, - value: Spanned>, - }, - Block { - exprs: Vec>>, - void: bool, // True if last expression is discarded (ends with semicolon). - }, -} - // (a, s) -> (Box::new(a), s) fn boxspan(a: Spanned) -> Spanned> { (Box::new(a.0), a.1) @@ -255,22 +121,26 @@ pub fn expr_parser<'tokens, 'src: 'tokens>() -> impl Parser< ) .map(|e: Spanned| e.0); + // \x : t, y : t -> rt = e let lambda = just(Token::Lambda) .ignore_then( ( - symbol - .then( - just(Token::Colon) - .ignore_then(type_parser()) - .or_not()) - ) - .separated_by(just(Token::Comma)) + symbol.then( + just(Token::Colon) + .ignore_then(type_parser()) + .or_not()) + ).separated_by(just(Token::Comma)) .allow_trailing() .collect::>() ) - .then_ignore(just(Token::Arrow)) + .then( + just(Token::Arrow) + .ignore_then(type_parser()) + .or_not() + ) + .then_ignore(just(Token::Assign)) .then(expr.clone()) - .map(|(args, body)| Expr::Lambda(args, boxspan(body))); + .map(|((args, ret), body)| Expr::Lambda(args, ret, boxspan(body))); // ident (: type)? let bind = symbol @@ -444,7 +314,6 @@ pub fn type_parser<'tokens, 'src: 'tokens>() -> impl Parser< Token::Ident("num") => Type::Num, Token::Ident("str") => Type::Str, Token::Unit => Type::Unit, - Token::Ident(s) => Type::Var(s.to_string()), }; let tys_paren = ty.clone() diff --git a/src/typing/ty.rs b/syntax/src/ty.rs similarity index 72% rename from src/typing/ty.rs rename to syntax/src/ty.rs index b626bea..dace70a 100644 --- a/src/typing/ty.rs +++ b/syntax/src/ty.rs @@ -4,10 +4,10 @@ use std::fmt::{self, Display, Formatter}; #[derive(Clone, Debug, Eq, PartialEq)] pub enum Type { Unit, Bool, Num, Str, + Var(usize), // This type is only used during type inference. Func(Vec, Box), Tuple(Vec), Array(Box), - Var(String), } impl Display for Type { @@ -17,6 +17,7 @@ impl Display for Type { Type::Bool => write!(f, "Bool"), Type::Num => write!(f, "Num"), Type::Str => write!(f, "Str"), + Type::Var(id) => write!(f, "{}", itoa(id)), Type::Func(ref args, ref ret) => { write!(f, "({}", args[0])?; for arg in &args[1..] { @@ -32,7 +33,19 @@ impl Display for Type { write!(f, ")") } Type::Array(ref ty) => write!(f, "[{}]", ty), - Type::Var(ref id) => write!(f, "{}", id), } } +} + +/// Convert a number to a string of lowercase letters +pub fn itoa(i: usize) -> String { + let mut s = String::new(); + let mut i = i; + + while i >= 26 { + s.push((b'a' + (i % 26) as u8) as char); + i /= 26; + } + s.push((b'a' + i as u8) as char); + s } \ No newline at end of file diff --git a/typing/Cargo.toml b/typing/Cargo.toml new file mode 100644 index 0000000..c8c29cd --- /dev/null +++ b/typing/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "typing" +version = "0.1.0" +edition = "2021" + +[dependencies] +chumsky = "1.0.0-alpha.3" +syntax = { path = "../syntax" } diff --git a/typing/src/infer.rs b/typing/src/infer.rs new file mode 100644 index 0000000..8ef7df9 --- /dev/null +++ b/typing/src/infer.rs @@ -0,0 +1,569 @@ +use std::collections::HashMap; +use chumsky::span::SimpleSpan; +use syntax::{ + expr::{ + Lit, UnaryOp, BinaryOp, + Expr, + }, + ty::*, +}; + +use super::typed::TExpr; + +#[derive(Clone, Debug)] +struct Infer<'src> { + env: HashMap<&'src str, Type>, + subst: Vec, + constraints: Vec<(Type, Type)>, +} + +impl<'src> Infer<'src> { + fn new() -> Self { + Infer { + env: HashMap::new(), + subst: Vec::new(), + constraints: Vec::new(), + } + } + + /// Generate a fresh type variable + fn fresh(&mut self) -> Type { + let i = self.subst.len(); + self.subst.push(Type::Var(i)); + Type::Var(i) + } + + /// Get a substitution for a type variable + fn subst(&self, i: usize) -> Option { + self.subst.get(i).cloned() + } + + /// Check if a type variable occurs in a type + fn occurs(&self, i: usize, t: Type) -> bool { + use Type::*; + match t { + Unit | Bool | Num | Str => false, + Var(j) => { + if let Some(t) = self.subst(j) { + if t != Var(j) { + return self.occurs(i, t); + } + } + i == j + }, + Func(args, ret) => { + args.into_iter().any(|t| self.occurs(i, t)) || self.occurs(i, *ret) + }, + Tuple(tys) => tys.into_iter().any(|t| self.occurs(i, t)), + Array(ty) => self.occurs(i, *ty), + } + } + + /// Unify two types + fn unify(&mut self, t1: Type, t2: Type) -> Result<(), String> { + use Type::*; + match (t1, t2) { + // Literal types + (Unit, Unit) + | (Bool, Bool) + | (Num, Num) + | (Str, Str) => Ok(()), + + // Variable + (Var(i), Var(j)) if i == j => Ok(()), // Same variables can be unified + (Var(i), t2) => { + // If the substitution is not the variable itself, + // unify the substitution with t2 + if let Some(t) = self.subst(i) { + if t != Var(i) { + return self.unify(t, t2); + } + } + // If the variable occurs in t2 + if self.occurs(i, t2.clone()) { + return Err(format!("Infinite type: '{} = {}", itoa(i), t2)); + } + // Set the substitution + self.subst[i] = t2; + Ok(()) + }, + (t1, Var(i)) => { + if let Some(t) = self.subst(i) { + if t != Var(i) { + return self.unify(t1, t); + } + } + if self.occurs(i, t1.clone()) { + return Err(format!("Infinite type: '{} = {}", itoa(i), t1)); + } + self.subst[i] = t1; + Ok(()) + }, + + // Function + (Func(a1, r1), Func(a2, r2)) => { + // Check the number of arguments + if a1.len() != a2.len() { + return Err(format!("Function argument mismatch: {} != {}", a1.len(), a2.len())); + } + // Unify the arguments + for (a1, a2) in a1.into_iter().zip(a2.into_iter()) { + self.unify(a1, a2)?; + } + // Unify the return types + self.unify(*r1, *r2) + }, + + // Tuple + (Tuple(t1), Tuple(t2)) => { + // Check the number of elements + if t1.len() != t2.len() { + return Err(format!("Tuple element mismatch: {} != {}", t1.len(), t2.len())); + } + // Unify the elements + for (t1, t2) in t1.into_iter().zip(t2.into_iter()) { + self.unify(t1, t2)?; + } + Ok(()) + }, + + // Array + (Array(t1), Array(t2)) => self.unify(*t1, *t2), + + // The rest will be type mismatch + (t1, t2) => Err(format!("Type mismatch: {} != {}", t1, t2)), + } + } + + /// Solve the constraints by unifying them + fn solve(&mut self) -> Result<(), String> { + for (t1, t2) in self.constraints.clone().into_iter() { + self.unify(t1, t2)?; + } + Ok(()) + } + + /// Substitute the type variables with the substitutions + fn substitute(&mut self, t: Type) -> Type { + use Type::*; + match t { + // Only match any type that can contain type variables + Var(i) => { + if let Some(t) = self.subst(i) { + if t != Var(i) { + return self.substitute(t); + } + } + Var(i) + }, + Func(args, ret) => { + Func( + args.into_iter().map(|t| self.substitute(t)).collect(), + Box::new(self.substitute(*ret)), + ) + }, + Tuple(tys) => Tuple(tys.into_iter().map(|t| self.substitute(t)).collect()), + Array(ty) => Array(Box::new(self.substitute(*ty))), + // The rest will be returned as is + _ => t, + } + } + + /// Find a type variable in (typed) expression and substitute them + fn substitute_texp(&mut self, e: TExpr<'src>) -> TExpr<'src> { + use TExpr::*; + match e { + Lit(_) | Ident(_) => e, + Unary { op, expr: (e, lspan), ret_ty } => { + Unary { + op, + expr: (Box::new(self.substitute_texp(*e)), lspan), + ret_ty, + } + }, + Binary { op, lhs: (lhs, lspan), rhs: (rhs, rspan), ret_ty } => { + let lhst = self.substitute_texp(*lhs); + let rhst = self.substitute_texp(*rhs); + Binary { + op, + lhs: (Box::new(lhst), lspan), + rhs: (Box::new(rhst), rspan), + ret_ty: self.substitute(ret_ty), + } + }, + Lambda { params, body: (body, bspan), ret_ty } => { + let bodyt = self.substitute_texp(*body); + let paramst = params.into_iter() + .map(|(name, ty)| (name, self.substitute(ty))) + .collect::>(); + Lambda { + params: paramst, + body: (Box::new(bodyt), bspan), + ret_ty: self.substitute(ret_ty), + } + }, + Call { func: (func, fspan), args } => { + let funct = self.substitute_texp(*func); + let argst = args.into_iter() + .map(|(arg, span)| (self.substitute_texp(arg), span)) + .collect::>(); + Call { + func: (Box::new(funct), fspan), + args: argst, + } + }, + If { cond: (cond, cspan), t: (t, tspan), f: (f, fspan), br_ty } => { + let condt = self.substitute_texp(*cond); + let tt = self.substitute_texp(*t); + let ft = self.substitute_texp(*f); + If { + cond: (Box::new(condt), cspan), + t: (Box::new(tt), tspan), + f: (Box::new(ft), fspan), + br_ty, + } + }, + Let { name, ty, value: (v, vspan), body: (b, bspan) } => { + let vt = self.substitute_texp(*v); + let bt = self.substitute_texp(*b); + Let { + name, + ty: self.substitute(ty), + value: (Box::new(vt), vspan), + body: (Box::new(bt), bspan), + } + }, + Define { name, ty, value: (v, vspan) } => { + let vt = self.substitute_texp(*v); + Define { + name, + ty: self.substitute(ty), + value: (Box::new(vt), vspan), + } + }, + Block { exprs, void, ret_ty } => { + let exprst = exprs.into_iter() + .map(|(e, span)| (self.substitute_texp(e), span)) + .collect::>(); + Block { + exprs: exprst, + void, + ret_ty, + } + }, + } + } + + /// Infer the type of an expression + fn infer(&mut self, e: Expr<'src>, expected: Type) -> Result, String> { + match e { + // Literal values + // Push the constraint (expected type to be the literal type) and + // return the typed expression + Expr::Lit(l) => { + let t = match l { + Lit::Unit => Type::Unit, + Lit::Bool(_) => Type::Bool, + Lit::Num(_) => Type::Num, + Lit::Str(_) => Type::Str, + }; + self.constraints.push((expected, t)); + Ok(TExpr::Lit(l)) + }, + + // Identifiers + // The same as literals but the type is looked up in the environment + Expr::Ident(ref x) => { + let t = self.env.get(x) + .ok_or(format!("Unbound variable: {}", x))?; + self.constraints.push((expected, t.clone())); + Ok(TExpr::Ident(x.clone())) + } + + // Unary & binary operators + // The type of the left and right hand side are inferred and + // the expected type is determined by the operator + Expr::Unary(op, (expr, espan)) => match op { + // Numeric operators (Num -> Num) + UnaryOp::Neg => { + let et = self.infer(*expr, Type::Num)?; + self.constraints.push((expected, Type::Num)); + Ok(TExpr::Unary { + op, + expr: (Box::new(et), espan), + ret_ty: Type::Num, + }) + }, + // Boolean operators (Bool -> Bool) + UnaryOp::Not => { + let et = self.infer(*expr, Type::Bool)?; + self.constraints.push((expected, Type::Bool)); + Ok(TExpr::Unary { + op, + expr: (Box::new(et), espan), + ret_ty: Type::Bool, + }) + }, + } + Expr::Binary(op, (lhs, lspan), (rhs, rspan)) => match op { + // Numeric operators (Num -> Num -> Num) + BinaryOp::Add + | BinaryOp::Sub + | BinaryOp::Mul + | BinaryOp::Div + | BinaryOp::Rem + => { + let lt = self.infer(*lhs, Type::Num)?; + let rt = self.infer(*rhs, Type::Num)?; + self.constraints.push((expected, Type::Num)); + Ok(TExpr::Binary { + op, + lhs: (Box::new(lt), lspan), + rhs: (Box::new(rt), rspan), + ret_ty: Type::Num, + }) + }, + // Boolean operators (Bool -> Bool -> Bool) + BinaryOp::And + | BinaryOp::Or + => { + let lt = self.infer(*lhs, Type::Bool)?; + let rt = self.infer(*rhs, Type::Bool)?; + self.constraints.push((expected, Type::Bool)); + Ok(TExpr::Binary { + op, + lhs: (Box::new(lt), lspan), + rhs: (Box::new(rt), rspan), + ret_ty: Type::Bool, + }) + }, + // Comparison operators ('a -> 'a -> Bool) + BinaryOp::Eq + | BinaryOp::Ne + | BinaryOp::Lt + | BinaryOp::Le + | BinaryOp::Gt + | BinaryOp::Ge + => { + // Create a fresh type variable and then use it as the + // expected type for both the left and right hand side + // so the type on both side have to be the same + let t = self.fresh(); + let lt = self.infer(*lhs, t.clone())?; + let rt = self.infer(*rhs, t)?; + self.constraints.push((expected, Type::Bool)); + Ok(TExpr::Binary { + op, + lhs: (Box::new(lt), lspan), + rhs: (Box::new(rt), rspan), + ret_ty: Type::Bool, + }) + }, + } + + // Lambda + Expr::Lambda(args, ret, (b, bspan)) => { + // Get the return type or create a fresh type variable + let rt = ret.unwrap_or(self.fresh()); + // Fill in the type of the arguments with a fresh type + let xs = args.into_iter() + .map(|(x, t)| (x, t.unwrap_or(self.fresh()))) + .collect::>(); + + // Create a new environment, and add the arguments to it + // and use the new environment to infer the body + let mut env = self.env.clone(); + xs.clone().into_iter().for_each(|(x, t)| { env.insert(x, t); }); + let mut inf = self.clone(); + inf.env = env; + let bt = inf.infer(*b, rt.clone())?; + + // Add the substitutions & constraints from the body + // if it doesn't already exist + for s in inf.subst { + if !self.subst.contains(&s) { + self.subst.push(s); + } + } + for c in inf.constraints { + if !self.constraints.contains(&c) { + self.constraints.push(c); + } + } + + // Push the constraints + self.constraints.push((expected, Type::Func( + xs.clone().into_iter() + .map(|x| x.1) + .collect(), + Box::new(rt.clone()), + ))); + + Ok(TExpr::Lambda { + params: xs, + body: (Box::new(bt), bspan), + ret_ty: rt, + }) + }, + + // Call + Expr::Call((f, fspan), args) => { + // Generate fresh types for the arguments + let freshes = args.clone().into_iter() + .map(|_| self.fresh()) + .collect::>(); + // Create a function type + let fsig = Type::Func( + freshes.clone(), + Box::new(expected), + ); + // Expect the function to have the function type + let ft = self.infer(*f, fsig)?; + // Infer the arguments + let xs = args.into_iter() + .zip(freshes.into_iter()) + .map(|((x, xspan), t)| { + let xt = self.infer(x, t)?; + Ok((xt, xspan)) + }) + .collect::, String>>()?; + + Ok(TExpr::Call { + func: (Box::new(ft), fspan), + args: xs, + }) + }, + + // If + Expr::If { cond: (c, cspan), t: (t, tspan), f: (f, fspan) } => { + // Condition has to be a boolean + let ct = self.infer(*c, Type::Bool)?; + // The type of the if expression is the same as the + // expected type + let tt = self.infer(*t, expected.clone())?; + let et = self.infer(*f, expected.clone())?; + + Ok(TExpr::If { + cond: (Box::new(ct), cspan), + t: (Box::new(tt), tspan), + f: (Box::new(et), fspan), + br_ty: expected, + }) + }, + + // Let & define + Expr::Let { name, ty, value: (v, vspan), body: (b, bspan) } => { + // Infer the type of the value + let ty = ty.unwrap_or(self.fresh()); + let vt = self.infer(*v, ty.clone())?; + + // Create a new environment and add the binding to it + // and then use the new environment to infer the body + let mut env = self.env.clone(); + env.insert(name.clone(), ty.clone()); + let mut inf = Infer::new(); + inf.env = env; + let bt = inf.infer(*b, expected)?; + + Ok(TExpr::Let { + name, ty, + value: (Box::new(vt), vspan), + body: (Box::new(bt), bspan), + }) + }, + Expr::Define { name, ty, value: (v, vspan) } => { + let ty = ty.unwrap_or(self.fresh()); + let vt = self.infer(*v, ty.clone())?; + self.env.insert(name.clone(), ty.clone()); + + // Define always returns unit + self.constraints.push((expected, Type::Unit)); + + Ok(TExpr::Define { + name, ty, + value: (Box::new(vt), vspan), + }) + }, + + // Block + Expr::Block { exprs, void } => { + // Infer the type of each expression + let xs = exprs.into_iter() + .map(|(x, xspan)| { + let xt = self.infer(*x, expected.clone())?; + Ok((xt, xspan)) + }) + .collect::, String>>()?; + + let ret_ty = if void { + Type::Unit + } else { + expected + }; + + Ok(TExpr::Block { + exprs: xs, + void, ret_ty, + }) + }, + } + } +} + +/// Infer a list of expressions +pub fn infer_exprs(es: Vec<(Expr, SimpleSpan)>) -> (Vec<(TExpr, SimpleSpan)>, String) { + let mut inf = Infer::new(); + // Typed expressions + let mut tes = vec![]; + // Typed expressions without substitutions + let mut tes_nosub = vec![]; + // Errors + let mut errs = vec![]; + + for e in es { + let f = inf.fresh(); + let t = inf.infer(e.0, f).unwrap(); + tes.push(Some((t.clone(), e.1))); + tes_nosub.push((t, e.1)); + + match inf.solve() { + Ok(_) => { + // Substitute the type variables for the solved expressions + tes = tes.into_iter() + .map(|te| match te { + Some((t, s)) => { + Some((inf.substitute_texp(t), s)) + }, + None => None, + }) + .collect(); + }, + Err(e) => { + errs.push(e); + // Replace the expression with None + tes.pop(); + tes.push(None); + }, + } + } + + // Union typed expressions, replacing None with the typed expression without substitutions + // None means that the expression has an error + let mut tes_union = vec![]; + for (te, te_nosub) in tes.into_iter().zip(tes_nosub.into_iter()) { + match te { + Some(t) => { + tes_union.push(t); + }, + None => { + tes_union.push(te_nosub); + }, + } + } + + ( + // Renamer::new().process(tes_union), + tes_union, + errs.join("\n") + ) +} \ No newline at end of file diff --git a/typing/src/lib.rs b/typing/src/lib.rs new file mode 100644 index 0000000..4c01380 --- /dev/null +++ b/typing/src/lib.rs @@ -0,0 +1,3 @@ +pub mod infer; +pub mod rename; +pub mod typed; \ No newline at end of file diff --git a/typing/src/rename.rs b/typing/src/rename.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/typing/typed.rs b/typing/src/typed.rs similarity index 90% rename from src/typing/typed.rs rename to typing/src/typed.rs index d2b2452..93906d7 100644 --- a/src/typing/typed.rs +++ b/typing/src/typed.rs @@ -1,9 +1,11 @@ -use super::ty::Type; -use crate::parse::parser::{ - BinaryOp, - UnaryOp, - Lit, - Spanned, +use syntax::{ + expr::{ + BinaryOp, + UnaryOp, + Lit, + Spanned, + }, + ty::Type, }; // Typed version of the expression.