generic functions work now

This commit is contained in:
Jakub Doka 2024-10-22 07:20:08 +02:00
parent 35d34dca54
commit 2aa5ba9abc
No known key found for this signature in database
GPG key ID: C6E9A89936B8C143
15 changed files with 238 additions and 49 deletions

View file

@ -398,6 +398,15 @@ main := fn(): int {
}
```
#### generic_functions
```hb
add := fn($T: type, a: T, b: T): T return a + b
main := fn(): int {
return add(u32, 2, 2) - add(int, 1, 3)
}
```
### Incomplete Examples
#### comptime_pointers
@ -479,15 +488,6 @@ main := fn(): int {
}
```
#### generic_functions
```hb
add := fn($T: type, a: T, b: T): T return a + b
main := fn(): int {
return add(u32, 2, 2) - add(int, 1, 3)
}
```
#### fb_driver
```hb
arm_fb_ptr := fn(): int return 100
@ -859,15 +859,15 @@ request_page := fn(page_count: u8): ^u8 {
create_back_buffer := fn(total_pages: int): ^u32 {
if total_pages <= 0xFF {
return @bitcast(request_page(total_pages))
return @bitcast(request_page(@intcast(total_pages)))
}
ptr := request_page(255)
remaining := total_pages - 0xFF
loop if remaining <= 0 break else {
if remaining < 0xFF {
request_page(remaining)
_f := request_page(@intcast(remaining))
} else {
request_page(0xFF)
_f := request_page(0xFF)
}
remaining -= 0xFF
}
@ -875,7 +875,7 @@ create_back_buffer := fn(total_pages: int): ^u32 {
}
main := fn(): void {
create_back_buffer(400)
_f := create_back_buffer(400)
return
}
```

View file

@ -281,6 +281,7 @@ impl TokenKind {
Self::Ne => (a != b) as i64,
Self::Band => a & b,
Self::Bor => a | b,
Self::Xor => a ^ b,
Self::Mod => a % b,
Self::Shr => a >> b,
s => todo!("{s}"),

View file

@ -12,7 +12,7 @@ use {
reg, task,
ty::{self, ArrayLen, Loc, Tuple},
vc::{BitSet, Vc},
Comptime, Func, Global, HashMap, Offset, OffsetIter, PLoc, Reloc, Sig, TypeParser,
Comptime, Func, Global, HashMap, Offset, OffsetIter, PLoc, Reloc, Sig, SymKey, TypeParser,
TypedReloc, Types,
},
alloc::{string::String, vec::Vec},
@ -25,6 +25,7 @@ use {
},
hashbrown::hash_map,
regalloc2::VReg,
std::panic,
};
const VOID: Nid = 0;
@ -211,6 +212,8 @@ impl Nodes {
return false;
}
debug_assert!(!matches!(self[target].kind, Kind::Call { .. }));
for i in 0..self[target].inputs.len() {
let inp = self[target].inputs[i];
let index = self[inp].outputs.iter().position(|&p| p == target).unwrap();
@ -1053,8 +1056,8 @@ impl ItemCtx {
}
fn finalize(&mut self) {
self.nodes.unlock(NEVER);
self.nodes.unlock_remove_scope(&core::mem::take(&mut self.scope));
self.nodes.unlock(NEVER);
self.nodes.unlock(MEM);
}
@ -1064,7 +1067,6 @@ impl ItemCtx {
fn emit_body_code(&mut self, sig: Sig, tys: &Types) -> usize {
let mut nodes = core::mem::take(&mut self.nodes);
nodes.visited.clear(nodes.values.len());
let fuc = Function::new(&mut nodes, tys, sig);
let mut ralloc = Regalloc::default(); // TODO: reuse
@ -1078,7 +1080,7 @@ impl ItemCtx {
let options = regalloc2::RegallocOptions {
verbose_log: false,
validate_ssa: false,
validate_ssa: cfg!(debug_assertions),
algorithm: regalloc2::Algorithm::Ion,
};
regalloc2::run_with_ctx(&fuc, &ralloc.env, &options, &mut ralloc.ctx)
@ -1154,7 +1156,8 @@ impl ItemCtx {
Kind::If => {
let &[_, cnd] = node.inputs.as_slice() else { unreachable!() };
if let Kind::BinOp { op } = fuc.nodes[cnd].kind
&& let Some((op, swapped)) = op.cond_op(node.ty.is_signed())
&& let Some((op, swapped)) =
op.cond_op(fuc.nodes[fuc.nodes[cnd].inputs[1]].ty.is_signed())
{
let &[lhs, rhs] = allocs else { unreachable!() };
let &[_, lh, rh] = fuc.nodes[cnd].inputs.as_slice() else {
@ -1742,8 +1745,8 @@ impl TypeParser for Codegen<'_> {
ty::Id::NEVER
}
fn find_local_ty(&mut self, _: Ident) -> Option<ty::Id> {
None
fn find_local_ty(&mut self, ident: Ident) -> Option<ty::Id> {
self.ci.scope.vars.iter().rfind(|v| (v.id == ident && v.value == NEVER)).map(|v| v.ty)
}
}
@ -2220,7 +2223,7 @@ impl<'a> Codegen<'a> {
Expr::Call { func, args, .. } => {
self.ci.call_count += 1;
let ty = self.ty(func);
let ty::Kind::Func(fu) = ty.expand() else {
let ty::Kind::Func(mut fu) = ty.expand() else {
self.report(
func.pos(),
fa!("compiler cant (yet) call '{}'", self.ty_display(ty)),
@ -2228,10 +2231,12 @@ impl<'a> Codegen<'a> {
return Value::NEVER;
};
let Some(sig) = self.compute_signature(&mut fu, func.pos(), args) else {
return Value::NEVER;
};
self.make_func_reachable(fu);
let fuc = &self.tys.ins.funcs[fu as usize];
let sig = fuc.sig.expect("TODO: generic functions");
let ast = &self.files[fuc.file as usize];
let &Expr::Closure { args: cargs, .. } = fuc.expr.get(ast) else { unreachable!() };
@ -2248,8 +2253,13 @@ impl<'a> Codegen<'a> {
}
let mut inps = Vc::from([self.ci.ctrl]);
for ((arg, carg), tyx) in args.iter().zip(cargs).zip(sig.args.range()) {
let ty = self.tys.ins.args[tyx];
let mut tys = sig.args.range();
for (arg, carg) in args.iter().zip(cargs) {
let ty = self.tys.ins.args[tys.next().unwrap()];
if ty == ty::Id::TYPE {
tys.next().unwrap();
continue;
}
let mut value = self.expr_ctx(arg, Ctx::default().with_ty(ty))?;
debug_assert_ne!(self.ci.nodes[value.id].kind, Kind::Stre);
self.assert_ty(arg.pos(), &mut value, ty, fa!("argument {}", carg.name));
@ -2258,6 +2268,7 @@ impl<'a> Codegen<'a> {
inps.push(value.id);
}
let argc = inps.len() as u32 - 1;
for &n in inps.iter().skip(1) {
self.ci.nodes.unlock(n);
}
@ -2286,7 +2297,6 @@ impl<'a> Codegen<'a> {
}
};
let argc = args.len() as u32;
self.ci.ctrl = self.ci.nodes.new_node(sig.ret, Kind::Call { func: fu, argc }, inps);
self.store_mem(VOID, VOID);
@ -2295,7 +2305,7 @@ impl<'a> Codegen<'a> {
}
Expr::Directive { name: "inline", args: [func, args @ ..], .. } => {
let ty = self.ty(func);
let ty::Kind::Func(fu) = ty.expand() else {
let ty::Kind::Func(mut fu) = ty.expand() else {
self.report(
func.pos(),
fa!(
@ -2307,8 +2317,11 @@ impl<'a> Codegen<'a> {
return Value::NEVER;
};
let Some(sig) = self.compute_signature(&mut fu, func.pos(), args) else {
return Value::NEVER;
};
let fuc = &self.tys.ins.funcs[fu as usize];
let sig = fuc.sig.expect("TODO: generic functions");
let ast = &self.files[fuc.file as usize];
let &Expr::Closure { args: cargs, body, .. } = fuc.expr.get(ast) else {
unreachable!()
@ -2327,11 +2340,20 @@ impl<'a> Codegen<'a> {
}
let arg_base = self.ci.scope.vars.len();
for ((arg, carg), tyx) in args.iter().zip(cargs).zip(sig.args.range()) {
let ty = self.tys.ins.args[tyx];
if self.tys.size_of(ty) == 0 {
let mut sig_args = sig.args.range();
for (arg, carg) in args.iter().zip(cargs) {
let ty = self.tys.ins.args[sig_args.next().unwrap()];
if ty == ty::Id::TYPE {
self.ci.scope.vars.push(Variable {
id: carg.id,
ty: self.tys.ins.args[sig_args.next().unwrap()],
ptr: false,
value: NEVER,
});
self.ci.nodes.lock(NEVER);
continue;
}
let mut value = self.expr_ctx(arg, Ctx::default().with_ty(ty))?;
debug_assert_ne!(self.ci.nodes[value.id].kind, Kind::Stre);
debug_assert_ne!(value.id, 0);
@ -2731,6 +2753,69 @@ impl<'a> Codegen<'a> {
}
}
fn compute_signature(&mut self, func: &mut ty::Func, pos: Pos, args: &[Expr]) -> Option<Sig> {
let fuc = &self.tys.ins.funcs[*func as usize];
let fast = self.files[fuc.file as usize].clone();
let &Expr::Closure { args: cargs, ret, .. } = fuc.expr.get(&fast) else {
unreachable!();
};
Some(if let Some(sig) = fuc.sig {
sig
} else {
let arg_base = self.tys.tmp.args.len();
let base = self.ci.scope.vars.len();
for (arg, carg) in args.iter().zip(cargs) {
let ty = self.ty(&carg.ty);
self.tys.tmp.args.push(ty);
let sym = parser::find_symbol(&fast.symbols, carg.id);
let ty = if sym.flags & idfl::COMPTIME == 0 {
// FIXME: could fuck us
ty::Id::UNDECLARED
} else {
debug_assert_eq!(
ty,
ty::Id::TYPE,
"TODO: we dont support anything except type generics"
);
let ty = self.ty(arg);
self.tys.tmp.args.push(ty);
ty
};
self.ci.scope.vars.push(Variable { id: carg.id, ty, ptr: false, value: NEVER });
}
let Some(args) = self.tys.pack_args(arg_base) else {
self.report(pos, "function instance has too many arguments");
return None;
};
let ret = self.ty(ret);
self.ci.scope.vars.truncate(base);
let sym = SymKey::FuncInst(*func, args);
let ct = |ins: &mut crate::TypeIns| {
let func_id = ins.funcs.len();
let fuc = &ins.funcs[*func as usize];
ins.funcs.push(Func {
file: fuc.file,
name: fuc.name,
base: Some(*func),
sig: Some(Sig { args, ret }),
expr: fuc.expr,
..Default::default()
});
ty::Kind::Func(func_id as _).compress()
};
*func = self.tys.syms.get_or_insert(sym, &mut self.tys.ins, ct).expand().inner();
Sig { args, ret }
})
}
fn assign_pattern(&mut self, pat: &Expr, right: Value) {
match *pat {
Expr::Ident { id, .. } => {
@ -2895,14 +2980,22 @@ impl<'a> Codegen<'a> {
let mut sig_args = sig.args.range();
for arg in args.iter() {
let ty = self.tys.ins.args[sig_args.next().unwrap()];
if ty == ty::Id::TYPE {
self.ci.scope.vars.push(Variable {
id: arg.id,
ty: self.tys.ins.args[sig_args.next().unwrap()],
ptr: false,
value: NEVER,
});
self.ci.nodes.lock(NEVER);
continue;
}
let mut deps = Vc::from([VOID]);
if ty.loc(&self.tys) == Loc::Stack {
deps.push(MEM);
}
let value = self.ci.nodes.new_node_nop(ty, Kind::Arg, [VOID]);
self.ci.nodes.lock(value);
let sym = parser::find_symbol(&ast.symbols, arg.id);
assert!(sym.flags & idfl::COMPTIME == 0, "TODO");
let ptr = self.tys.size_of(ty) > 8;
self.ci.scope.vars.push(Variable { id: arg.id, value, ty, ptr });
}
@ -3355,7 +3448,10 @@ impl<'a> Function<'a> {
self.add_instr(nid, ops);
for o in node.outputs.into_iter().rev() {
if self.nodes[o].inputs[0] == nid {
if self.nodes[o].inputs[0] == nid
|| (matches!(self.nodes[o].kind, Kind::Loop | Kind::Region)
&& self.nodes[o].inputs[1] == nid)
{
self.emit_node(o, nid);
}
}
@ -3819,11 +3915,11 @@ mod tests {
arrays;
inline;
idk;
generic_functions;
// Incomplete Examples;
//comptime_pointers;
//generic_types;
//generic_functions;
fb_driver;
// Purely Testing Examples;
@ -3835,8 +3931,8 @@ mod tests {
structs_in_registers;
comptime_function_from_another_file;
inline_test;
//inlined_generic_functions;
//some_generic_code;
inlined_generic_functions;
some_generic_code;
integer_inference_issues;
writing_into_string;
request_page;

View file

@ -3,7 +3,7 @@ continue_and_state_change:
LI64 r8, 4d
LI64 r9, 2d
LI64 r10, 10d
6: JLTU r2, r10, :0
6: JLTS r2, r10, :0
CP r1, r2
JMP :1
0: JNE r2, r9, :2
@ -68,7 +68,7 @@ main:
multiple_breaks:
LI64 r6, 3d
LI64 r5, 10d
4: JLTU r2, r5, :0
4: JLTS r2, r5, :0
CP r1, r2
JMP :1
0: ADDI64 r1, r2, 1d
@ -80,7 +80,7 @@ multiple_breaks:
state_change_in_break:
LI64 r5, 3d
LI64 r6, 10d
4: JLTU r2, r6, :0
4: JLTS r2, r6, :0
CP r1, r2
JMP :1
0: JNE r2, r5, :2

View file

@ -15,7 +15,7 @@ main:
CP r35, r32
CP r36, r32
CP r37, r32
5: JLTU r35, r33, :0
5: JLTS r35, r33, :0
ADDI64 r36, r36, 1d
CP r2, r32
CP r3, r36

View file

@ -0,0 +1,24 @@
add:
ADD64 r1, r2, r3
JALA r0, r31, 0a
add:
ADD32 r1, r2, r3
JALA r0, r31, 0a
main:
ADDI64 r254, r254, -24d
ST r31, r254, 0a, 24h
LI64 r3, 2d
CP r2, r3
JAL r31, r0, :add
CP r32, r1
LI64 r3, 3d
LI64 r2, 1d
JAL r31, r0, :add
ANDI r33, r32, 4294967295d
SUB64 r1, r33, r1
LD r31, r254, 0a, 24h
ADDI64 r254, r254, 24d
JALA r0, r31, 0a
code size: 162
ret: 0
status: Ok(())

View file

@ -4,7 +4,7 @@ main:
LI64 r5, 128d
LI64 r7, 0d
ADDI64 r4, r254, 0d
2: JLTU r7, r5, :0
2: JLTS r7, r5, :0
LD r2, r254, 42a, 1h
ANDI r1, r2, 255d
JMP :1

View file

@ -3,7 +3,7 @@ fib:
ST r31, r254, 0a, 40h
LI64 r1, 1d
LI64 r32, 2d
JGTU r2, r32, :0
JGTS r2, r32, :0
JMP :1
0: CP r33, r2
SUB64 r2, r33, r1

View file

@ -26,7 +26,7 @@ line:
ADDI64 r6, r254, 0d
LD r9, r4, 0a, 8h
LD r11, r2, 0a, 8h
JGTU r11, r9, :0
JGTS r11, r9, :0
JMP :0
0: JALA r0, r31, 0a
main:

View file

@ -0,0 +1,6 @@
main:
LI64 r1, 10d
JALA r0, r31, 0a
code size: 29
ret: 10
status: Ok(())

View file

@ -0,0 +1,50 @@
create_back_buffer:
ADDI64 r254, r254, -56d
ST r31, r254, 0a, 56h
LI64 r32, 255d
JGTS r2, r32, :0
AND r2, r2, r32
JAL r31, r0, :request_page
JMP :1
0: CP r33, r2
LI64 r34, 255d
CP r2, r34
JAL r31, r0, :request_page
LI64 r35, 0d
CP r2, r33
SUB64 r36, r2, r32
5: JGTS r36, r35, :2
JMP :1
2: CP r37, r1
JLTS r36, r32, :3
CP r2, r34
JAL r31, r0, :request_page
JMP :4
3: AND r2, r36, r32
JAL r31, r0, :request_page
4: SUB64 r36, r36, r32
CP r1, r37
JMP :5
1: LD r31, r254, 0a, 56h
ADDI64 r254, r254, 56d
JALA r0, r31, 0a
main:
ADDI64 r254, r254, -8d
ST r31, r254, 0a, 8h
LI64 r2, 400d
JAL r31, r0, :create_back_buffer
LD r31, r254, 0a, 8h
ADDI64 r254, r254, 8d
JALA r0, r31, 0a
request_page:
CP r12, r2
LI64 r5, 12d
LI64 r3, 2d
LI64 r2, 3d
LRA r4, r0, :"\0\u{1}xxxxxxxx\0"
ST r12, r4, 1a, 1h
ECA
JALA r0, r31, 0a
code size: 346
ret: 42
status: Ok(())

View file

@ -0,0 +1,12 @@
main:
ADDI64 r254, r254, -8d
ST r31, r254, 0a, 8h
JAL r31, r0, :some_func
LD r31, r254, 0a, 8h
ADDI64 r254, r254, 8d
JALA r0, r31, 0a
some_func:
JALA r0, r31, 0a
code size: 85
ret: 0
status: Ok(())

View file

@ -17,7 +17,7 @@ sqrt:
ADDI64 r7, r7, -1d
ADD64 r9, r4, r8
SLU64 r9, r9, r7
JLTU r2, r9, :2
JLTS r2, r9, :2
ADD64 r1, r8, r1
SUB64 r2, r2, r9
JMP :2
@ -25,5 +25,5 @@ sqrt:
JMP :3
1: JALA r0, r31, 0a
code size: 188
ret: 14
ret: 15
status: Ok(())

View file

@ -3,7 +3,7 @@ fib:
ST r31, r254, 0a, 32h
CP r32, r2
LI64 r33, 2d
JLTU r2, r33, :0
JLTS r2, r33, :0
CP r34, r32
ADDI64 r2, r34, -1d
JAL r31, r0, :fib

View file

@ -5,10 +5,10 @@ main:
LI64 r9, 1d
LI64 r8, 0d
ADDI64 r6, r254, 0d
4: JLTU r8, r5, :0
4: JLTS r8, r5, :0
LI64 r7, 10d
CP r8, r9
3: JLTU r8, r7, :1
3: JLTS r8, r7, :1
LD r9, r254, 2048a, 1h
ANDI r1, r9, 255d
JMP :2