diff --git a/lang/README.md b/lang/README.md index 81d228b4..80853dc2 100644 --- a/lang/README.md +++ b/lang/README.md @@ -264,6 +264,26 @@ main := fn(): uint { } ``` +#### unions +```hb +Union := union { + i: u32, + f: f32, + + $sconst := 0 + + $new := fn(i: u32): Self { + return .{i} + } +} + +main := fn(): uint { + v := Union.{f: 0} + u := Union.new(Union.sconst) + return v.i + u.i +} +``` + #### nullable_types ```hb main := fn(): uint { diff --git a/lang/src/fmt.rs b/lang/src/fmt.rs index 4ef1508f..b72d6f0d 100644 --- a/lang/src/fmt.rs +++ b/lang/src/fmt.rs @@ -70,7 +70,7 @@ fn token_group(kind: TokenKind) -> TokenGroup { | ShrAss | ShlAss => TG::Assign, DQuote | Quote => TG::String, Slf | Defer | Return | If | Else | Loop | Break | Continue | Fn | Idk | Die | Struct - | Packed | True | False | Null | Match | Enum => TG::Keyword, + | Packed | True | False | Null | Match | Enum | Union => TG::Keyword, } } @@ -345,6 +345,17 @@ impl<'a> Formatter<'a> { }, ) } + Expr::Union { fields, trailing_comma, .. } => self.fmt_fields( + f, + "union", + trailing_comma, + fields, + |s, StructField { name, ty, .. }, f| { + f.write_str(name)?; + f.write_str(": ")?; + s.fmt(ty, f) + }, + ), Expr::Enum { variants, trailing_comma, .. } => self.fmt_fields( f, "enum", diff --git a/lang/src/lexer.rs b/lang/src/lexer.rs index c41d442c..0c621bb6 100644 --- a/lang/src/lexer.rs +++ b/lang/src/lexer.rs @@ -148,6 +148,7 @@ pub enum TokenKind { Struct, Packed, Enum, + Union, True, False, Null, @@ -316,6 +317,7 @@ gen_token_kind! { Struct = b"struct", Packed = b"packed", Enum = b"enum", + Union = b"union", True = b"true", False = b"false", Null = b"null", diff --git a/lang/src/parser.rs b/lang/src/parser.rs index 40a79dbf..5e2b7cf9 100644 --- a/lang/src/parser.rs +++ b/lang/src/parser.rs @@ -386,6 +386,23 @@ impl<'a, 'b> Parser<'a, 'b> { captured: self.collect_captures(prev_boundary, prev_captured), trailing_comma: core::mem::take(&mut self.trailing_sep) || must_trail, }, + T::Union => E::Union { + pos, + fields: self.collect_fields(&mut must_trail, |s| { + if s.lexer.taste().kind != T::Colon { + return Some(None); + } + let name = s.expect_advance(T::Ident)?; + s.expect_advance(T::Colon)?; + Some(Some(StructField { + pos: name.start, + name: s.tok_str(name), + ty: s.expr()?, + })) + })?, + captured: self.collect_captures(prev_boundary, prev_captured), + trailing_comma: core::mem::take(&mut self.trailing_sep) || must_trail, + }, T::Enum => E::Enum { pos, variants: self.collect_fields(&mut must_trail, |s| { @@ -977,6 +994,13 @@ generate_expr! { trailing_comma: bool, packed: bool, }, + /// `'union' LIST('{', ',', '}', Ident ':' Expr)` + Union { + pos: Pos, + fields: FieldList<'a, StructField<'a>>, + captured: &'a [Ident], + trailing_comma: bool, + }, /// `'enum' LIST('{', ',', '}', Ident)` Enum { pos: Pos, diff --git a/lang/src/son.rs b/lang/src/son.rs index 88b312c1..27cc4a50 100644 --- a/lang/src/son.rs +++ b/lang/src/son.rs @@ -16,7 +16,7 @@ use { ty::{ self, Arg, ArrayLen, CompState, ConstData, EnumData, EnumField, FTask, FuncData, GlobalData, Loc, Module, Offset, OffsetIter, OptLayout, Sig, StringRef, StructData, - StructField, SymKey, Tuple, TypeBase, TypeIns, Types, + StructField, SymKey, Tuple, TypeBase, TypeIns, Types, UnionData, }, utils::{BitSet, Ent, Vc}, Ident, @@ -3701,64 +3701,94 @@ impl<'a> Codegen<'a> { .or(ctx.ty.map(|ty| self.tys.inner_of(ty).unwrap_or(ty))); inference!(sty, ctx, self, pos, "struct", ".{...}"); - let ty::Kind::Struct(s) = sty.expand() else { - let inferred = if ty.is_some() { "" } else { "inferred " }; - self.error( + match sty.expand() { + ty::Kind::Union(u) => { + let &[CtorField { pos: fpos, name, value }] = fields else { + return self.error( + pos, + fa!("union initializer needs to have exactly one field"), + ); + }; + + let mem = self.new_stack(pos, sty); + + let Some(index) = self.tys.find_union_field(u, name) else { + self.error( + fpos, + fa!("union '{}' does not have this field", self.ty_display(sty)), + ); + return Value::NEVER; + }; + + let (ty, offset) = (self.tys.union_fields(u)[index].ty, 0); + + let mut value = self.expr_ctx(&value, Ctx::default().with_ty(ty))?; + self.assert_ty(fpos, &mut value, ty, fa!("field {}", name)); + let mem = self.offset(mem, offset); + self.store_mem(mem, ty, value.id); + + Some(Value::ptr(mem).ty(sty)) + } + ty::Kind::Struct(s) => { + let mut offs = OffsetIter::new(s, self.tys) + .into_iter(self.tys) + .map(|(f, o)| (f.ty, o)) + .collect::>(); + let mem = self.new_stack(pos, sty); + for field in fields { + let Some(index) = self.tys.find_struct_field(s, field.name) else { + self.error( + field.pos, + fa!( + "struct '{}' does not have this field", + self.ty_display(sty) + ), + ); + continue; + }; + + let (ty, offset) = + mem::replace(&mut offs[index], (ty::Id::UNDECLARED, field.pos)); + + if ty == ty::Id::UNDECLARED { + self.error(field.pos, "the struct field is already initialized"); + self.error(offset, "previous initialization is here"); + continue; + } + + let mut value = + self.expr_ctx(&field.value, Ctx::default().with_ty(ty))?; + self.assert_ty(field.pos, &mut value, ty, fa!("field {}", field.name)); + let mem = self.offset(mem, offset); + self.store_mem(mem, ty, value.id); + } + + let field_list = self + .tys + .struct_fields(s) + .iter() + .zip(offs) + .filter(|&(_, (ty, _))| ty != ty::Id::UNDECLARED) + .map(|(f, _)| self.tys.names.ident_str(f.name)) + .intersperse(", ") + .collect::(); + + if !field_list.is_empty() { + self.error(pos, fa!("the struct initializer is missing {field_list}")); + } + + Some(Value::ptr(mem).ty(sty)) + } + _ => self.error( pos, fa!( - "the {inferred}type of the constructor is `{}`, \ - but thats not a struct", + "the {}type of the constructor is `{}`, \ + but thats not a struct or union", + if ty.is_some() { "" } else { "inferred " }, self.ty_display(sty) ), - ); - return Value::NEVER; - }; - - // TODO: dont allocate - let mut offs = OffsetIter::new(s, self.tys) - .into_iter(self.tys) - .map(|(f, o)| (f.ty, o)) - .collect::>(); - let mem = self.new_stack(pos, sty); - for field in fields { - let Some(index) = self.tys.find_struct_field(s, field.name) else { - self.error( - field.pos, - fa!("struct '{}' does not have this field", self.ty_display(sty)), - ); - continue; - }; - - let (ty, offset) = - mem::replace(&mut offs[index], (ty::Id::UNDECLARED, field.pos)); - - if ty == ty::Id::UNDECLARED { - self.error(field.pos, "the struct field is already initialized"); - self.error(offset, "previous initialization is here"); - continue; - } - - let mut value = self.expr_ctx(&field.value, Ctx::default().with_ty(ty))?; - self.assert_ty(field.pos, &mut value, ty, fa!("field {}", field.name)); - let mem = self.offset(mem, offset); - self.store_mem(mem, ty, value.id); + ), } - - let field_list = self - .tys - .struct_fields(s) - .iter() - .zip(offs) - .filter(|&(_, (ty, _))| ty != ty::Id::UNDECLARED) - .map(|(f, _)| self.tys.names.ident_str(f.name)) - .intersperse(", ") - .collect::(); - - if !field_list.is_empty() { - self.error(pos, fa!("the struct initializer is missing {field_list}")); - } - - Some(Value::ptr(mem).ty(sty)) } Expr::Block { stmts, .. } => { let base = self.ci.scope.vars.len(); @@ -4216,6 +4246,37 @@ impl<'a> Codegen<'a> { let intrnd = self.tys.names.project(name); self.gen_enum_variant(pos, e, intrnd) } + ty::Kind::Union(u) => { + let TypeBase { ast, file, .. } = *self.tys.ins.unions[u]; + if let Some(f) = + self.tys.find_union_field(u, name).map(|i| &self.tys.union_fields(u)[i]) + { + Some(Value::ptr(vtarget.id).ty(f.ty)) + } else if let Expr::Struct { fields: [.., CommentOr::Or(Err(_))], .. } = + ast.get(&self.files[file.index()]) + && let ty = self.find_type(pos, self.ci.file, file, u.into(), Err(name)) + && let ty::Kind::Func(_) = ty.expand() + { + return Some(Err((ty, vtarget))); + } else { + let field_list = self + .tys + .union_fields(u) + .iter() + .map(|f| self.tys.names.ident_str(f.name)) + .intersperse("', '") + .collect::(); + self.error( + pos, + fa!( + "the '{}' does not have this field, \ + but it does have '{field_list}'", + self.ty_display(tty) + ), + ); + Value::NEVER + } + } ty::Kind::Struct(s) => { let TypeBase { ast, file, .. } = *self.tys.ins.structs[s]; if let Some((offset, ty)) = OffsetIter::offset_of(self.tys, s, name) { @@ -4250,6 +4311,10 @@ impl<'a> Codegen<'a> { let TypeBase { file, .. } = *self.tys.ins.structs[s]; self.find_type_as_value(pos, file, s, Err(name), ctx) } + ty::Kind::Union(u) => { + let TypeBase { file, .. } = *self.tys.ins.unions[u]; + self.find_type_as_value(pos, file, u, Err(name), ctx) + } ty::Kind::Module(m) => self.find_type_as_value(pos, m, m, Err(name), ctx), ty::Kind::Enum(e) => { let intrnd = self.tys.names.project(name); @@ -5700,6 +5765,19 @@ impl<'a> Codegen<'a> { |s, field| EnumField { name: s.tys.names.intern(field.name) }, |s, base| s.ins.enums.push(EnumData { base }), ), + Expr::Union { pos, fields, captured, .. } => self.parse_base_ty( + pos, + expr, + captured, + fields, + sc, + |s| [&mut s.ins.struct_fields, &mut s.tmp.struct_fields], + |s, field| { + let ty = s.parse_ty(sc.anon(), &field.ty); + StructField { name: s.tys.names.intern(field.name), ty } + }, + |s, base| s.ins.unions.push(UnionData { base, ..Default::default() }), + ), Expr::Closure { pos, args, ret, .. } if let Some(name) = sc.name => { let func = FuncData { file: sc.file, @@ -5871,6 +5949,7 @@ mod tests { structs; struct_scopes; enums; + unions; nullable_types; struct_operators; global_variables; diff --git a/lang/src/ty.rs b/lang/src/ty.rs index 31c5d427..afd14085 100644 --- a/lang/src/ty.rs +++ b/lang/src/ty.rs @@ -14,6 +14,7 @@ use { }, hashbrown::hash_map, }; + macro_rules! impl_deref { ($for:ty { $name:ident: $base:ty }) => { impl Deref for $for { @@ -33,11 +34,8 @@ macro_rules! impl_deref { } pub type ArrayLen = u32; - -impl Func { - pub const ECA: Func = Func(u32::MAX); - pub const MAIN: Func = Func(u32::MIN); -} +pub type Offset = u32; +pub type Size = u32; #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default, PartialOrd, Ord)] pub struct Tuple(pub u32); @@ -128,6 +126,11 @@ impl crate::ctx_map::CtxEntry for Id { debug_assert_ne!(en.pos, Pos::MAX); SymKey::Type(en.file, en.pos, en.captured) } + Kind::Union(e) => { + let en = &ctx.unions[e]; + debug_assert_ne!(en.pos, Pos::MAX); + SymKey::Type(en.file, en.pos, en.captured) + } Kind::Ptr(p) => SymKey::Pointer(&ctx.ptrs[p]), Kind::Opt(p) => SymKey::Optional(&ctx.opts[p]), Kind::Func(f) => { @@ -266,8 +269,8 @@ impl Id { Loc::Reg } Kind::Ptr(_) | Kind::Enum(_) | Kind::Builtin(_) => Loc::Reg, - Kind::Struct(_) if tys.size_of(*self) == 0 => Loc::Reg, - Kind::Struct(_) | Kind::Slice(_) | Kind::Opt(_) => Loc::Stack, + Kind::Struct(_) | Kind::Union(_) if tys.size_of(*self) == 0 => Loc::Reg, + Kind::Struct(_) | Kind::Union(_) | Kind::Slice(_) | Kind::Opt(_) => Loc::Stack, c @ (Kind::Func(_) | Kind::Global(_) | Kind::Module(_) | Kind::Const(_)) => { unreachable!("{c:?}") } @@ -431,6 +434,7 @@ type_kind! { Builtin, Struct, Enum, + Union, Ptr, Slice, Opt, @@ -441,6 +445,11 @@ type_kind! { } } +impl Func { + pub const ECA: Func = Func(u32::MAX); + pub const MAIN: Func = Func(u32::MIN); +} + impl Module { pub const MAIN: Self = Self(0); } @@ -527,6 +536,28 @@ impl core::fmt::Display for Display<'_> { f.write_str(file.ident_str(record.name)) } } + TK::Union(idx) => { + let record = &self.tys.ins.unions[idx]; + if record.name.is_null() { + f.write_str("[")?; + idx.fmt(f)?; + f.write_str("]{")?; + for (i, &StructField { name, ty }) in + self.tys.union_fields(idx).iter().enumerate() + { + if i != 0 { + f.write_str(", ")?; + } + f.write_str(self.tys.names.ident_str(name))?; + f.write_str(": ")?; + self.rety(ty).fmt(f)?; + } + f.write_str("}") + } else { + let file = &self.files[record.file.index()]; + f.write_str(file.ident_str(record.name)) + } + } TK::Enum(idx) => { let enm = &self.tys.ins.enums[idx]; debug_assert!(!enm.name.is_null()); @@ -563,9 +594,6 @@ impl core::fmt::Display for Display<'_> { } } -pub type Offset = u32; -pub type Size = u32; - #[derive(PartialEq, Eq, Hash, Clone, Copy)] pub enum SymKey<'a> { Pointer(&'a PtrData), @@ -642,6 +670,15 @@ pub struct EnumData { impl_deref!(EnumData { base: TypeBase }); +#[derive(Default)] +pub struct UnionData { + pub base: TypeBase, + pub size: Cell, + pub align: Cell, +} + +impl_deref!(UnionData { base: TypeBase }); + pub struct StructField { pub name: Ident, pub ty: Id, @@ -741,6 +778,7 @@ pub struct TypeIns { pub consts: EntVec, pub structs: EntVec, pub enums: EntVec, + pub unions: EntVec, pub ptrs: EntVec, pub opts: EntVec, pub slices: EntVec, @@ -779,6 +817,7 @@ impl Types { Kind::NEVER => |_| Ok(()), Kind::Enum(_) | Kind::Struct(_) + | Kind::Union(_) | Kind::Builtin(_) | Kind::Ptr(_) | Kind::Slice(_) @@ -804,6 +843,20 @@ impl Types { Tuple::new(sp, len) } + pub fn union_fields(&self, union: Union) -> &[StructField] { + &self.ins.struct_fields[self.union_field_range(union)] + } + + fn union_field_range(&self, union: Union) -> Range { + let start = self.ins.unions[union].field_start as usize; + let end = self + .ins + .unions + .next(union) + .map_or(self.ins.struct_fields.len(), |s| s.field_start as usize); + start..end + } + pub fn struct_fields(&self, strct: Struct) -> &[StructField] { &self.ins.struct_fields[self.struct_field_range(strct)] } @@ -870,6 +923,16 @@ impl Types { self.ins.structs[stru].size.set(oiter.offset); oiter.offset } + Kind::Union(union) => { + if self.ins.unions[union].size.get() != 0 { + return self.ins.unions[union].size.get(); + } + + let size = + self.union_fields(union).iter().map(|f| self.size_of(f.ty)).max().unwrap_or(0); + self.ins.unions[union].size.set(size); + size + } Kind::Enum(enm) => (self.enum_field_range(enm).len().ilog2() + 7) / 8, Kind::Opt(opt) => { let base = self.ins.opts[opt].base; @@ -886,6 +949,15 @@ impl Types { pub fn align_of(&self, ty: Id) -> Size { match ty.expand() { + Kind::Union(union) => { + if self.ins.unions[union].align.get() != 0 { + return self.ins.unions[union].align.get() as _; + } + let align = + self.union_fields(union).iter().map(|f| self.align_of(f.ty)).max().unwrap_or(1); + self.ins.unions[union].align.set(align.try_into().unwrap()); + align + } Kind::Struct(stru) => { if self.ins.structs[stru].align.get() != 0 { return self.ins.structs[stru].align.get() as _; @@ -957,6 +1029,11 @@ impl Types { self.struct_fields(s).iter().position(|f| f.name == name) } + pub fn find_union_field(&self, u: Union, name: &str) -> Option { + let name = self.names.project(name)?; + self.union_fields(u).iter().position(|f| f.name == name) + } + pub fn clear(&mut self) { self.syms.clear(); self.names.clear(); @@ -976,56 +1053,54 @@ impl Types { debug_assert_eq!(self.tasks.len(), 0); } + fn type_base_of(&self, id: Id) -> Option<&TypeBase> { + Some(match id.expand() { + Kind::Struct(s) => &*self.ins.structs[s], + Kind::Enum(e) => &*self.ins.enums[e], + Kind::Union(e) => &*self.ins.unions[e], + Kind::Builtin(_) + | Kind::Ptr(_) + | Kind::Slice(_) + | Kind::Opt(_) + | Kind::Func(_) + | Kind::Global(_) + | Kind::Module(_) + | Kind::Const(_) => return None, + }) + } + pub fn scope_of<'a>(&self, parent: Id, file: &'a parser::Ast) -> Option<&'a [Expr<'a>]> { - match parent.expand() { - Kind::Struct(s) => { - if let Expr::Struct { fields: [.., CommentOr::Or(Err(scope))], .. } = - self.ins.structs[s].ast.get(file) - { - Some(scope) - } else { - Some(&[]) - } - } - Kind::Enum(e) => { - if let Expr::Enum { variants: [.., CommentOr::Or(Err(scope))], .. } = - self.ins.enums[e].ast.get(file) - { - Some(scope) - } else { - Some(&[]) - } - } - Kind::Module(_) => Some(file.exprs()), - _ => None, + let base = match parent.expand() { + _ if let Some(base) = self.type_base_of(parent) => base, + Kind::Module(_) => return Some(file.exprs()), + _ => return None, + }; + + if let Expr::Struct { fields: [.., CommentOr::Or(Err(scope))], .. } + | Expr::Union { fields: [.., CommentOr::Or(Err(scope))], .. } + | Expr::Enum { variants: [.., CommentOr::Or(Err(scope))], .. } = base.ast.get(file) + { + Some(scope) + } else { + Some(&[]) } } pub fn parent_of(&self, ty: Id) -> Option { - match ty.expand() { - Kind::Struct(s) => Some(self.ins.structs[s].parent), - Kind::Enum(e) => Some(self.ins.enums[e].parent), - _ => None, - } + self.type_base_of(ty).map(|b| b.parent) } pub fn captures_of<'a>(&self, ty: Id, file: &'a parser::Ast) -> Option<(&'a [Ident], Tuple)> { - match ty.expand() { - Kind::Struct(s) => { - let &Expr::Struct { captured, .. } = self.ins.structs[s].ast.get(file) else { - unreachable!() - }; - Some((captured, self.ins.structs[s].captured)) - } - Kind::Enum(e) => { - let &Expr::Enum { captured, .. } = self.ins.enums[e].ast.get(file) else { - unreachable!() - }; - Some((captured, self.ins.enums[e].captured)) - } - _ => None, - } - .inspect(|(a, b)| debug_assert_eq!(a.len(), b.len())) + let base = self.type_base_of(ty)?; + + let (Expr::Struct { captured, .. } + | Expr::Enum { captured, .. } + | Expr::Union { captured, .. }) = *base.ast.get(file) + else { + unreachable!() + }; + debug_assert_eq!(captured.len(), base.captured.len()); + Some((captured, base.captured)) } } diff --git a/lang/tests/son_tests_unions.txt b/lang/tests/son_tests_unions.txt new file mode 100644 index 00000000..a68fed22 --- /dev/null +++ b/lang/tests/son_tests_unions.txt @@ -0,0 +1,11 @@ +main: + ADDI64 r254, r254, -4d + ST r0, r254, 0a, 4h + LD r13, r254, 0a, 4h + ANDI r13, r13, 4294967295d + CP r1, r13 + ADDI64 r254, r254, 4d + JALA r0, r31, 0a +code size: 81 +ret: 0 +status: Ok(())