diff --git a/lang/README.md b/lang/README.md index bf2062b..54ea2b6 100644 --- a/lang/README.md +++ b/lang/README.md @@ -593,6 +593,17 @@ main := fn(): uint { ### Purely Testing Examples +#### needless_unwrap +```hb +main := fn(): uint { + always_nn := @as(?^uint, &0) + ptr := @unwrap(always_nn) + always_n := @as(?^uint, null) + ptr = @unwrap(always_n) + return *ptr +} +``` + #### inlining_issues ```hb main := fn(): void { @@ -1349,3 +1360,23 @@ opaque := fn(): Foo { return .(3, 2) } ``` + +#### more_if_opts +```hb +main := fn(): uint { + opq1 := opaque() + opq2 := opaque() + a := 0 + + if opq1 == null { + } else a = *opq1 + if opq1 != null a = *opq1 + //if opq1 == null | opq2 == null { + //} else a = *opq1 + //if opq1 != null & opq2 != null a = *opq1 + + return a +} + +opaque := fn(): ?^uint return null +``` diff --git a/lang/src/son.rs b/lang/src/son.rs index ea05930..94ab709 100644 --- a/lang/src/son.rs +++ b/lang/src/son.rs @@ -18,7 +18,7 @@ use { alloc::{string::String, vec::Vec}, core::{ assert_matches::debug_assert_matches, - cell::RefCell, + cell::{Cell, RefCell}, fmt::{self, Debug, Display, Write}, format_args as fa, mem, ops::{self, Deref}, @@ -145,20 +145,21 @@ impl Nodes { self[target].loop_depth } - fn idepth(&mut self, target: Nid) -> IDomDepth { + fn idepth(&self, target: Nid) -> IDomDepth { if target == VOID { return 0; } - if self[target].depth == 0 { - self[target].depth = match self[target].kind { + if self[target].depth.get() == 0 { + let depth = match self[target].kind { Kind::End | Kind::Start => unreachable!("{:?}", self[target].kind), Kind::Region => { self.idepth(self[target].inputs[0]).max(self.idepth(self[target].inputs[1])) } _ => self.idepth(self[target].inputs[0]), } + 1; + self[target].depth.set(depth); } - self[target].depth + self[target].depth.get() } fn fix_loops(&mut self) { @@ -214,8 +215,10 @@ impl Nodes { return; } - let index = self[0].outputs.iter().position(|&p| p == node).unwrap(); - self[0].outputs.remove(index); + let current = self[node].inputs[0]; + + let index = self[current].outputs.iter().position(|&p| p == node).unwrap(); + self[current].outputs.remove(index); self[node].inputs[0] = deepest; debug_assert!( !self[deepest].outputs.contains(&node) @@ -424,7 +427,7 @@ impl Nodes { self[self[from].inputs[0]].inputs[index - 1] } - fn idom(&mut self, target: Nid) -> Nid { + fn idom(&self, target: Nid) -> Nid { match self[target].kind { Kind::Start => VOID, Kind::End => unreachable!(), @@ -436,7 +439,7 @@ impl Nodes { } } - fn common_dom(&mut self, mut a: Nid, mut b: Nid) -> Nid { + fn common_dom(&self, mut a: Nid, mut b: Nid) -> Nid { while a != b { let [ldepth, rdepth] = [self.idepth(a), self.idepth(b)]; if ldepth >= rdepth { @@ -893,81 +896,18 @@ impl Nodes { return Some(self.new_const(ty, op.apply_unop(value, is_float))); } } - K::Assert { kind, pos } => { - if self[target].ty == ty::Id::VOID { - let &[ctrl, cond] = self[target].inputs.as_slice() else { unreachable!() }; - if let K::CInt { value } = self[cond].kind { - let ty = if value != 0 { - ty::Id::NEVER - } else { - return Some(ctrl); - }; - return Some(self.new_node_nop(ty, K::Assert { kind, pos }, [ctrl, cond])); - } - - 'b: { - let mut cursor = ctrl; - loop { - if cursor == ENTRY { - break 'b; - } - - // TODO: do more inteligent checks on the condition - if self[cursor].kind == Kind::Else - && self[self[cursor].inputs[0]].inputs[1] == cond - { - return Some(ctrl); - } - if self[cursor].kind == Kind::Then - && self[self[cursor].inputs[0]].inputs[1] == cond - { - return Some(self.new_node_nop( - ty::Id::NEVER, - K::Assert { kind, pos }, - [ctrl, cond], - )); - } - - cursor = self.idom(cursor); - } - } - } - } K::If => { if self[target].ty == ty::Id::VOID { - let &[ctrl, cond] = self[target].inputs.as_slice() else { unreachable!() }; - if let K::CInt { value } = self[cond].kind { - let ty = if value == 0 { - ty::Id::LEFT_UNREACHABLE - } else { - ty::Id::RIGHT_UNREACHABLE - }; - return Some(self.new_node_nop(ty, K::If, [ctrl, cond])); - } - - 'b: { - let mut cursor = ctrl; - let ty = loop { - if cursor == ENTRY { - break 'b; - } - - // TODO: do more inteligent checks on the condition - if self[cursor].kind == Kind::Then - && self[self[cursor].inputs[0]].inputs[1] == cond - { - break ty::Id::RIGHT_UNREACHABLE; - } - if self[cursor].kind == Kind::Else - && self[self[cursor].inputs[0]].inputs[1] == cond - { - break ty::Id::LEFT_UNREACHABLE; - } - - cursor = self.idom(cursor); - }; - - return Some(self.new_node_nop(ty, K::If, [ctrl, cond])); + match self.try_opt_cond(target) { + CondOptRes::Unknown => {} + CondOptRes::Known { value, .. } => { + let ty = if value { + ty::Id::RIGHT_UNREACHABLE + } else { + ty::Id::LEFT_UNREACHABLE + }; + return Some(self.new_node_nop(ty, K::If, self[target].inputs.clone())); + } } } } @@ -1284,6 +1224,58 @@ impl Nodes { None } + fn try_opt_cond(&self, target: Nid) -> CondOptRes { + let &[ctrl, cond, ..] = self[target].inputs.as_slice() else { unreachable!() }; + if let Kind::CInt { value } = self[cond].kind { + return CondOptRes::Known { value: value != 0, pin: None }; + } + + let mut cursor = ctrl; + while cursor != ENTRY { + let ctrl = &self[cursor]; + // TODO: do more inteligent checks on the condition + if matches!(ctrl.kind, Kind::Then | Kind::Else) { + let other_cond = self[ctrl.inputs[0]].inputs[1]; + if let Some(value) = self.matches_cond(cond, other_cond) { + return CondOptRes::Known { + value: (ctrl.kind == Kind::Then) ^ !value, + pin: Some(cursor), + }; + } + } + + cursor = self.idom(cursor); + } + + CondOptRes::Unknown + } + + fn matches_cond(&self, to_match: Nid, matches: Nid) -> Option { + use TokenKind as K; + let [tn, mn] = [&self[to_match], &self[matches]]; + match (tn.kind, mn.kind) { + _ if to_match == matches => Some(true), + (Kind::BinOp { op: K::Ne }, Kind::BinOp { op: K::Eq }) + | (Kind::BinOp { op: K::Eq }, Kind::BinOp { op: K::Ne }) + if tn.inputs[1..] == mn.inputs[1..] => + { + Some(false) + } + (_, Kind::BinOp { op: K::Band }) => self + .matches_cond(to_match, mn.inputs[1]) + .or(self.matches_cond(to_match, mn.inputs[2])), + (_, Kind::BinOp { op: K::Bor }) => match ( + self.matches_cond(to_match, mn.inputs[1]), + self.matches_cond(to_match, mn.inputs[2]), + ) { + (None, Some(a)) | (Some(a), None) => Some(a), + (Some(b), Some(a)) if a == b => Some(a), + _ => None, + }, + _ => None, + } + } + fn is_const(&self, id: Nid) -> bool { matches!(self[id].kind, Kind::CInt { .. }) } @@ -1326,7 +1318,6 @@ impl Nodes { self[prev].outputs.swap_remove(index); self[with].outputs.push(target); self.remove(prev); - target } } @@ -1366,7 +1357,7 @@ impl Nodes { write!(out, "{:>4}: ", op.name()) } Kind::Call { func, args: _ } => { - write!(out, "call: {func} {} ", self[node].depth) + write!(out, "call: {func} {} ", self[node].depth.get()) } Kind::Global { global } => write!(out, "glob: {global:<5}"), Kind::Entry => write!(out, "ctrl: {:<5}", "entry"), @@ -1396,7 +1387,7 @@ impl Nodes { while self.visited.set(node) { match self[node].kind { Kind::Start => { - writeln!(out, "start: {}", self[node].depth)?; + writeln!(out, "start: {}", self[node].depth.get())?; let mut cfg_index = Nid::MAX; for o in iter(self, node) { self.basic_blocks_instr(out, o)?; @@ -1415,7 +1406,9 @@ impl Nodes { writeln!( out, "region{node}: {} {} {:?}", - self[node].depth, self[node].loop_depth, self[node].inputs + self[node].depth.get(), + self[node].loop_depth, + self[node].inputs )?; let mut cfg_index = Nid::MAX; for o in iter(self, node) { @@ -1430,7 +1423,9 @@ impl Nodes { writeln!( out, "loop{node}: {} {} {:?}", - self[node].depth, self[node].loop_depth, self[node].outputs + self[node].depth.get(), + self[node].loop_depth, + self[node].outputs )?; let mut cfg_index = Nid::MAX; for o in iter(self, node) { @@ -1448,7 +1443,9 @@ impl Nodes { writeln!( out, "b{node}: {} {} {:?}", - self[node].depth, self[node].loop_depth, self[node].outputs + self[node].depth.get(), + self[node].loop_depth, + self[node].outputs )?; let mut cfg_index = Nid::MAX; for o in iter(self, node) { @@ -1605,6 +1602,11 @@ impl Nodes { } } +enum CondOptRes { + Unknown, + Known { value: bool, pin: Option }, +} + impl ops::Index for Nodes { type Output = Node; @@ -1622,6 +1624,7 @@ impl ops::IndexMut for Nodes { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub enum AssertKind { NullCheck, + UnwrapCheck, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default)] @@ -1649,11 +1652,6 @@ pub enum Kind { Return, // [ctrl] Die, - // [ctrl, cond] - Assert { - kind: AssertKind, - pos: Pos, - }, // [ctrl] CInt { value: i64, @@ -1678,6 +1676,11 @@ pub enum Kind { func: ty::Func, args: ty::Tuple, }, + // [ctrl, cond, value] + Assert { + kind: AssertKind, + pos: Pos, + }, // [ctrl] Stck, // [ctrl, memory] @@ -1742,7 +1745,7 @@ pub struct Node { ty: ty::Id, offset: Offset, ralloc_backref: RallocBRef, - depth: IDomDepth, + depth: Cell, lock_rc: LockRc, loop_depth: LoopDepth, aclass: AClassId, @@ -2020,7 +2023,9 @@ impl ItemCtx { mem::take(&mut self.ctrl).soft_remove(&mut self.nodes); self.nodes.iter_peeps(1000, stack); + } + fn unlock(&mut self) { self.nodes.unlock(MEM); self.nodes.unlock(NEVER); self.nodes.unlock(LOOPS); @@ -2583,7 +2588,7 @@ impl<'a> Codegen<'a> { Expr::Field { target, name, pos } => { let mut vtarget = self.raw_expr(target)?; self.strip_var(&mut vtarget); - self.unwrap_opt(pos, &mut vtarget); + self.implicit_unwrap(pos, &mut vtarget); let tty = vtarget.ty; if let ty::Kind::Module(m) = tty.expand() { @@ -2657,7 +2662,7 @@ impl<'a> Codegen<'a> { let ctx = Ctx { ty: ctx.ty.map(|ty| self.tys.make_ptr(ty)) }; let mut vl = self.expr_ctx(val, ctx)?; - self.unwrap_opt(val.pos(), &mut vl); + self.implicit_unwrap(val.pos(), &mut vl); let Some(base) = self.tys.base_of(vl.ty) else { self.report( @@ -2753,7 +2758,7 @@ impl<'a> Codegen<'a> { { let mut lhs = self.raw_expr_ctx(left, ctx)?; self.strip_var(&mut lhs); - self.unwrap_opt(left.pos(), &mut lhs); + self.implicit_unwrap(left.pos(), &mut lhs); match lhs.ty.expand() { _ if lhs.ty.is_pointer() @@ -2767,7 +2772,7 @@ impl<'a> Codegen<'a> { self.ci.nodes.unlock(lhs.id); let mut rhs = rhs?; self.strip_var(&mut rhs); - self.unwrap_opt(right.pos(), &mut rhs); + self.implicit_unwrap(right.pos(), &mut rhs); let (ty, aclass) = self.binop_ty(pos, &mut lhs, &mut rhs, op); let inps = [VOID, lhs.id, rhs.id]; let bop = @@ -2896,7 +2901,7 @@ impl<'a> Codegen<'a> { let mut val = self.raw_expr(expr)?; self.strip_var(&mut val); - let Some(ty) = self.tys.inner_of(val.ty) else { + if !val.ty.is_optional() { self.report( expr.pos(), fa!( @@ -2907,8 +2912,7 @@ impl<'a> Codegen<'a> { return Value::NEVER; }; - self.unwrap_opt_unchecked(ty, val.ty, &mut val); - val.ty = ty; + self.explicit_unwrap(expr.pos(), &mut val); Some(val) } Expr::Directive { name: "intcast", args: [expr], pos } => { @@ -4078,30 +4082,76 @@ impl<'a> Codegen<'a> { } fn finalize(&mut self, prev_err_len: usize) -> bool { + use {AssertKind as AK, CondOptRes as CR}; + self.ci.finalize(&mut self.pool.nid_stack, self.tys, self.files); - for (_, node) in self.ci.nodes.iter() { + let mut to_remove = vec![]; + for (id, node) in self.ci.nodes.iter() { let Kind::Assert { kind, pos } = node.kind else { continue }; - match kind { - AssertKind::NullCheck => match node.ty { - ty::Id::NEVER => { - self.report( - pos, - "the value is always null, some checks might need to be inverted", - ); - } - _ => { - self.report( - pos, - "can't prove the value is not 'null', \ - use '@unwrap()' if you believe compiler is stupid, \ - or explicitly check for null and handle it \ - ('if == null { /* handle */ } else { /* use opt */ }')", - ); - } - }, - } + let res = self.ci.nodes.try_opt_cond(id); + + // TODO: highlight the pin position + let msg = match (kind, res) { + (AK::UnwrapCheck, CR::Known { value: false, .. }) => { + "unwrap is not needed since the value is (provably) never null, \ + remove it, or replace with '@as(, )'" + } + (AK::UnwrapCheck, CR::Known { value: true, .. }) => { + "unwrap is incorrect since the value is (provably) always null, \ + make sure your logic is correct" + } + (AK::NullCheck, CR::Known { value: true, .. }) => { + "the value is always null, some checks might need to be inverted" + } + (AK::NullCheck, CR::Unknown) => { + "can't prove the value is not 'null', \ + use '@unwrap()' if you believe compiler is stupid, \ + or explicitly check for null and handle it \ + ('if == null { /* handle */ } else { /* use opt */ }')" + } + (AK::NullCheck, CR::Known { value: false, pin }) => { + to_remove.push((id, pin)); + continue; + } + (AK::UnwrapCheck, CR::Unknown) => { + to_remove.push((id, None)); + continue; + } + }; + self.report(pos, msg); } + to_remove.into_iter().for_each(|(n, pin)| { + let pin = pin.unwrap_or_else(|| { + let mut pin = self.ci.nodes[n].inputs[0]; + while matches!(self.ci.nodes[pin].kind, Kind::Assert { .. }) { + pin = self.ci.nodes[n].inputs[0]; + } + pin + }); + for mut out in self.ci.nodes[n].outputs.clone() { + if self.ci.nodes.is_cfg(out) { + let index = self.ci.nodes[out].inputs.iter().position(|&p| p == n).unwrap(); + self.ci.nodes.modify_input(out, index, self.ci.nodes[n].inputs[0]); + } else { + if !self.ci.nodes[out].kind.is_pinned() { + out = self.ci.nodes.modify_input(out, 0, pin); + } + let index = + self.ci.nodes[out].inputs[1..].iter().position(|&p| p == n).unwrap() + 1; + self.ci.nodes.modify_input(out, index, self.ci.nodes[n].inputs[2]); + } + } + debug_assert!( + self.ci.nodes.values[n as usize] + .as_ref() + .map_or(true, |n| !matches!(n.kind, Kind::Assert { .. })), + "{:?} {:?}", + self.ci.nodes[n], + self.ci.nodes[n].outputs.iter().map(|&o| &self.ci.nodes[o]).collect::>(), + ); + }); + self.ci.unlock(); self.errors.borrow().len() == prev_err_len } @@ -4200,21 +4250,32 @@ impl<'a> Codegen<'a> { } } - fn unwrap_opt(&mut self, pos: Pos, opt: &mut Value) { + fn implicit_unwrap(&mut self, pos: Pos, opt: &mut Value) { + self.unwrap_low(pos, opt, AssertKind::NullCheck); + } + + fn explicit_unwrap(&mut self, pos: Pos, opt: &mut Value) { + self.unwrap_low(pos, opt, AssertKind::UnwrapCheck); + } + + fn unwrap_low(&mut self, pos: Pos, opt: &mut Value, kind: AssertKind) { let Some(ty) = self.tys.inner_of(opt.ty) else { return }; let null_check = self.gen_null_check(*opt, ty, TokenKind::Eq); - // TODO: extract the if check int a fucntion - self.ci.ctrl.set( - self.ci.nodes.new_node( - ty::Id::VOID, - Kind::Assert { kind: AssertKind::NullCheck, pos }, - [self.ci.ctrl.get(), null_check], - ), - &mut self.ci.nodes, - ); let oty = mem::replace(&mut opt.ty, ty); self.unwrap_opt_unchecked(ty, oty, opt); + + // TODO: extract the if check int a fucntion + self.ci.ctrl.set( + self.ci.nodes.new_node(oty, Kind::Assert { kind, pos }, [ + self.ci.ctrl.get(), + null_check, + opt.id, + ]), + &mut self.ci.nodes, + ); + self.ci.nodes.pass_aclass(self.ci.nodes.aclass_index(opt.id).1, self.ci.ctrl.get()); + opt.id = self.ci.ctrl.get(); } fn unwrap_opt_unchecked(&mut self, ty: ty::Id, oty: ty::Id, opt: &mut Value) { @@ -4293,7 +4354,7 @@ impl<'a> Codegen<'a> { if let Some(inner) = self.tys.inner_of(src.ty) && inner.try_upcast(expected) == Some(expected) { - self.unwrap_opt(pos, src); + self.implicit_unwrap(pos, src); return self.assert_ty(pos, src, expected, hint); } @@ -4480,6 +4541,7 @@ mod tests { fb_driver; // Purely Testing Examples; + needless_unwrap; inlining_issues; null_check_test; only_break_loop; @@ -4519,5 +4581,6 @@ mod tests { aliasing_overoptimization; global_aliasing_overptimization; overwrite_aliasing_overoptimization; + more_if_opts; } } diff --git a/lang/src/son/hbvm.rs b/lang/src/son/hbvm.rs index 38aa166..67043b1 100644 --- a/lang/src/son/hbvm.rs +++ b/lang/src/son/hbvm.rs @@ -350,7 +350,7 @@ impl ItemCtx { let &[dst, lhs, rhs] = allocs else { unreachable!() }; self.emit(op(atr(dst), atr(lhs), atr(rhs))); } else if let Some(against) = op.cmp_against() { - let op_ty = fuc.nodes[lh].ty; + let op_ty = fuc.nodes[rh].ty; self.emit(extend(fuc.nodes[lh].ty, fuc.nodes[lh].ty.extend(), 0, 0)); self.emit(extend(fuc.nodes[rh].ty, fuc.nodes[rh].ty.extend(), 1, 1)); @@ -365,7 +365,7 @@ impl ItemCtx { let op_fn = opop.float_cmp(op_ty).unwrap(); self.emit(op_fn(atr(dst), atr(lhs), atr(rhs))); self.emit(instrs::not(atr(dst), atr(dst))); - } else if op_ty.is_integer() { + } else { let op_fn = if op_ty.is_signed() { instrs::cmps } else { instrs::cmpu }; self.emit(op_fn(atr(dst), atr(lhs), atr(rhs))); @@ -373,8 +373,6 @@ impl ItemCtx { if matches!(op, TokenKind::Eq | TokenKind::Lt | TokenKind::Gt) { self.emit(instrs::not(atr(dst), atr(dst))); } - } else { - todo!("unhandled operator: {op}"); } } else { todo!("unhandled operator: {op}"); diff --git a/lang/tests/son_tests_more_if_opts.txt b/lang/tests/son_tests_more_if_opts.txt new file mode 100644 index 0000000..370ebd2 --- /dev/null +++ b/lang/tests/son_tests_more_if_opts.txt @@ -0,0 +1,27 @@ +main: + ADDI64 r254, r254, -16d + ST r31, r254, 0a, 16h + JAL r31, r0, :opaque + CP r32, r1 + JAL r31, r0, :opaque + LI64 r6, 0d + CP r1, r32 + JNE r1, r6, :0 + CP r32, r1 + LI64 r1, 0d + CP r3, r32 + JMP :1 + 0: CP r3, r1 + LD r1, r3, 0a, 8h + 1: JEQ r3, r6, :2 + LD r1, r3, 0a, 8h + JMP :2 + 2: LD r31, r254, 0a, 16h + ADDI64 r254, r254, 16d + JALA r0, r31, 0a +opaque: + LI64 r1, 0d + JALA r0, r31, 0a +code size: 183 +ret: 0 +status: Ok(()) diff --git a/lang/tests/son_tests_needless_unwrap.txt b/lang/tests/son_tests_needless_unwrap.txt new file mode 100644 index 0000000..73b8726 --- /dev/null +++ b/lang/tests/son_tests_needless_unwrap.txt @@ -0,0 +1,6 @@ +test.hb:4:17: unwrap is not needed since the value is (provably) never null, remove it, or replace with '@as(, )' + ptr := @unwrap(always_nn) + ^ +test.hb:6:16: unwrap is incorrect since the value is (provably) always null, make sure your logic is correct + ptr = @unwrap(always_n) + ^