From f1deab11c977e855e80b13883e4a57bc75be6ef5 Mon Sep 17 00:00:00 2001 From: Jakub Doka Date: Tue, 12 Nov 2024 12:20:08 +0100 Subject: [PATCH] making better peepholes and fixing overoptimization on memory swaps --- lang/README.md | 2 +- lang/src/son.rs | 68 ++++++++++++++++--- ...ts_overwrite_aliasing_overoptimization.txt | 34 +++++----- ...son_tests_storing_into_nullable_struct.txt | 27 ++++---- lang/tests/son_tests_string_flip.txt | 44 ++++++------ 5 files changed, 112 insertions(+), 63 deletions(-) diff --git a/lang/README.md b/lang/README.md index 2442126..ab3bb6c 100644 --- a/lang/README.md +++ b/lang/README.md @@ -1088,7 +1088,7 @@ main := fn(): uint { i += 1 } - return arr[0].u + return arr[2].u } ``` diff --git a/lang/src/son.rs b/lang/src/son.rs index 756f224..36eddcf 100644 --- a/lang/src/son.rs +++ b/lang/src/son.rs @@ -107,13 +107,19 @@ macro_rules! inference { #[derive(Clone)] pub struct Nodes { values: Vec>, + queued_peeps: Vec, free: Nid, lookup: Lookup, } impl Default for Nodes { fn default() -> Self { - Self { values: Default::default(), free: Nid::MAX, lookup: Default::default() } + Self { + values: Default::default(), + queued_peeps: Default::default(), + free: Nid::MAX, + lookup: Default::default(), + } } } @@ -673,6 +679,7 @@ impl Nodes { let id = self.new_node_nop(ty, kind, inps); if let Some(opt) = self.peephole(id, tys) { debug_assert_ne!(opt, id); + self.queued_peeps.clear(); self.lock(opt); self.remove(id); self.unlock(opt); @@ -731,6 +738,7 @@ impl Nodes { fn late_peephole(&mut self, target: Nid, tys: &Types) -> Option { if let Some(id) = self.peephole(target, tys) { + self.queued_peeps.clear(); self.replace(target, id); return None; } @@ -759,6 +767,11 @@ impl Nodes { } if let Some(new) = self.peephole(node, tys) { + let plen = stack.len(); + stack.append(&mut self.queued_peeps); + for &p in &stack[plen..] { + self.lock(p); + } self.replace(node, new); self.push_adjacent_nodes(new, stack); } @@ -1043,18 +1056,14 @@ impl Nodes { K::Phi => { let &[ctrl, lhs, rhs] = self[target].inputs.as_slice() else { unreachable!() }; - if rhs == target { - return Some(lhs); - } - - if lhs == rhs { + if rhs == target || lhs == rhs { return Some(lhs); } if self[lhs].kind == Kind::Stre && self[rhs].kind == Kind::Stre && self[lhs].ty == self[rhs].ty - && !matches!(self[lhs].ty.expand(), ty::Kind::Struct(_)) + && self[lhs].ty.loc(tys) == Loc::Reg && self[lhs].inputs[2] == self[rhs].inputs[2] && self[lhs].inputs[3] == self[rhs].inputs[3] { @@ -1098,11 +1107,18 @@ impl Nodes { } } K::Stre => { + if target == 79 { + std::dbg!(std::backtrace::Backtrace::capture()); + } + let &[_, value, region, store, ..] = self[target].inputs.as_slice() else { unreachable!() }; if self[value].kind == Kind::Load && self[value].inputs[1] == region { + if target == 79 { + std::dbg!(1); + } return Some(store); } @@ -1121,6 +1137,9 @@ impl Nodes { 'eliminate: { if self[target].outputs.is_empty() { + if target == 79 { + std::dbg!(2); + } break 'eliminate; } @@ -1130,6 +1149,9 @@ impl Nodes { for &ele in self[value].outputs.clone().iter().filter(|&&n| n != target) { self[ele].peep_triggers.push(target); } + if target == 79 { + std::dbg!(3); + } break 'eliminate; } @@ -1138,6 +1160,9 @@ impl Nodes { }; if self[stack].ty != self[value].ty || self[stack].kind != Kind::Stck { + if target == 79 { + std::dbg!(3); + } break 'eliminate; } @@ -1162,8 +1187,22 @@ impl Nodes { } let Some(index) = unidentifed.iter().position(|&n| n == contact_point) else { + if target == 79 { + std::dbg!(5); + } break 'eliminate; }; + if self[self[cursor].inputs[1]].kind == Kind::Load + && self[value].outputs.iter().any(|&n| { + self.aclass_index(self[self[cursor].inputs[1]].inputs[1]).0 + == self.aclass_index(self[n].inputs[2]).0 + }) + { + if target == 79 { + std::dbg!(6); + } + break 'eliminate; + } unidentifed.remove(index); saved.push(contact_point); first_store = cursor; @@ -1175,6 +1214,9 @@ impl Nodes { } if !unidentifed.is_empty() { + if target == 79 { + std::dbg!(7); + } break 'eliminate; } @@ -1214,7 +1256,8 @@ impl Nodes { debug_assert_eq!(inps.len(), 4); inps[2] = region; inps[3] = prev_store; - prev_store = self.new_node(self[oper].ty, Kind::Stre, inps, tys); + prev_store = self.new_node_nop(self[oper].ty, Kind::Stre, inps); + self.queued_peeps.push(prev_store); } return Some(prev_store); @@ -1279,6 +1322,15 @@ impl Nodes { if self[cursor].inputs[0] == ctrl && self[cursor].inputs[2] == region && self[cursor].ty == self[target].ty + && (self[self[cursor].inputs[1]].kind != Kind::Load + || (!self[target].outputs.is_empty() + && self[target].outputs.iter().all(|&n| { + self[n].kind != Kind::Stre + || self + .aclass_index(self[self[cursor].inputs[1]].inputs[1]) + .0 + != self.aclass_index(self[n].inputs[2]).0 + }))) { return Some(self[cursor].inputs[1]); } diff --git a/lang/tests/son_tests_overwrite_aliasing_overoptimization.txt b/lang/tests/son_tests_overwrite_aliasing_overoptimization.txt index 8ca6b6f..4441af2 100644 --- a/lang/tests/son_tests_overwrite_aliasing_overoptimization.txt +++ b/lang/tests/son_tests_overwrite_aliasing_overoptimization.txt @@ -1,23 +1,21 @@ main: - ADDI64 r254, r254, -112d - ST r31, r254, 40a, 72h - LI64 r32, 4d - ADDI64 r33, r254, 24d - ADDI64 r34, r254, 0d - ST r32, r254, 24a, 8h - LI64 r35, 1d - ST r35, r254, 32a, 8h - ST r35, r254, 16a, 8h - BMC r33, r34, 16h + ADDI64 r254, r254, -88d + ST r31, r254, 24a, 64h + ADDI64 r32, r254, 0d + LI64 r33, 1d + ST r33, r254, 16a, 8h + LI64 r34, 4d + ST r34, r254, 0a, 8h + ST r33, r254, 8a, 8h JAL r31, r0, :opaque ST r1, r254, 0a, 16h - LD r36, r254, 8a, 8h - LD r37, r254, 16a, 8h - ADD64 r38, r37, r36 - LD r39, r254, 0a, 8h - SUB64 r1, r39, r38 - LD r31, r254, 40a, 72h - ADDI64 r254, r254, 112d + LD r35, r254, 8a, 8h + LD r36, r254, 16a, 8h + ADD64 r37, r36, r35 + LD r38, r254, 0a, 8h + SUB64 r1, r38, r37 + LD r31, r254, 24a, 64h + ADDI64 r254, r254, 88d JALA r0, r31, 0a opaque: ADDI64 r254, r254, -16d @@ -29,6 +27,6 @@ opaque: LD r1, r2, 0a, 16h ADDI64 r254, r254, 16d JALA r0, r31, 0a -code size: 323 +code size: 307 ret: 0 status: Ok(()) diff --git a/lang/tests/son_tests_storing_into_nullable_struct.txt b/lang/tests/son_tests_storing_into_nullable_struct.txt index 3b70861..f007bfe 100644 --- a/lang/tests/son_tests_storing_into_nullable_struct.txt +++ b/lang/tests/son_tests_storing_into_nullable_struct.txt @@ -49,22 +49,19 @@ optional: ADDI64 r254, r254, 16d JALA r0, r31, 0a optionala: - ADDI64 r254, r254, -56d - ADDI64 r7, r254, 0d - LI64 r8, 1d - ADDI64 r9, r254, 16d - ADDI64 r6, r254, 24d - ST r8, r254, 0a, 8h + ADDI64 r254, r254, -48d + ADDI64 r5, r254, 0d + ADDI64 r4, r254, 16d + ST r5, r254, 16a, 8h + LI64 r9, 1d ST r9, r254, 24a, 8h - ADDI64 r3, r6, 8d - BMC r7, r3, 8h - ADDI64 r4, r254, 8d - ST r4, r254, 40a, 8h - LI64 r2, 0d - ST r2, r254, 48a, 8h - BMC r6, r1, 32h - ADDI64 r254, r254, 56d + ADDI64 r12, r254, 8d + ST r12, r254, 32a, 8h + LI64 r11, 0d + ST r11, r254, 40a, 8h + BMC r4, r1, 32h + ADDI64 r254, r254, 48d JALA r0, r31, 0a -code size: 604 +code size: 577 ret: 100 status: Ok(()) diff --git a/lang/tests/son_tests_string_flip.txt b/lang/tests/son_tests_string_flip.txt index 2bea039..33856d5 100644 --- a/lang/tests/son_tests_string_flip.txt +++ b/lang/tests/son_tests_string_flip.txt @@ -1,5 +1,5 @@ main: - ADDI64 r254, r254, -32d + ADDI64 r254, r254, -40d LI64 r7, 1d LI64 r6, 4d LI64 r4, 0d @@ -7,36 +7,38 @@ main: CP r8, r4 6: JNE r8, r6, :0 LI64 r6, 2d + ADDI64 r3, r254, 32d CP r8, r4 - 4: LD r1, r254, 0a, 8h + 4: LD r1, r254, 16a, 8h JNE r8, r7, :1 JMP :2 - 1: MUL64 r11, r8, r6 + 1: MUL64 r12, r8, r6 ADD64 r8, r8, r7 SUB64 r9, r6, r8 MUL64 r9, r9, r6 - CP r2, r4 - 5: JNE r2, r6, :3 + CP r11, r4 + 5: JNE r11, r6, :3 JMP :4 - 3: ADD64 r10, r2, r7 - ADD64 r12, r9, r2 - MULI64 r12, r12, 8d - ADD64 r1, r11, r2 - ADD64 r12, r5, r12 + 3: ADD64 r10, r11, r7 + ADD64 r1, r12, r11 + ADD64 r2, r9, r11 MULI64 r1, r1, 8d + MULI64 r11, r2, 8d ADD64 r1, r5, r1 - BMC r12, r1, 8h - BMC r1, r12, 8h - CP r2, r10 + ADD64 r11, r5, r11 + BMC r1, r3, 8h + BMC r11, r1, 8h + BMC r3, r11, 8h + CP r11, r10 JMP :5 - 0: ADD64 r12, r8, r7 - MULI64 r10, r8, 8d - ADD64 r1, r5, r10 - ST r8, r1, 0a, 8h - CP r8, r12 + 0: ADD64 r2, r8, r7 + MULI64 r12, r8, 8d + ADD64 r3, r5, r12 + ST r8, r3, 0a, 8h + CP r8, r2 JMP :6 - 2: ADDI64 r254, r254, 32d + 2: ADDI64 r254, r254, 40d JALA r0, r31, 0a -code size: 255 -ret: 2 +code size: 271 +ret: 0 status: Ok(())